From 079a0343c363eac5d4959d18aa60911fe1760b61 Mon Sep 17 00:00:00 2001 From: jdidion Date: Mon, 3 Feb 2025 16:27:34 -0800 Subject: [PATCH 01/67] implement retry queues as per-thread --- src/hive/channel.rs | 967 +++++++++++++++++++++++++++++++++++++++++ src/hive/task/delay.rs | 165 +++++++ src/hive/task/iter.rs | 27 ++ src/hive/task/local.rs | 44 ++ src/hive/task/mod.rs | 41 ++ 5 files changed, 1244 insertions(+) create mode 100644 src/hive/channel.rs create mode 100644 src/hive/task/delay.rs create mode 100644 src/hive/task/iter.rs create mode 100644 src/hive/task/local.rs create mode 100644 src/hive/task/mod.rs diff --git a/src/hive/channel.rs b/src/hive/channel.rs new file mode 100644 index 0000000..05e6685 --- /dev/null +++ b/src/hive/channel.rs @@ -0,0 +1,967 @@ +use super::prelude::*; +use super::{ + Config, DerefOutcomes, HiveInner, LocalQueues, OutcomeSender, Shared, SpawnError, TaskSender, +}; +use crate::atomic::Atomic; +use crate::bee::{DefaultQueen, Queen, TaskId, Worker}; +use crossbeam_utils::Backoff; +use std::collections::HashMap; +use std::fmt::Debug; +use std::ops::{Deref, DerefMut}; +use std::sync::{mpsc, Arc}; +use std::thread::{self, JoinHandle}; + +#[derive(thiserror::Error, Debug)] +#[error("The hive has been poisoned")] +pub struct Poisoned; + +impl> Hive { + /// Spawns a new worker thread with the specified index and with access to the `shared` data. + fn try_spawn>( + thread_index: usize, + shared: Arc>, + ) -> Result, SpawnError> { + // spawn a thread that executes the worker loop + shared.thread_builder().spawn(move || { + // perform one-time initialization of the worker thread + Self::init_thread(thread_index, &shared); + // create a Sentinel that will spawn a new thread on panic until it is cancelled + let sentinel = Sentinel::new(thread_index, Arc::clone(&shared)); + // create a new Worker instance + let mut worker = shared.create_worker(); + // execute the main loop + // get the next task to process - this decrements the queued counter and increments + // the active counter + while let Ok(task) = shared.next_task(thread_index) { + // execute the task until it succeeds or we reach maximum retries - this should + // be the only place where a panic can occur + Self::execute(task, thread_index, &mut worker, &shared); + // finish the task - decrements the active counter and notifies other threads + shared.finish_task(false); + } + // this is only reachable when the main loop exits due to the task receiver having + // disconnected; cancel the Sentinel so this thread won't be re-spawned on drop + sentinel.cancel(); + }) + } + + /// Creates a new `Hive`. This should only be called from `Builder`. + /// + /// The `Hive` will attempt to spawn the configured number of worker threads + /// (`config.num_threads`) but the actual number of threads available may be lower if there + /// are any errors during spawning. + pub(super) fn new(config: Config, queen: Q) -> Self { + let (task_tx, task_rx) = mpsc::channel(); + let shared = Arc::new(Shared::new(config.into_sync(), queen, task_rx)); + shared.init_threads(|thread_index| Self::try_spawn(thread_index, Arc::clone(&shared))); + Self(Some(HiveInner { task_tx, shared })) + } + + #[inline] + fn task_tx(&self) -> &TaskSender { + &self.0.as_ref().unwrap().task_tx + } + + /// Attempts to increase the number of worker threads by `num_threads`. Returns the number of + /// new worker threads that were successfully started (which may be fewer than `num_threads`), + /// or a `Poisoned` error if the hive has been poisoned. + pub fn grow(&self, num_threads: usize) -> Result { + if num_threads == 0 { + return Ok(0); + } + let shared = &self.0.as_ref().unwrap().shared; + // do not start any new threads if the hive is poisoned + if shared.is_poisoned() { + return Err(Poisoned); + } + let num_started = shared.grow_threads(num_threads, |thread_index| { + Self::try_spawn(thread_index, Arc::clone(shared)) + }); + Ok(num_started) + } + + /// Sets the number of worker threads to the number of available CPU cores. Returns the number + /// of new threads that were successfully started (which may be `0`), or a `Poisoned` error if + /// the hive has been poisoned. + pub fn use_all_cores(&self) -> Result { + let num_threads = num_cpus::get().saturating_sub(self.max_workers()); + self.grow(num_threads) + } + + /// Sends one input to the `Hive` for processing and returns its ID. The `Outcome` + /// of the task is sent to the `outcome_tx` channel if provided, otherwise it is retained in + /// the `Hive` for later retrieval. + /// + /// This method is called by all the `*apply*` methods. + fn send_one(&self, input: W::Input, outcome_tx: Option>) -> TaskId { + #[cfg(debug_assertions)] + if self.max_workers() == 0 { + dbg!("WARNING: no worker threads are active for hive"); + } + let shared = &self.0.as_ref().unwrap().shared; + let task = shared.prepare_task(input, outcome_tx); + let task_id = task.id(); + // try to send the task to the hive; if the hive is poisoned or if sending fails, convert + // the task into an `Unprocessed` outcome and try to send it to the outcome channel; if + // that fails, store the outcome in the hive + if let Some(abandoned_task) = if self.is_poisoned() { + Some(task) + } else { + self.task_tx().send(task).err().map(|err| err.0) + } { + shared.abandon_task(abandoned_task); + } + task_id + } + + /// Sends one `input` to the `Hive` for procesing and returns the result, blocking until the + /// result is available. Creates a channel to send the input and receive the outcome. Returns + /// an [`Outcome`] with the task output or an error. + pub fn apply(&self, input: W::Input) -> Outcome { + let (tx, rx) = outcome_channel(); + let task_id = self.send_one(input, Some(tx)); + rx.recv().unwrap_or_else(|_| Outcome::Missing { task_id }) + } + + /// Sends one `input` to the `Hive` for processing and returns its ID. The [`Outcome`] of + /// the task will be sent to `tx` upon completion. + pub fn apply_send(&self, input: W::Input, tx: OutcomeSender) -> TaskId { + self.send_one(input, Some(tx)) + } + + /// Sends one `input` to the `Hive` for processing and returns its ID immediately. The + /// [`Outcome`] of the task will be retained and available for later retrieval. + pub fn apply_store(&self, input: W::Input) -> TaskId { + self.send_one(input, None) + } + + /// Sends a `batch` of inputs to the `Hive` for processing, and returns a `Vec` of their + /// task IDs. The [`Outcome`]s of the tasks are sent to the `outcome_tx` channel if provided, + /// otherwise they are retained in the `Hive` for later retrieval. + /// + /// The batch is provided as an [`ExactSizeIterator`], which enables the hive to reserve a + /// range of task IDs (a single atomic operation) rather than one at a time. + /// + /// This method is called by all the `swarm*` methods. + fn send_batch(&self, batch: T, outcome_tx: Option>) -> Vec + where + T: IntoIterator, + T::IntoIter: ExactSizeIterator, + { + #[cfg(debug_assertions)] + if self.max_workers() == 0 { + dbg!("WARNING: no worker threads are active for hive"); + } + let task_tx = self.task_tx(); + let iter = batch.into_iter(); + let (batch_size, _) = iter.size_hint(); + let shared = &self.0.as_ref().unwrap().shared; + let batch = shared.prepare_batch(batch_size, iter, outcome_tx); + if !self.is_poisoned() { + batch + .map(|task| { + let task_id = task.id(); + // try to send the task to the hive; if sending fails, convert the task into an + // `Unprocessed` outcome and try to send it to the outcome channel; if that + // fails, store the outcome in the hive + if let Err(err) = task_tx.send(task) { + shared.abandon_task(err.0); + } + task_id + }) + .collect() + } else { + // if the hive is poisoned, convert all tasks into `Unprocessed` outcomes and try to + // send them to their outcome channels or store them in the hive + (&self.0.as_ref().unwrap().shared).abandon_batch(batch) + } + } + + /// Sends a `batch` of inputs to the `Hive` for processing, and returns an iterator over the + /// [`Outcome`]s in the same order as the inputs. + /// + /// This method is more efficient than [`map`](Self::map) when the input is an + /// [`ExactSizeIterator`]. + pub fn swarm(&self, batch: T) -> impl Iterator> + where + T: IntoIterator, + T::IntoIter: ExactSizeIterator, + { + let (tx, rx) = outcome_channel(); + let task_ids = self.send_batch(batch, Some(tx)); + rx.select_ordered(task_ids) + } + + /// Sends a `batch` of inputs to the `Hive` for processing, and returns an unordered iterator + /// over the [`Outcome`]s. + /// + /// The `Outcome`s will be sent in the order they are completed; use [`swarm`](Self::swarm) to + /// instead receive the `Outcome`s in the order they were submitted. This method is more + /// efficient than [`map_unordered`](Self::map_unordered) when the input is an + /// [`ExactSizeIterator`]. + pub fn swarm_unordered(&self, batch: T) -> impl Iterator> + where + T: IntoIterator, + T::IntoIter: ExactSizeIterator, + { + let (tx, rx) = outcome_channel(); + let task_ids = self.send_batch(batch, Some(tx)); + rx.select_unordered(task_ids) + } + + /// Sends a `batch` of inputs to the `Hive` for processing, and returns a [`Vec`] of task IDs. + /// The [`Outcome`]s of the tasks will be sent to `tx` upon completion. + /// + /// This method is more efficient than [`map_send`](Self::map_send) when the input is an + /// [`ExactSizeIterator`]. + pub fn swarm_send(&self, batch: T, outcome_tx: OutcomeSender) -> Vec + where + T: IntoIterator, + T::IntoIter: ExactSizeIterator, + { + self.send_batch(batch, Some(outcome_tx)) + } + + /// Sends a `batch` of inputs to the `Hive` for processing, and returns a [`Vec`] of task IDs. + /// The [`Outcome`]s of the task are retained and available for later retrieval. + /// + /// This method is more efficient than `map_store` when the input is an [`ExactSizeIterator`]. + pub fn swarm_store(&self, batch: T) -> Vec + where + T: IntoIterator, + T::IntoIter: ExactSizeIterator, + { + self.send_batch(batch, None) + } + + /// Iterates over `inputs` and sends each one to the `Hive` for processing and returns an + /// iterator over the [`Outcome`]s in the same order as the inputs. + /// + /// [`swarm`](Self::swarm) should be preferred when `inputs` is an [`ExactSizeIterator`]. + pub fn map( + &self, + inputs: impl IntoIterator, + ) -> impl Iterator> { + let (tx, rx) = outcome_channel(); + let task_ids: Vec<_> = inputs + .into_iter() + .map(|task| self.apply_send(task, tx.clone())) + .collect(); + rx.select_ordered(task_ids) + } + + /// Iterates over `inputs`, sends each one to the `Hive` for processing, and returns an + /// iterator over the [`Outcome`]s in order they become available. + /// + /// [`swarm_unordered`](Self::swarm_unordered) should be preferred when `inputs` is an + /// [`ExactSizeIterator`]. + pub fn map_unordered( + &self, + inputs: impl IntoIterator, + ) -> impl Iterator> { + let (tx, rx) = outcome_channel(); + // `map` is required (rather than `inspect`) because we need owned items + let task_ids: Vec<_> = inputs + .into_iter() + .map(|task| self.apply_send(task, tx.clone())) + .collect(); + rx.select_unordered(task_ids) + } + + /// Iterates over `inputs` and sends each one to the `Hive` for processing. Returns a [`Vec`] + /// of task IDs. The [`Outcome`]s of the tasks will be sent to `tx` upon completion. + /// + /// [`swarm_send`](Self::swarm_send) should be preferred when `inputs` is an + /// [`ExactSizeIterator`]. + pub fn map_send( + &self, + inputs: impl IntoIterator, + tx: OutcomeSender, + ) -> Vec { + inputs + .into_iter() + .map(|input| self.apply_send(input, tx.clone())) + .collect() + } + + /// Iterates over `inputs` and sends each one to the `Hive` for processing. Returns a [`Vec`] + /// of task IDs. The [`Outcome`]s of the task are retained and available for later retrieval. + /// + /// [`swarm_store`](Self::swarm_store) should be preferred when `inputs` is an + /// [`ExactSizeIterator`]. + pub fn map_store(&self, inputs: impl IntoIterator) -> Vec { + inputs + .into_iter() + .map(|input| self.apply_store(input)) + .collect() + } + + /// Iterates over `items` and calls `f` with a mutable reference to a state value (initialized + /// to `init`) and each item. `F` returns an input that is sent to the `Hive` for processing. + /// Returns an [`OutcomeBatch`] of the outputs and the final state value. + pub fn scan( + &self, + items: impl IntoIterator, + init: St, + f: F, + ) -> (OutcomeBatch, St) + where + F: FnMut(&mut St, T) -> W::Input, + { + let (tx, rx) = outcome_channel(); + let (task_ids, fold_value) = self.scan_send(items, tx, init, f); + let outcomes = rx.select_unordered(task_ids).into(); + (outcomes, fold_value) + } + + /// Iterates over `items` and calls `f` with a mutable reference to a state value (initialized + /// to `init`) and each item. `F` returns an input that is sent to the `Hive` for processing, + /// or an error. Returns an [`OutcomeBatch`] of the outputs, a [`Vec`] of errors, and the final + /// state value. + pub fn try_scan( + &self, + items: impl IntoIterator, + init: St, + mut f: F, + ) -> (OutcomeBatch, Vec, St) + where + F: FnMut(&mut St, T) -> Result, + { + let (tx, rx) = outcome_channel(); + let (task_ids, errors, fold_value) = items.into_iter().fold( + (Vec::new(), Vec::new(), init), + |(mut task_ids, mut errors, mut acc), inp| { + match f(&mut acc, inp) { + Ok(input) => task_ids.push(self.apply_send(input, tx.clone())), + Err(err) => errors.push(err), + } + (task_ids, errors, acc) + }, + ); + let outcomes = rx.select_unordered(task_ids).into(); + (outcomes, errors, fold_value) + } + + /// Iterates over `items` and calls `f` with a mutable reference to a state value (initialized + /// to `init`) and each item. `f` returns an input that is sent to the `Hive` for processing. + /// The outputs are sent to `tx` in the order they become available. Returns a [`Vec`] of the + /// task IDs and the final state value. + pub fn scan_send( + &self, + items: impl IntoIterator, + tx: OutcomeSender, + init: St, + mut f: F, + ) -> (Vec, St) + where + F: FnMut(&mut St, T) -> W::Input, + { + items + .into_iter() + .fold((Vec::new(), init), |(mut task_ids, mut acc), item| { + let input = f(&mut acc, item); + task_ids.push(self.apply_send(input, tx.clone())); + (task_ids, acc) + }) + } + + /// Iterates over `items` and calls `f` with a mutable reference to a state value (initialized + /// to `init`) and each item. `f` returns an input that is sent to the `Hive` for processing, + /// or an error. The outputs are sent to `tx` in the order they become available. This + /// function returns the final state value and a [`Vec`] of results, where each result is + /// either a task ID or an error. + pub fn try_scan_send( + &self, + items: impl IntoIterator, + tx: OutcomeSender, + init: St, + mut f: F, + ) -> (Vec>, St) + where + F: FnMut(&mut St, T) -> Result, + { + items + .into_iter() + .fold((Vec::new(), init), |(mut results, mut acc), inp| { + results.push(f(&mut acc, inp).map(|input| self.apply_send(input, tx.clone()))); + (results, acc) + }) + } + + /// Iterates over `items` and calls `f` with a mutable reference to a state value (initialized + /// to `init`) and each item. `f` returns an input that is sent to the `Hive` for processing. + /// This function returns the final state value and a [`Vec`] of task IDs. The [`Outcome`]s of + /// the tasks are retained and available for later retrieval. + pub fn scan_store( + &self, + items: impl IntoIterator, + init: St, + mut f: F, + ) -> (Vec, St) + where + F: FnMut(&mut St, T) -> W::Input, + { + items + .into_iter() + .fold((Vec::new(), init), |(mut task_ids, mut acc), item| { + let input = f(&mut acc, item); + task_ids.push(self.apply_store(input)); + (task_ids, acc) + }) + } + + /// Iterates over `items` and calls `f` with a mutable reference to a state value (initialized + /// to `init`) and each item. `f` returns an input that is sent to the `Hive` for processing, + /// or an error. This function returns the final value of the state value and a [`Vec`] of + /// results, where each result is either a task ID or an error. The [`Outcome`]s of the + /// tasks are retained and available for later retrieval. + pub fn try_scan_store( + &self, + items: impl IntoIterator, + init: St, + mut f: F, + ) -> (Vec>, St) + where + F: FnMut(&mut St, T) -> Result, + { + items + .into_iter() + .fold((Vec::new(), init), |(mut results, mut acc), item| { + results.push(f(&mut acc, item).map(|input| self.apply_store(input))); + (results, acc) + }) + } + + /// Blocks the calling thread until all tasks finish. + pub fn join(&self) { + (&self.0.as_ref().unwrap().shared).wait_on_done(); + } + + /// Returns the [`MutexGuard`](parking_lot::MutexGuard) for the [`Queen`]. + /// + /// Note that the `Queen` will remain locked until the returned guard is dropped, and that + /// locking the `Queen` prevents new worker threads from being started. + pub fn queen(&self) -> impl Deref + '_ { + (&self.0.as_ref().unwrap().shared).queen.lock() + } + + /// Returns the number of worker threads that have been requested, i.e., the maximum number of + /// tasks that could be processed concurrently. This may be greater than + /// [`active_workers`](Self::active_workers) if any of the worker threads failed to start. + pub fn max_workers(&self) -> usize { + (&self.0.as_ref().unwrap().shared) + .config + .num_threads + .get_or_default() + } + + /// Returns the number of worker threads that have been successfully started. This may be + /// fewer than [`max_workers`](Self::max_workers) if any of the worker threads failed to start. + pub fn alive_workers(&self) -> usize { + (&self.0.as_ref().unwrap().shared) + .spawn_results + .lock() + .iter() + .filter(|result| result.is_ok()) + .count() + } + + /// Returns `true` if there are any "dead" worker threads that failed to spawn. + pub fn has_dead_workers(&self) -> bool { + (&self.0.as_ref().unwrap().shared) + .spawn_results + .lock() + .iter() + .any(|result| result.is_err()) + } + + /// Attempts to respawn any dead worker threads. Returns the number of worker threads that were + /// successfully respawned. + pub fn revive_workers(&self) -> usize { + let shared = &self.0.as_ref().unwrap().shared; + shared + .respawn_dead_threads(|thread_index| Self::try_spawn(thread_index, Arc::clone(shared))) + } + + /// Returns the number of tasks currently (queued for processing, being processed). + pub fn num_tasks(&self) -> (u64, u64) { + (&self.0.as_ref().unwrap().shared).num_tasks() + } + + /// Returns the number of times one of this `Hive`'s worker threads has panicked. + pub fn num_panics(&self) -> usize { + (&self.0.as_ref().unwrap().shared).num_panics.get() + } + + /// Returns `true` if this `Hive` has been poisoned - i.e., its internal state has been + /// corrupted such that it is no longer able to process tasks. + /// + /// Note that, when a `Hive` is poisoned, it is still possible to call methods that extract + /// its stored [`Outcome`]s (e.g., [`take_stored`](Self::take_stored)) or consume it (e.g., + /// [`try_into_husk`](Self::try_into_husk)). + pub fn is_poisoned(&self) -> bool { + (&self.0.as_ref().unwrap().shared).is_poisoned() + } + + /// Returns `true` if the suspended flag is set. + pub fn is_suspended(&self) -> bool { + (&self.0.as_ref().unwrap().shared).is_suspended() + } + + /// Sets the suspended flag, which notifies worker threads that they a) MAY terminate their + /// current task early (returning an [`Outcome::Unprocessed`]), and b) MUST not accept new + /// tasks, and instead block until the suspended flag is cleared. + /// + /// Call [`resume`](Self::resume) to unset the suspended flag and continue processing tasks. + /// + /// Note: this does *not* prevent new tasks from being queued, and there is a window of time + /// (~1 second) after the suspended flag is set within which a worker thread may still accept a + /// new task. + /// + /// # Examples + /// + /// ``` + /// use beekeeper::bee::stock::{Thunk, ThunkWorker}; + /// use beekeeper::hive::Builder; + /// use std::thread; + /// use std::time::Duration; + /// + /// # fn main() { + /// let hive = Builder::new() + /// .num_threads(4) + /// .build_with_default::>(); + /// hive.map((0..10).map(|_| Thunk::of(|| thread::sleep(Duration::from_secs(3))))); + /// thread::sleep(Duration::from_secs(1)); // Allow first set of tasks to be started. + /// // There should be 4 active tasks and 6 queued tasks. + /// hive.suspend(); + /// assert_eq!(hive.num_tasks(), (6, 4)); + /// // Wait for active tasks to complete. + /// hive.join(); + /// assert_eq!(hive.num_tasks(), (6, 0)); + /// hive.resume(); + /// // Wait for remaining tasks to complete. + /// hive.join(); + /// assert_eq!(hive.num_tasks(), (0, 0)); + /// # } + /// ``` + pub fn suspend(&self) { + (&self.0.as_ref().unwrap().shared).set_suspended(true); + } + + /// Unsets the suspended flag, allowing worker threads to continue processing queued tasks. + pub fn resume(&self) { + (&self.0.as_ref().unwrap().shared).set_suspended(false); + } + + /// Removes all `Unprocessed` outcomes from this `Hive` and returns them as an iterator over + /// the input values. + fn take_unprocessed_inputs(&self) -> impl ExactSizeIterator { + (&self.0.as_ref().unwrap().shared) + .take_unprocessed() + .into_iter() + .map(|outcome| match outcome { + Outcome::Unprocessed { input, task_id: _ } => input, + _ => unreachable!(), + }) + } + + /// If this `Hive` is suspended, resumes this `Hive` and re-submits any unprocessed tasks for + /// processing, with their results to be sent to `tx`. Returns a [`Vec`] of task IDs that + /// were resumed. + pub fn resume_send(&self, outcome_tx: OutcomeSender) -> Vec { + (&self.0.as_ref().unwrap().shared) + .set_suspended(false) + .then(|| self.swarm_send(self.take_unprocessed_inputs(), outcome_tx)) + .unwrap_or_default() + } + + /// If this `Hive` is suspended, resumes this `Hive` and re-submit any unprocessed tasks for + /// processing, with their results to be stored in the queue. Returns a [`Vec`] of task IDs + /// that were resumed. + pub fn resume_store(&self) -> Vec { + (&self.0.as_ref().unwrap().shared) + .set_suspended(false) + .then(|| self.swarm_store(self.take_unprocessed_inputs())) + .unwrap_or_default() + } + + /// Returns all stored outcomes as a [`HashMap`] of task IDs to `Outcome`s. + pub fn take_stored(&self) -> HashMap> { + (&self.0.as_ref().unwrap().shared).take_outcomes() + } + + /// Consumes this `Hive` and attempts to return a [`Husk`] containing the remnants of this + /// `Hive`, including any stored task outcomes, and all the data necessary to create a new + /// `Hive`. + /// + /// If this `Hive` has been cloned, and those clones have not been dropped, this method + /// returns `None` since it cannot take exclusive ownership of the internal shared data. + /// + /// This method first joins on the `Hive` to wait for all tasks to finish. + pub fn try_into_husk(mut self) -> Option> { + if (&self.0.as_ref().unwrap().shared).num_referrers() > 1 { + return None; + } + // take the inner value and replace it with `None` + let inner = self.0.take().unwrap(); + // wait for all tasks to finish + inner.shared.wait_on_done(); + // drop the task sender so receivers will drop automatically + drop(inner.task_tx); + // wait for worker threads to drop, then take ownership of the shared data and convert it + // into a Husk + let mut shared = inner.shared; + let mut backoff = None::; + loop { + // TODO: may want to have some timeout or other kind of limit to prevent this from + // looping forever if a worker thread somehow gets stuck, or if the `num_referrers` + // counter is corrupted + shared = match Arc::try_unwrap(shared) { + Ok(shared) => { + return Some(shared.try_into_husk()); + } + Err(shared) => { + backoff.get_or_insert_with(Backoff::new).spin(); + shared + } + }; + } + } +} + +impl Default for Hive> { + fn default() -> Self { + Builder::default().build_with_default::() + } +} + +impl> Clone for Hive { + /// Creates a shallow copy of this `Hive` containing references to its same internal state, + /// i.e., all clones of a `Hive` submit tasks to the same shared worker thread pool. + fn clone(&self) -> Self { + let inner = self.0.as_ref().unwrap(); + (&inner.shared).referrer_is_cloning(); + Self(Some(inner.clone())) + } +} + +impl> Clone for HiveInner { + fn clone(&self) -> Self { + HiveInner { + task_tx: self.task_tx.clone(), + shared: Arc::clone(&self.shared), + } + } +} + +impl> Debug for Hive { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if let Some(inner) = self.0.as_ref() { + f.debug_struct("Hive") + .field("task_tx", &inner.task_tx) + .field("shared", &inner.shared) + .finish() + } else { + f.write_str("Hive {}") + } + } +} + +impl> PartialEq for Hive { + fn eq(&self, other: &Hive) -> bool { + let self_shared = &self.0.as_ref().unwrap().shared; + let other_shared = &other.0.as_ref().unwrap().shared; + Arc::ptr_eq(self_shared, other_shared) + } +} + +impl> Eq for Hive {} + +impl> DerefOutcomes for Hive { + #[inline] + fn outcomes_deref(&self) -> impl Deref>> { + (&self.0.as_ref().unwrap().shared).outcomes() + } + + #[inline] + fn outcomes_deref_mut(&mut self) -> impl DerefMut>> { + (&self.0.as_ref().unwrap().shared).outcomes() + } +} + +impl> Drop for Hive { + fn drop(&mut self) { + // if this Hive has already been turned into a Husk, it's inner value will be `None` + if let Some(inner) = self.0.as_ref() { + // reduce the referrer count + let _ = inner.shared.referrer_is_dropping(); + // if this Hive is the only one with a pointer to the shared data, poison it + // to prevent any worker threads that still have access to the shared data from + // re-spawning. + if inner.shared.num_referrers() == 0 { + inner.shared.poison(); + } + } + } +} + +/// Sentinel for a worker thread. Until the sentinel is cancelled, it will respawn the worker +/// thread if it panics. +struct Sentinel, L: LocalQueues> { + thread_index: usize, + shared: Arc>, + active: bool, +} + +impl, L: LocalQueues> Sentinel { + fn new(thread_index: usize, shared: Arc>) -> Self { + Self { + thread_index, + shared, + active: true, + } + } + + /// Cancel and destroy this sentinel. + fn cancel(mut self) { + self.active = false; + } +} + +impl, L: LocalQueues> Drop for Sentinel { + fn drop(&mut self) { + if self.active { + // if the sentinel is active, that means the thread panicked during task execution, so + // we have to finish the task here before respawning + self.shared.finish_task(thread::panicking()); + // only respawn if the sentinel is active and the hive has not been poisoned + if !self.shared.is_poisoned() { + // can't do anything with the previous result + let _ = self + .shared + .respawn_thread(self.thread_index, |thread_index| { + Hive::try_spawn(thread_index, Arc::clone(&self.shared)) + }); + } + } + } +} + +#[cfg(not(feature = "affinity"))] +mod no_affinity { + use crate::bee::{Queen, Worker}; + use crate::hive::{Hive, LocalQueues, Shared}; + + impl> Hive { + #[inline] + pub(super) fn init_thread>(_: usize, _: &Shared) {} + } +} + +#[cfg(feature = "affinity")] +mod affinity { + use crate::bee::{Queen, Worker}; + use crate::hive::cores::Cores; + use crate::hive::{Hive, Poisoned, Shared}; + + impl> Hive { + /// Tries to pin the worker thread to a specific CPU core. + #[inline] + pub(super) fn init_thread(thread_index: usize, shared: &Shared) { + if let Some(core) = shared.get_core_affinity(thread_index) { + core.try_pin_current(); + } + } + + /// Attempts to increase the number of worker threads by `num_threads`. + /// + /// The provided `affinity` specifies additional CPU core indices to which the worker + /// threads may be pinned - these are added to the existing pool of core indices (if any). + /// + /// Returns the number of new worker threads that were successfully started (which may be + /// fewer than `num_threads`) or a `Poisoned` error if the hive has been poisoned. + pub fn grow_with_affinity>( + &self, + num_threads: usize, + affinity: C, + ) -> Result { + (&self.0.as_ref().unwrap().shared).add_core_affinity(affinity.into()); + self.grow(num_threads) + } + + /// Sets the number of worker threads to the number of available CPU cores. An attempt is + /// made to pin each worker thread to a different CPU core. + /// + /// Returns the number of new threads spun up (if any) or a `Poisoned` error if the hive + /// has been poisoned. + pub fn use_all_cores_with_affinity(&self) -> Result { + (&self.0.as_ref().unwrap().shared).add_core_affinity(Cores::all()); + self.use_all_cores() + } + } +} + +#[cfg(feature = "batching")] +mod batching { + use crate::bee::{Queen, Worker}; + use crate::hive::Hive; + + impl> Hive { + /// Returns the batch size for worker threads. + pub fn worker_batch_size(&self) -> usize { + (&self.0.as_ref().unwrap().shared).batch_size() + } + + /// Sets the batch size for worker threads. This will block the current thread until all + /// worker thread queues can be resized. + pub fn set_worker_batch_size(&self, batch_size: usize) { + (&self.0.as_ref().unwrap().shared).set_batch_size(batch_size); + } + } +} + +#[cfg(not(feature = "retry"))] +mod no_retry { + use crate::bee::{Queen, Worker}; + use crate::hive::{Hive, LocalQueue, Outcome, Shared, Task}; + + impl> Hive { + #[inline] + pub(super) fn execute>( + task: Task, + _thread_index: usize, + worker: &mut W, + shared: &Shared, + ) { + let (input, ctx, outcome_tx) = task.into_parts(); + let result = worker.apply(input, &ctx); + let outcome = Outcome::from_worker_result(result, ctx.task_id()); + shared.send_or_store_outcome(outcome, outcome_tx); + } + } +} + +#[cfg(feature = "retry")] +mod retry { + use crate::bee::{ApplyError, Queen, Worker}; + use crate::hive::{Hive, LocalQueues, Outcome, Shared, Task}; + + impl> Hive { + #[inline] + pub(super) fn execute>( + task: Task, + thread_index: usize, + worker: &mut W, + shared: &Shared, + ) { + let (input, mut ctx, outcome_tx) = task.into_parts(); + match worker.apply(input, &ctx) { + Err(ApplyError::Retryable { input, .. }) if shared.can_retry(&ctx) => { + ctx.inc_attempt(); + shared.queue_retry(thread_index, input, ctx, outcome_tx); + } + result => { + let outcome = Outcome::from_worker_result(result, ctx.task_id()); + shared.send_or_store_outcome(outcome, outcome_tx); + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::Poisoned; + use crate::bee::stock::{Caller, Thunk, ThunkWorker}; + use crate::hive::{outcome_channel, Builder, Outcome, OutcomeIteratorExt}; + use std::collections::HashMap; + use std::thread; + use std::time::Duration; + + #[test] + fn test_suspend() { + let hive = Builder::new() + .num_threads(4) + .build_with_default::>(); + let outcome_iter = + hive.map((0..10).map(|_| Thunk::of(|| thread::sleep(Duration::from_secs(3))))); + // Allow first set of tasks to be started. + thread::sleep(Duration::from_secs(1)); + // There should be 4 active tasks and 6 queued tasks. + hive.suspend(); + assert_eq!(hive.num_tasks(), (6, 4)); + // Wait for active tasks to complete. + hive.join(); + assert_eq!(hive.num_tasks(), (6, 0)); + hive.resume(); + // Wait for remaining tasks to complete. + hive.join(); + assert_eq!(hive.num_tasks(), (0, 0)); + let outputs: Vec<_> = outcome_iter.into_outputs().collect(); + assert_eq!(outputs.len(), 10); + } + + #[test] + fn test_spawn_after_poison() { + let hive = Builder::new() + .num_threads(4) + .build_with_default::>(); + assert_eq!(hive.max_workers(), 4); + assert_eq!(hive.alive_workers(), 4); + // poison hive using private method + hive.0.as_ref().unwrap().shared.poison(); + // attempt to spawn a new task + assert!(matches!(hive.grow(1), Err(Poisoned))); + // make sure the worker count wasn't increased + assert_eq!(hive.max_workers(), 4); + assert_eq!(hive.alive_workers(), 4); + } + + #[test] + fn test_apply_after_poison() { + let hive = Builder::new() + .num_threads(4) + .build_with(Caller::of(|i: usize| i * 2)); + // poison hive using private method + hive.0.as_ref().unwrap().shared.poison(); + // submit a task, check that it comes back unprocessed + let (tx, rx) = outcome_channel(); + let sent_input = 1; + let sent_task_id = hive.apply_send(sent_input, tx.clone()); + let outcome = rx.recv().unwrap(); + match outcome { + Outcome::Unprocessed { input, task_id } => { + assert_eq!(input, sent_input); + assert_eq!(task_id, sent_task_id); + } + _ => panic!("Expected unprocessed outcome"), + } + } + + #[test] + fn test_swarm_after_poison() { + let hive = Builder::new() + .num_threads(4) + .build_with(Caller::of(|i: usize| i * 2)); + // poison hive using private method + hive.0.as_ref().unwrap().shared.poison(); + // submit a task, check that it comes back unprocessed + let (tx, rx) = outcome_channel(); + let inputs = 0..10; + let task_ids: HashMap = hive + .swarm_send(inputs.clone(), tx) + .into_iter() + .zip(inputs) + .collect(); + for outcome in rx.into_iter().take(10) { + match outcome { + Outcome::Unprocessed { input, task_id } => { + let expected_input = task_ids.get(&task_id); + assert!(expected_input.is_some()); + assert_eq!(input, *expected_input.unwrap()); + } + _ => panic!("Expected unprocessed outcome"), + } + } + } +} diff --git a/src/hive/task/delay.rs b/src/hive/task/delay.rs new file mode 100644 index 0000000..7e13186 --- /dev/null +++ b/src/hive/task/delay.rs @@ -0,0 +1,165 @@ +use std::cell::UnsafeCell; +use std::cmp::Ordering; +use std::collections::BinaryHeap; +use std::time::{Duration, Instant}; + +/// A queue where each item has an associated `Instant` at which it will be available. +/// +/// This is implemented internally as a `UnsafeCell`. +/// +/// SAFETY: This data structure is designed to enable the queue to be modified by a *single thread* +/// using interior mutability. `UnsafeCell` is used for performance - this is safe so long as the +/// queue is only accessed from a single thread at a time. This data structure is *not* thread-safe. +#[derive(Debug)] +pub struct DelayQueue(UnsafeCell>>); + +impl DelayQueue { + /// Pushes an item onto the queue. Returns the `Instant` at which the item will be available, + /// or an error with `item` if there was an error pushing the item. + pub fn push(&self, item: T, delay: Duration) -> Result { + unsafe { + match self.0.get().as_mut() { + Some(queue) => { + let delayed = Delayed::new(item, delay); + let until = delayed.until; + queue.push(delayed); + Ok(until) + } + None => Err(item), + } + } + } + + /// Returns the `Instant` at which the next item will be available. Returns `None` if the queue + /// is empty. + pub fn next_available(&self) -> Option { + unsafe { + self.0 + .get() + .as_ref() + .and_then(|queue| queue.peek().map(|head| head.until)) + } + } + + /// Returns the item at the head of the queue, if one exists and is available (i.e., its delay + /// has been exceeded), and removes it. + pub fn try_pop(&self) -> Option { + unsafe { + if self + .next_available() + .map(|until| until <= Instant::now()) + .unwrap_or(false) + { + self.0 + .get() + .as_mut() + .and_then(|queue| queue.pop()) + .map(|delayed| delayed.value) + } else { + None + } + } + } + + /// Drains all items from the queue and returns them as an iterator. + pub fn drain(&mut self) -> impl Iterator + '_ { + self.0.get_mut().drain().map(|delayed| delayed.value) + } +} + +unsafe impl Sync for DelayQueue {} + +impl Default for DelayQueue { + fn default() -> Self { + DelayQueue(UnsafeCell::new(BinaryHeap::new())) + } +} + +#[derive(Debug)] +struct Delayed { + value: T, + until: Instant, +} + +impl Delayed { + pub fn new(value: T, delay: Duration) -> Self { + Delayed { + value, + until: Instant::now() + delay, + } + } +} + +/// Implements ordering for `Delayed`, so it can be used to correctly order elements in the +/// `BinaryHeap` of the `DelayQueue`. +/// +/// Earlier entries have higher priority (should be popped first), so they are Greater that later +/// entries. +impl Ord for Delayed { + fn cmp(&self, other: &Delayed) -> Ordering { + other.until.cmp(&self.until) + } +} + +impl PartialOrd for Delayed { + fn partial_cmp(&self, other: &Delayed) -> Option { + Some(self.cmp(other)) + } +} + +impl PartialEq for Delayed { + fn eq(&self, other: &Delayed) -> bool { + self.cmp(other) == Ordering::Equal + } +} + +impl Eq for Delayed {} + +#[cfg(test)] +mod tests { + use super::DelayQueue; + use std::{thread, time::Duration}; + + impl DelayQueue { + fn len(&self) -> usize { + unsafe { self.0.get().as_ref().unwrap().len() } + } + } + + #[test] + fn test_works() { + let queue = DelayQueue::default(); + + queue.push(1, Duration::from_secs(1)).unwrap(); + queue.push(2, Duration::from_secs(2)).unwrap(); + queue.push(3, Duration::from_secs(3)).unwrap(); + + assert_eq!(queue.len(), 3); + assert_eq!(queue.try_pop(), None); + + thread::sleep(Duration::from_secs(1)); + assert_eq!(queue.try_pop(), Some(1)); + assert_eq!(queue.len(), 2); + + thread::sleep(Duration::from_secs(1)); + assert_eq!(queue.try_pop(), Some(2)); + assert_eq!(queue.len(), 1); + + thread::sleep(Duration::from_secs(1)); + assert_eq!(queue.try_pop(), Some(3)); + assert_eq!(queue.len(), 0); + + assert_eq!(queue.try_pop(), None); + } + + #[test] + fn test_into_vec() { + let mut queue = DelayQueue::default(); + queue.push(1, Duration::from_secs(1)).unwrap(); + queue.push(2, Duration::from_secs(2)).unwrap(); + queue.push(3, Duration::from_secs(3)).unwrap(); + let mut v: Vec<_> = queue.drain().collect(); + v.sort(); + assert_eq!(v, vec![1, 2, 3]); + } +} diff --git a/src/hive/task/iter.rs b/src/hive/task/iter.rs new file mode 100644 index 0000000..90bd8b4 --- /dev/null +++ b/src/hive/task/iter.rs @@ -0,0 +1,27 @@ +use crate::bee::{Queen, Worker}; +use crate::hive::counter::CounterError; +use crate::hive::{LocalQueues, Shared, Task}; +use std::sync::Arc; + +#[derive(thiserror::Error, Debug)] +pub enum NextTaskError { + #[error("Task receiver disconnected")] + Disconnected, + #[error("The hive has been poisoned")] + Poisoned, + #[error("Task counter has invalid state")] + InvalidCounter(CounterError), +} + +pub struct TaskIterator, L: LocalQueues> { + thread_index: usize, + shared: Arc>, +} + +impl, L: LocalQueues> Iterator for TaskIterator { + type Item = Task; + + fn next(&mut self) -> Option { + todo!() + } +} diff --git a/src/hive/task/local.rs b/src/hive/task/local.rs new file mode 100644 index 0000000..f7dee9e --- /dev/null +++ b/src/hive/task/local.rs @@ -0,0 +1,44 @@ +#[cfg(any(feature = "batching", feature = "retry"))] +pub use channel::ChannelLocalQueues as LocalQueuesImpl; +#[cfg(not(any(feature = "batching", feature = "retry")))] +pub use null::NullLocalQueues as LocalQueuesImpl; + +#[cfg(not(any(feature = "batching", feature = "retry")))] +mod null { + use crate::bee::Worker; + use crate::hive::LocalQueues; + use std::marker::PhantomData; + + pub struct NullLocalQueues(PhantomData); + + impl LocalQueues for NullLocalQueues {} +} + +#[cfg(any(feature = "batching", feature = "retry"))] +mod channel { + use crate::bee::Worker; + use crate::hive::{LocalQueues, Task}; + use parking_lot::RwLock; + + pub struct ChannelLocalQueues { + /// thread-local queues of tasks used when the `batching` feature is enabled + #[cfg(feature = "batching")] + batch_queues: RwLock>>>, + /// thread-local queues used for tasks that are waiting to be retried after a failure + #[cfg(feature = "retry")] + retry_queues: RwLock>>>, + } + + impl LocalQueues for ChannelLocalQueues {} + + impl Default for ChannelLocalQueues { + fn default() -> Self { + Self { + #[cfg(feature = "batching")] + batch_queues: Default::default(), + #[cfg(feature = "retry")] + retry_queues: Default::default(), + } + } + } +} diff --git a/src/hive/task/mod.rs b/src/hive/task/mod.rs new file mode 100644 index 0000000..e636987 --- /dev/null +++ b/src/hive/task/mod.rs @@ -0,0 +1,41 @@ +#[cfg(feature = "retry")] +mod delay; +mod iter; +mod local; + +pub use local::LocalQueuesImpl; + +use super::{Outcome, OutcomeSender, Task}; +use crate::bee::{Context, TaskId, Worker}; + +impl Task { + /// Creates a new `Task`. + pub fn new(input: W::Input, ctx: Context, outcome_tx: Option>) -> Self { + Task { + input, + ctx, + outcome_tx, + } + } + + /// Returns the ID of this task. + pub fn id(&self) -> TaskId { + self.ctx.task_id() + } + + /// Consumes this `Task` and returns a tuple `(input, context, outcome_tx)`. + pub fn into_parts(self) -> (W::Input, Context, Option>) { + (self.input, self.ctx, self.outcome_tx) + } + + /// Consumes this `Task` and returns a `Outcome::Unprocessed` outcome with the input and ID, + /// and the outcome sender. + pub fn into_unprocessed(self) -> (Outcome, Option>) { + let (input, ctx, outcome_tx) = self.into_parts(); + let outcome = Outcome::Unprocessed { + input, + task_id: ctx.task_id(), + }; + (outcome, outcome_tx) + } +} From 4b8f7c2c0d10fa37978ff9563500928f2d9041b9 Mon Sep 17 00:00:00 2001 From: jdidion Date: Tue, 4 Feb 2025 10:24:40 -0800 Subject: [PATCH 02/67] WIP --- src/hive/hive.rs | 104 ++++++++++++++++++++------------------- src/hive/local.rs | 28 +++++++++++ src/hive/mod.rs | 23 ++++----- src/hive/shared.rs | 30 ++++++----- src/hive/workstealing.rs | 17 +++++++ 5 files changed, 127 insertions(+), 75 deletions(-) create mode 100644 src/hive/local.rs create mode 100644 src/hive/workstealing.rs diff --git a/src/hive/hive.rs b/src/hive/hive.rs index b9cf447..79de77e 100644 --- a/src/hive/hive.rs +++ b/src/hive/hive.rs @@ -1,5 +1,7 @@ use super::prelude::*; -use super::{Config, DerefOutcomes, HiveInner, OutcomeSender, Shared, SpawnError, TaskSender}; +use super::{ + Config, DerefOutcomes, HiveInner, LocalQueue, OutcomeSender, Shared, SpawnError, TaskSender, +}; use crate::atomic::Atomic; use crate::bee::{DefaultQueen, Queen, TaskId, Worker}; use crossbeam_utils::Backoff; @@ -15,9 +17,9 @@ pub struct Poisoned; impl> Hive { /// Spawns a new worker thread with the specified index and with access to the `shared` data. - fn try_spawn( + fn try_spawn>( thread_index: usize, - shared: Arc>, + shared: Arc>, ) -> Result, SpawnError> { // spawn a thread that executes the worker loop shared.thread_builder().spawn(move || { @@ -60,11 +62,6 @@ impl> Hive { &self.0.as_ref().unwrap().task_tx } - #[inline] - fn shared(&self) -> &Arc> { - &self.0.as_ref().unwrap().shared - } - /// Attempts to increase the number of worker threads by `num_threads`. Returns the number of /// new worker threads that were successfully started (which may be fewer than `num_threads`), /// or a `Poisoned` error if the hive has been poisoned. @@ -72,7 +69,7 @@ impl> Hive { if num_threads == 0 { return Ok(0); } - let shared = self.shared(); + let shared = &self.0.as_ref().unwrap().shared; // do not start any new threads if the hive is poisoned if shared.is_poisoned() { return Err(Poisoned); @@ -101,7 +98,7 @@ impl> Hive { if self.max_workers() == 0 { dbg!("WARNING: no worker threads are active for hive"); } - let shared = self.shared(); + let shared = &self.0.as_ref().unwrap().shared; let task = shared.prepare_task(input, outcome_tx); let task_id = task.id(); // try to send the task to the hive; if the hive is poisoned or if sending fails, convert @@ -158,7 +155,7 @@ impl> Hive { let task_tx = self.task_tx(); let iter = batch.into_iter(); let (batch_size, _) = iter.size_hint(); - let shared = self.shared(); + let shared = &self.0.as_ref().unwrap().shared; let batch = shared.prepare_batch(batch_size, iter, outcome_tx); if !self.is_poisoned() { batch @@ -176,7 +173,7 @@ impl> Hive { } else { // if the hive is poisoned, convert all tasks into `Unprocessed` outcomes and try to // send them to their outcome channels or store them in the hive - self.shared().abandon_batch(batch) + (&self.0.as_ref().unwrap().shared).abandon_batch(batch) } } @@ -437,7 +434,7 @@ impl> Hive { /// Blocks the calling thread until all tasks finish. pub fn join(&self) { - self.shared().wait_on_done(); + (&self.0.as_ref().unwrap().shared).wait_on_done(); } /// Returns the [`MutexGuard`](parking_lot::MutexGuard) for the [`Queen`]. @@ -445,20 +442,23 @@ impl> Hive { /// Note that the `Queen` will remain locked until the returned guard is dropped, and that /// locking the `Queen` prevents new worker threads from being started. pub fn queen(&self) -> impl Deref + '_ { - self.shared().queen.lock() + (&self.0.as_ref().unwrap().shared).queen.lock() } /// Returns the number of worker threads that have been requested, i.e., the maximum number of /// tasks that could be processed concurrently. This may be greater than /// [`active_workers`](Self::active_workers) if any of the worker threads failed to start. pub fn max_workers(&self) -> usize { - self.shared().config.num_threads.get_or_default() + (&self.0.as_ref().unwrap().shared) + .config + .num_threads + .get_or_default() } /// Returns the number of worker threads that have been successfully started. This may be /// fewer than [`max_workers`](Self::max_workers) if any of the worker threads failed to start. pub fn alive_workers(&self) -> usize { - self.shared() + (&self.0.as_ref().unwrap().shared) .spawn_results .lock() .iter() @@ -468,7 +468,7 @@ impl> Hive { /// Returns `true` if there are any "dead" worker threads that failed to spawn. pub fn has_dead_workers(&self) -> bool { - self.shared() + (&self.0.as_ref().unwrap().shared) .spawn_results .lock() .iter() @@ -478,19 +478,19 @@ impl> Hive { /// Attempts to respawn any dead worker threads. Returns the number of worker threads that were /// successfully respawned. pub fn revive_workers(&self) -> usize { - let shared = self.shared(); + let shared = &self.0.as_ref().unwrap().shared; shared .respawn_dead_threads(|thread_index| Self::try_spawn(thread_index, Arc::clone(shared))) } /// Returns the number of tasks currently (queued for processing, being processed). pub fn num_tasks(&self) -> (u64, u64) { - self.shared().num_tasks() + (&self.0.as_ref().unwrap().shared).num_tasks() } /// Returns the number of times one of this `Hive`'s worker threads has panicked. pub fn num_panics(&self) -> usize { - self.shared().num_panics.get() + (&self.0.as_ref().unwrap().shared).num_panics.get() } /// Returns `true` if this `Hive` has been poisoned - i.e., its internal state has been @@ -500,12 +500,12 @@ impl> Hive { /// its stored [`Outcome`]s (e.g., [`take_stored`](Self::take_stored)) or consume it (e.g., /// [`try_into_husk`](Self::try_into_husk)). pub fn is_poisoned(&self) -> bool { - self.shared().is_poisoned() + (&self.0.as_ref().unwrap().shared).is_poisoned() } /// Returns `true` if the suspended flag is set. pub fn is_suspended(&self) -> bool { - self.shared().is_suspended() + (&self.0.as_ref().unwrap().shared).is_suspended() } /// Sets the suspended flag, which notifies worker threads that they a) MAY terminate their @@ -545,18 +545,18 @@ impl> Hive { /// # } /// ``` pub fn suspend(&self) { - self.shared().set_suspended(true); + (&self.0.as_ref().unwrap().shared).set_suspended(true); } /// Unsets the suspended flag, allowing worker threads to continue processing queued tasks. pub fn resume(&self) { - self.shared().set_suspended(false); + (&self.0.as_ref().unwrap().shared).set_suspended(false); } /// Removes all `Unprocessed` outcomes from this `Hive` and returns them as an iterator over /// the input values. fn take_unprocessed_inputs(&self) -> impl ExactSizeIterator { - self.shared() + (&self.0.as_ref().unwrap().shared) .take_unprocessed() .into_iter() .map(|outcome| match outcome { @@ -569,7 +569,7 @@ impl> Hive { /// processing, with their results to be sent to `tx`. Returns a [`Vec`] of task IDs that /// were resumed. pub fn resume_send(&self, outcome_tx: OutcomeSender) -> Vec { - self.shared() + (&self.0.as_ref().unwrap().shared) .set_suspended(false) .then(|| self.swarm_send(self.take_unprocessed_inputs(), outcome_tx)) .unwrap_or_default() @@ -579,7 +579,7 @@ impl> Hive { /// processing, with their results to be stored in the queue. Returns a [`Vec`] of task IDs /// that were resumed. pub fn resume_store(&self) -> Vec { - self.shared() + (&self.0.as_ref().unwrap().shared) .set_suspended(false) .then(|| self.swarm_store(self.take_unprocessed_inputs())) .unwrap_or_default() @@ -587,7 +587,7 @@ impl> Hive { /// Returns all stored outcomes as a [`HashMap`] of task IDs to `Outcome`s. pub fn take_stored(&self) -> HashMap> { - self.shared().take_outcomes() + (&self.0.as_ref().unwrap().shared).take_outcomes() } /// Consumes this `Hive` and attempts to return a [`Husk`] containing the remnants of this @@ -599,7 +599,7 @@ impl> Hive { /// /// This method first joins on the `Hive` to wait for all tasks to finish. pub fn try_into_husk(mut self) -> Option> { - if self.shared().num_referrers() > 1 { + if (&self.0.as_ref().unwrap().shared).num_referrers() > 1 { return None; } // take the inner value and replace it with `None` @@ -640,12 +640,12 @@ impl> Clone for Hive { /// i.e., all clones of a `Hive` submit tasks to the same shared worker thread pool. fn clone(&self) -> Self { let inner = self.0.as_ref().unwrap(); - self.shared().referrer_is_cloning(); + (&inner.shared).referrer_is_cloning(); Self(Some(inner.clone())) } } -impl> Clone for HiveInner { +impl, L: LocalQueue> Clone for HiveInner { fn clone(&self) -> Self { HiveInner { task_tx: self.task_tx.clone(), @@ -669,7 +669,9 @@ impl> Debug for Hive { impl> PartialEq for Hive { fn eq(&self, other: &Hive) -> bool { - Arc::ptr_eq(self.shared(), other.shared()) + let self_shared = &self.0.as_ref().unwrap().shared; + let other_shared = &other.0.as_ref().unwrap().shared; + Arc::ptr_eq(self_shared, other_shared) } } @@ -678,12 +680,12 @@ impl> Eq for Hive {} impl> DerefOutcomes for Hive { #[inline] fn outcomes_deref(&self) -> impl Deref>> { - self.shared().outcomes() + (&self.0.as_ref().unwrap().shared).outcomes() } #[inline] fn outcomes_deref_mut(&mut self) -> impl DerefMut>> { - self.shared().outcomes() + (&self.0.as_ref().unwrap().shared).outcomes() } } @@ -705,14 +707,14 @@ impl> Drop for Hive { /// Sentinel for a worker thread. Until the sentinel is cancelled, it will respawn the worker /// thread if it panics. -struct Sentinel> { +struct Sentinel, L: LocalQueue> { thread_index: usize, - shared: Arc>, + shared: Arc>, active: bool, } -impl> Sentinel { - fn new(thread_index: usize, shared: Arc>) -> Self { +impl, L: LocalQueue> Sentinel { + fn new(thread_index: usize, shared: Arc>) -> Self { Self { thread_index, shared, @@ -726,7 +728,7 @@ impl> Sentinel { } } -impl> Drop for Sentinel { +impl, L: LocalQueue> Drop for Sentinel { fn drop(&mut self) { if self.active { // if the sentinel is active, that means the thread panicked during task execution, so @@ -748,11 +750,11 @@ impl> Drop for Sentinel { #[cfg(not(feature = "affinity"))] mod no_affinity { use crate::bee::{Queen, Worker}; - use crate::hive::{Hive, Shared}; + use crate::hive::{Hive, LocalQueue, Shared}; impl> Hive { #[inline] - pub(super) fn init_thread(_: usize, _: &Shared) {} + pub(super) fn init_thread>(_: usize, _: &Shared) {} } } @@ -783,7 +785,7 @@ mod affinity { num_threads: usize, affinity: C, ) -> Result { - self.shared().add_core_affinity(affinity.into()); + (&self.0.as_ref().unwrap().shared).add_core_affinity(affinity.into()); self.grow(num_threads) } @@ -793,7 +795,7 @@ mod affinity { /// Returns the number of new threads spun up (if any) or a `Poisoned` error if the hive /// has been poisoned. pub fn use_all_cores_with_affinity(&self) -> Result { - self.shared().add_core_affinity(Cores::all()); + (&self.0.as_ref().unwrap().shared).add_core_affinity(Cores::all()); self.use_all_cores() } } @@ -807,13 +809,13 @@ mod batching { impl> Hive { /// Returns the batch size for worker threads. pub fn worker_batch_size(&self) -> usize { - self.shared().batch_size() + (&self.0.as_ref().unwrap().shared).batch_size() } /// Sets the batch size for worker threads. This will block the current thread until all /// worker thread queues can be resized. pub fn set_worker_batch_size(&self, batch_size: usize) { - self.shared().set_batch_size(batch_size); + (&self.0.as_ref().unwrap().shared).set_batch_size(batch_size); } } } @@ -821,15 +823,15 @@ mod batching { #[cfg(not(feature = "retry"))] mod no_retry { use crate::bee::{Queen, Worker}; - use crate::hive::{Hive, Outcome, Shared, Task}; + use crate::hive::{Hive, LocalQueue, Outcome, Shared, Task}; impl> Hive { #[inline] - pub(super) fn execute( + pub(super) fn execute>( task: Task, _thread_index: usize, worker: &mut W, - shared: &Shared, + shared: &Shared, ) { let (input, ctx, outcome_tx) = task.into_parts(); let result = worker.apply(input, &ctx); @@ -907,7 +909,7 @@ mod tests { assert_eq!(hive.max_workers(), 4); assert_eq!(hive.alive_workers(), 4); // poison hive using private method - hive.shared().poison(); + hive.0.as_ref().unwrap().shared.poison(); // attempt to spawn a new task assert!(matches!(hive.grow(1), Err(Poisoned))); // make sure the worker count wasn't increased @@ -921,7 +923,7 @@ mod tests { .num_threads(4) .build_with(Caller::of(|i: usize| i * 2)); // poison hive using private method - hive.shared().poison(); + hive.0.as_ref().unwrap().shared.poison(); // submit a task, check that it comes back unprocessed let (tx, rx) = outcome_channel(); let sent_input = 1; @@ -942,7 +944,7 @@ mod tests { .num_threads(4) .build_with(Caller::of(|i: usize| i * 2)); // poison hive using private method - hive.shared().poison(); + hive.0.as_ref().unwrap().shared.poison(); // submit a task, check that it comes back unprocessed let (tx, rx) = outcome_channel(); let inputs = 0..10; diff --git a/src/hive/local.rs b/src/hive/local.rs new file mode 100644 index 0000000..5d9b022 --- /dev/null +++ b/src/hive/local.rs @@ -0,0 +1,28 @@ +#[cfg(any(feature = "batching", feature = "retry"))] +pub use channel::ChannelLocalQueues as LocalQueuesImpl; +#[cfg(not(any(feature = "batching", feature = "retry")))] +pub use null::NullLocalQueues as LocalQueuesImpl; + +#[cfg(not(any(feature = "batching", feature = "retry")))] +mod null { + use crate::hive::LocalQueues; + use crate::bee::Worker; + use std::marker::PhantomData; + + pub struct NullLocalQueues(PhantomData); + + impl LocalQueues for NullLocalQueues {} +} + +#[cfg(any(feature = "batching", feature = "retry"))] +mod channel { + use crate::hive::LocalQueues; + + pub struct ChannelLocalQueues { + /// worker thread-specific queues of tasks used when the `batching` feature is enabled + batch_queues: parking_lot::RwLock>>, + /// queue used for tasks that are waiting to be retried after a failure + #[cfg(feature = "retry")] + retry_queues: parking_lot::RwLock>>>, + } +} diff --git a/src/hive/mod.rs b/src/hive/mod.rs index 19d26df..a43fca8 100644 --- a/src/hive/mod.rs +++ b/src/hive/mod.rs @@ -362,12 +362,13 @@ mod gate; #[allow(clippy::module_inception)] mod hive; mod husk; +mod local; mod outcome; // TODO: scoped hive is still a WIP //mod scoped; mod shared; mod task; -//mod workstealing; +mod workstealing; #[cfg(feature = "affinity")] pub mod cores; @@ -423,17 +424,21 @@ type U32 = AtomicOption; #[cfg(feature = "retry")] type U64 = AtomicOption; +trait LocalQueues: Sized + Send + Sync + 'static {} + +type LocalQueuesImpl = local::LocalQueuesImpl; + /// A pool of worker threads that each execute the same function. /// /// See the [module documentation](crate::hive) for details. -pub struct Hive>(Option>); +pub struct Hive>(Option>>); /// A `Hive`'s inner state. Wraps a) the `Hive`'s reference to the `Shared` data (which is shared /// with the worker threads) and b) the `Sender>`, which is the sending end of the channel /// used to send tasks to the worker threads. -struct HiveInner> { +struct HiveInner, L: LocalQueues> { task_tx: TaskSender, - shared: Arc>, + shared: Arc>, } type TaskSender = std::sync::mpsc::Sender>; @@ -473,7 +478,7 @@ struct Config { } /// Data shared by all worker threads in a `Hive`. -struct Shared> { +struct Shared, L: LocalQueues> { /// core configuration parameters config: Config, /// the `Queen` used to create new workers @@ -503,12 +508,8 @@ struct Shared> { join_gate: PhasedGate, /// outcomes stored in the hive outcomes: Mutex>>, - /// worker thread-specific queues of tasks used when the `batching` feature is enabled - #[cfg(feature = "batching")] - local_queues: parking_lot::RwLock>>>, - /// queue used for tasks that are waiting to be retried after a failure - #[cfg(feature = "retry")] - retry_queues: parking_lot::RwLock>>>, + /// local queues used by worker threads to manage tasks + local_queues: L, } #[cfg(test)] diff --git a/src/hive/shared.rs b/src/hive/shared.rs index 0eb7fff..789e96a 100644 --- a/src/hive/shared.rs +++ b/src/hive/shared.rs @@ -1,5 +1,5 @@ use super::counter::CounterError; -use super::{Config, Outcome, OutcomeSender, Shared, SpawnError, Task, TaskReceiver}; +use super::{Config, LocalQueue, Outcome, OutcomeSender, Shared, SpawnError, Task, TaskReceiver}; use crate::atomic::{Atomic, AtomicInt, AtomicUsize}; use crate::bee::{Context, Queen, TaskId, Worker}; use crate::channel::SenderExt; @@ -11,7 +11,7 @@ use std::thread::{Builder, JoinHandle}; use std::time::Duration; use std::{fmt, iter, mem}; -impl> Shared { +impl, L: LocalQueue> Shared { /// Creates a new `Shared` instance with the given configuration, queen, and task receiver, /// and all other fields set to their default values. pub fn new(config: Config, queen: Q, task_rx: TaskReceiver) -> Self { @@ -29,7 +29,6 @@ impl> Shared { resume_gate: Default::default(), join_gate: Default::default(), outcomes: Default::default(), - #[cfg(feature = "batching")] local_queues: Default::default(), #[cfg(feature = "retry")] retry_queues: Default::default(), @@ -373,7 +372,7 @@ impl> Shared { } } -impl> fmt::Debug for Shared { +impl, L: LocalQueue> fmt::Debug for Shared { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let (queued, active) = self.num_tasks(); f.debug_struct("Shared") @@ -423,10 +422,10 @@ fn task_recv_timeout(rx: &TaskReceiver) -> Option, #[cfg(not(feature = "batching"))] mod no_batching { - use super::{NextTaskError, Shared, Task}; + use super::{LocalQueue, NextTaskError, Shared, Task}; use crate::bee::{Queen, Worker}; - impl> Shared { + impl, L: LocalQueue> Shared { /// Tries to receive a task from the input channel. /// /// Returns an error if the channel has disconnected. Returns `None` if a task is not @@ -442,17 +441,16 @@ mod no_batching { mod batching { use super::{NextTaskError, Shared, Task}; use crate::bee::{Queen, Worker}; + use crate::hive::LocalQueue; use crossbeam_queue::ArrayQueue; use std::collections::HashSet; use std::time::Duration; - impl> Shared { + impl, L: LocalQueue> Shared { pub(super) fn init_local_queues(&self, start_index: usize, end_index: usize) { let mut local_queues = self.local_queues.write(); assert_eq!(local_queues.len(), start_index); - // ArrayQueue cannot be zero-sized - let queue_size = self.batch_size().max(1); - (start_index..end_index).for_each(|_| local_queues.push(ArrayQueue::new(queue_size))) + (start_index..end_index).for_each(|_| local_queues.push(L::new(self))); } /// Returns the local queue batch size. @@ -603,12 +601,12 @@ pub enum NextTaskError { #[cfg(not(feature = "retry"))] mod no_retry { - use super::{NextTaskError, Task}; + use super::{LocalQueue, NextTaskError, Task}; use crate::atomic::Atomic; use crate::bee::{Queen, Worker}; use crate::hive::{Husk, Shared}; - impl> Shared { + impl, L: LocalQueue> Shared { /// Returns the next queued `Task`. The thread blocks until a new task becomes available, and /// since this requires holding a lock on the task `Reciever`, this also blocks any other /// threads that call this method. Returns `None` if the task `Sender` has hung up and there @@ -799,8 +797,14 @@ mod tests { use crate::bee::stock::ThunkWorker; use crate::bee::DefaultQueen; + #[cfg(not(feature = "batching"))] + type LocalQueue = (); + #[cfg(feature = "batching")] + type LocalQueue = crossbeam_deque::ArrayQueue>>; + type VoidThunkWorker = ThunkWorker<()>; - type VoidThunkWorkerShared = super::Shared>; + type VoidThunkWorkerShared = + super::Shared, LocalQueue>; #[test] fn test_sync_shared() { diff --git a/src/hive/workstealing.rs b/src/hive/workstealing.rs new file mode 100644 index 0000000..657ad05 --- /dev/null +++ b/src/hive/workstealing.rs @@ -0,0 +1,17 @@ +use super::{HiveInner, LocalQueue, Shared, Task}; +use crate::bee::{Context, Queen, Worker}; +use crossbeam_deque::Injector as GlobalQueue; +use std::sync::Arc; + +type WorkerQueue = crossbeam_deque::Worker>; + +impl LocalQueue for WorkerQueue { + fn new>(shared: &Arc>) -> Self + where + Self: Sized, + { + Self::new_fifo() + } +} + +pub struct WorkstealingHive>(Option>>); From 811eef861139872b6d6a613ce217473cf958823a Mon Sep 17 00:00:00 2001 From: jdidion Date: Tue, 4 Feb 2025 11:31:29 -0800 Subject: [PATCH 03/67] WIP --- CHANGELOG.md | 2 + Cargo.toml | 2 +- src/hive/delay.rs | 165 -------- src/hive/hive.rs | 967 --------------------------------------------- src/hive/local.rs | 28 -- src/hive/mod.rs | 37 +- src/hive/shared.rs | 40 +- src/hive/task.rs | 34 -- 8 files changed, 34 insertions(+), 1241 deletions(-) delete mode 100644 src/hive/delay.rs delete mode 100644 src/hive/hive.rs delete mode 100644 src/hive/local.rs delete mode 100644 src/hive/task.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index d1dd848..86bc83b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,8 @@ * Features * Added the `batching` feature, which enables worker threads to queue up batches of tasks locally, which can alleviate contention between threads in the pool, especially when there are many short-lived tasks. +* Other + * Switched to using thread-local retry queues for the implementation of the `retry` feature, to reduce thread-contention ## 0.2.1 diff --git a/Cargo.toml b/Cargo.toml index 18b49ee..95e509e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,7 +37,7 @@ name = "perf" harness = false [features] -default = [] +default = ["batching", "retry"] affinity = ["dep:core_affinity"] batching = ["dep:crossbeam-queue"] retry = [] diff --git a/src/hive/delay.rs b/src/hive/delay.rs deleted file mode 100644 index 7e13186..0000000 --- a/src/hive/delay.rs +++ /dev/null @@ -1,165 +0,0 @@ -use std::cell::UnsafeCell; -use std::cmp::Ordering; -use std::collections::BinaryHeap; -use std::time::{Duration, Instant}; - -/// A queue where each item has an associated `Instant` at which it will be available. -/// -/// This is implemented internally as a `UnsafeCell`. -/// -/// SAFETY: This data structure is designed to enable the queue to be modified by a *single thread* -/// using interior mutability. `UnsafeCell` is used for performance - this is safe so long as the -/// queue is only accessed from a single thread at a time. This data structure is *not* thread-safe. -#[derive(Debug)] -pub struct DelayQueue(UnsafeCell>>); - -impl DelayQueue { - /// Pushes an item onto the queue. Returns the `Instant` at which the item will be available, - /// or an error with `item` if there was an error pushing the item. - pub fn push(&self, item: T, delay: Duration) -> Result { - unsafe { - match self.0.get().as_mut() { - Some(queue) => { - let delayed = Delayed::new(item, delay); - let until = delayed.until; - queue.push(delayed); - Ok(until) - } - None => Err(item), - } - } - } - - /// Returns the `Instant` at which the next item will be available. Returns `None` if the queue - /// is empty. - pub fn next_available(&self) -> Option { - unsafe { - self.0 - .get() - .as_ref() - .and_then(|queue| queue.peek().map(|head| head.until)) - } - } - - /// Returns the item at the head of the queue, if one exists and is available (i.e., its delay - /// has been exceeded), and removes it. - pub fn try_pop(&self) -> Option { - unsafe { - if self - .next_available() - .map(|until| until <= Instant::now()) - .unwrap_or(false) - { - self.0 - .get() - .as_mut() - .and_then(|queue| queue.pop()) - .map(|delayed| delayed.value) - } else { - None - } - } - } - - /// Drains all items from the queue and returns them as an iterator. - pub fn drain(&mut self) -> impl Iterator + '_ { - self.0.get_mut().drain().map(|delayed| delayed.value) - } -} - -unsafe impl Sync for DelayQueue {} - -impl Default for DelayQueue { - fn default() -> Self { - DelayQueue(UnsafeCell::new(BinaryHeap::new())) - } -} - -#[derive(Debug)] -struct Delayed { - value: T, - until: Instant, -} - -impl Delayed { - pub fn new(value: T, delay: Duration) -> Self { - Delayed { - value, - until: Instant::now() + delay, - } - } -} - -/// Implements ordering for `Delayed`, so it can be used to correctly order elements in the -/// `BinaryHeap` of the `DelayQueue`. -/// -/// Earlier entries have higher priority (should be popped first), so they are Greater that later -/// entries. -impl Ord for Delayed { - fn cmp(&self, other: &Delayed) -> Ordering { - other.until.cmp(&self.until) - } -} - -impl PartialOrd for Delayed { - fn partial_cmp(&self, other: &Delayed) -> Option { - Some(self.cmp(other)) - } -} - -impl PartialEq for Delayed { - fn eq(&self, other: &Delayed) -> bool { - self.cmp(other) == Ordering::Equal - } -} - -impl Eq for Delayed {} - -#[cfg(test)] -mod tests { - use super::DelayQueue; - use std::{thread, time::Duration}; - - impl DelayQueue { - fn len(&self) -> usize { - unsafe { self.0.get().as_ref().unwrap().len() } - } - } - - #[test] - fn test_works() { - let queue = DelayQueue::default(); - - queue.push(1, Duration::from_secs(1)).unwrap(); - queue.push(2, Duration::from_secs(2)).unwrap(); - queue.push(3, Duration::from_secs(3)).unwrap(); - - assert_eq!(queue.len(), 3); - assert_eq!(queue.try_pop(), None); - - thread::sleep(Duration::from_secs(1)); - assert_eq!(queue.try_pop(), Some(1)); - assert_eq!(queue.len(), 2); - - thread::sleep(Duration::from_secs(1)); - assert_eq!(queue.try_pop(), Some(2)); - assert_eq!(queue.len(), 1); - - thread::sleep(Duration::from_secs(1)); - assert_eq!(queue.try_pop(), Some(3)); - assert_eq!(queue.len(), 0); - - assert_eq!(queue.try_pop(), None); - } - - #[test] - fn test_into_vec() { - let mut queue = DelayQueue::default(); - queue.push(1, Duration::from_secs(1)).unwrap(); - queue.push(2, Duration::from_secs(2)).unwrap(); - queue.push(3, Duration::from_secs(3)).unwrap(); - let mut v: Vec<_> = queue.drain().collect(); - v.sort(); - assert_eq!(v, vec![1, 2, 3]); - } -} diff --git a/src/hive/hive.rs b/src/hive/hive.rs deleted file mode 100644 index 79de77e..0000000 --- a/src/hive/hive.rs +++ /dev/null @@ -1,967 +0,0 @@ -use super::prelude::*; -use super::{ - Config, DerefOutcomes, HiveInner, LocalQueue, OutcomeSender, Shared, SpawnError, TaskSender, -}; -use crate::atomic::Atomic; -use crate::bee::{DefaultQueen, Queen, TaskId, Worker}; -use crossbeam_utils::Backoff; -use std::collections::HashMap; -use std::fmt::Debug; -use std::ops::{Deref, DerefMut}; -use std::sync::{mpsc, Arc}; -use std::thread::{self, JoinHandle}; - -#[derive(thiserror::Error, Debug)] -#[error("The hive has been poisoned")] -pub struct Poisoned; - -impl> Hive { - /// Spawns a new worker thread with the specified index and with access to the `shared` data. - fn try_spawn>( - thread_index: usize, - shared: Arc>, - ) -> Result, SpawnError> { - // spawn a thread that executes the worker loop - shared.thread_builder().spawn(move || { - // perform one-time initialization of the worker thread - Self::init_thread(thread_index, &shared); - // create a Sentinel that will spawn a new thread on panic until it is cancelled - let sentinel = Sentinel::new(thread_index, Arc::clone(&shared)); - // create a new Worker instance - let mut worker = shared.create_worker(); - // execute the main loop - // get the next task to process - this decrements the queued counter and increments - // the active counter - while let Ok(task) = shared.next_task(thread_index) { - // execute the task until it succeeds or we reach maximum retries - this should - // be the only place where a panic can occur - Self::execute(task, thread_index, &mut worker, &shared); - // finish the task - decrements the active counter and notifies other threads - shared.finish_task(false); - } - // this is only reachable when the main loop exits due to the task receiver having - // disconnected; cancel the Sentinel so this thread won't be re-spawned on drop - sentinel.cancel(); - }) - } - - /// Creates a new `Hive`. This should only be called from `Builder`. - /// - /// The `Hive` will attempt to spawn the configured number of worker threads - /// (`config.num_threads`) but the actual number of threads available may be lower if there - /// are any errors during spawning. - pub(super) fn new(config: Config, queen: Q) -> Self { - let (task_tx, task_rx) = mpsc::channel(); - let shared = Arc::new(Shared::new(config.into_sync(), queen, task_rx)); - shared.init_threads(|thread_index| Self::try_spawn(thread_index, Arc::clone(&shared))); - Self(Some(HiveInner { task_tx, shared })) - } - - #[inline] - fn task_tx(&self) -> &TaskSender { - &self.0.as_ref().unwrap().task_tx - } - - /// Attempts to increase the number of worker threads by `num_threads`. Returns the number of - /// new worker threads that were successfully started (which may be fewer than `num_threads`), - /// or a `Poisoned` error if the hive has been poisoned. - pub fn grow(&self, num_threads: usize) -> Result { - if num_threads == 0 { - return Ok(0); - } - let shared = &self.0.as_ref().unwrap().shared; - // do not start any new threads if the hive is poisoned - if shared.is_poisoned() { - return Err(Poisoned); - } - let num_started = shared.grow_threads(num_threads, |thread_index| { - Self::try_spawn(thread_index, Arc::clone(shared)) - }); - Ok(num_started) - } - - /// Sets the number of worker threads to the number of available CPU cores. Returns the number - /// of new threads that were successfully started (which may be `0`), or a `Poisoned` error if - /// the hive has been poisoned. - pub fn use_all_cores(&self) -> Result { - let num_threads = num_cpus::get().saturating_sub(self.max_workers()); - self.grow(num_threads) - } - - /// Sends one input to the `Hive` for processing and returns its ID. The `Outcome` - /// of the task is sent to the `outcome_tx` channel if provided, otherwise it is retained in - /// the `Hive` for later retrieval. - /// - /// This method is called by all the `*apply*` methods. - fn send_one(&self, input: W::Input, outcome_tx: Option>) -> TaskId { - #[cfg(debug_assertions)] - if self.max_workers() == 0 { - dbg!("WARNING: no worker threads are active for hive"); - } - let shared = &self.0.as_ref().unwrap().shared; - let task = shared.prepare_task(input, outcome_tx); - let task_id = task.id(); - // try to send the task to the hive; if the hive is poisoned or if sending fails, convert - // the task into an `Unprocessed` outcome and try to send it to the outcome channel; if - // that fails, store the outcome in the hive - if let Some(abandoned_task) = if self.is_poisoned() { - Some(task) - } else { - self.task_tx().send(task).err().map(|err| err.0) - } { - shared.abandon_task(abandoned_task); - } - task_id - } - - /// Sends one `input` to the `Hive` for procesing and returns the result, blocking until the - /// result is available. Creates a channel to send the input and receive the outcome. Returns - /// an [`Outcome`] with the task output or an error. - pub fn apply(&self, input: W::Input) -> Outcome { - let (tx, rx) = outcome_channel(); - let task_id = self.send_one(input, Some(tx)); - rx.recv().unwrap_or_else(|_| Outcome::Missing { task_id }) - } - - /// Sends one `input` to the `Hive` for processing and returns its ID. The [`Outcome`] of - /// the task will be sent to `tx` upon completion. - pub fn apply_send(&self, input: W::Input, tx: OutcomeSender) -> TaskId { - self.send_one(input, Some(tx)) - } - - /// Sends one `input` to the `Hive` for processing and returns its ID immediately. The - /// [`Outcome`] of the task will be retained and available for later retrieval. - pub fn apply_store(&self, input: W::Input) -> TaskId { - self.send_one(input, None) - } - - /// Sends a `batch` of inputs to the `Hive` for processing, and returns a `Vec` of their - /// task IDs. The [`Outcome`]s of the tasks are sent to the `outcome_tx` channel if provided, - /// otherwise they are retained in the `Hive` for later retrieval. - /// - /// The batch is provided as an [`ExactSizeIterator`], which enables the hive to reserve a - /// range of task IDs (a single atomic operation) rather than one at a time. - /// - /// This method is called by all the `swarm*` methods. - fn send_batch(&self, batch: T, outcome_tx: Option>) -> Vec - where - T: IntoIterator, - T::IntoIter: ExactSizeIterator, - { - #[cfg(debug_assertions)] - if self.max_workers() == 0 { - dbg!("WARNING: no worker threads are active for hive"); - } - let task_tx = self.task_tx(); - let iter = batch.into_iter(); - let (batch_size, _) = iter.size_hint(); - let shared = &self.0.as_ref().unwrap().shared; - let batch = shared.prepare_batch(batch_size, iter, outcome_tx); - if !self.is_poisoned() { - batch - .map(|task| { - let task_id = task.id(); - // try to send the task to the hive; if sending fails, convert the task into an - // `Unprocessed` outcome and try to send it to the outcome channel; if that - // fails, store the outcome in the hive - if let Err(err) = task_tx.send(task) { - shared.abandon_task(err.0); - } - task_id - }) - .collect() - } else { - // if the hive is poisoned, convert all tasks into `Unprocessed` outcomes and try to - // send them to their outcome channels or store them in the hive - (&self.0.as_ref().unwrap().shared).abandon_batch(batch) - } - } - - /// Sends a `batch` of inputs to the `Hive` for processing, and returns an iterator over the - /// [`Outcome`]s in the same order as the inputs. - /// - /// This method is more efficient than [`map`](Self::map) when the input is an - /// [`ExactSizeIterator`]. - pub fn swarm(&self, batch: T) -> impl Iterator> - where - T: IntoIterator, - T::IntoIter: ExactSizeIterator, - { - let (tx, rx) = outcome_channel(); - let task_ids = self.send_batch(batch, Some(tx)); - rx.select_ordered(task_ids) - } - - /// Sends a `batch` of inputs to the `Hive` for processing, and returns an unordered iterator - /// over the [`Outcome`]s. - /// - /// The `Outcome`s will be sent in the order they are completed; use [`swarm`](Self::swarm) to - /// instead receive the `Outcome`s in the order they were submitted. This method is more - /// efficient than [`map_unordered`](Self::map_unordered) when the input is an - /// [`ExactSizeIterator`]. - pub fn swarm_unordered(&self, batch: T) -> impl Iterator> - where - T: IntoIterator, - T::IntoIter: ExactSizeIterator, - { - let (tx, rx) = outcome_channel(); - let task_ids = self.send_batch(batch, Some(tx)); - rx.select_unordered(task_ids) - } - - /// Sends a `batch` of inputs to the `Hive` for processing, and returns a [`Vec`] of task IDs. - /// The [`Outcome`]s of the tasks will be sent to `tx` upon completion. - /// - /// This method is more efficient than [`map_send`](Self::map_send) when the input is an - /// [`ExactSizeIterator`]. - pub fn swarm_send(&self, batch: T, outcome_tx: OutcomeSender) -> Vec - where - T: IntoIterator, - T::IntoIter: ExactSizeIterator, - { - self.send_batch(batch, Some(outcome_tx)) - } - - /// Sends a `batch` of inputs to the `Hive` for processing, and returns a [`Vec`] of task IDs. - /// The [`Outcome`]s of the task are retained and available for later retrieval. - /// - /// This method is more efficient than `map_store` when the input is an [`ExactSizeIterator`]. - pub fn swarm_store(&self, batch: T) -> Vec - where - T: IntoIterator, - T::IntoIter: ExactSizeIterator, - { - self.send_batch(batch, None) - } - - /// Iterates over `inputs` and sends each one to the `Hive` for processing and returns an - /// iterator over the [`Outcome`]s in the same order as the inputs. - /// - /// [`swarm`](Self::swarm) should be preferred when `inputs` is an [`ExactSizeIterator`]. - pub fn map( - &self, - inputs: impl IntoIterator, - ) -> impl Iterator> { - let (tx, rx) = outcome_channel(); - let task_ids: Vec<_> = inputs - .into_iter() - .map(|task| self.apply_send(task, tx.clone())) - .collect(); - rx.select_ordered(task_ids) - } - - /// Iterates over `inputs`, sends each one to the `Hive` for processing, and returns an - /// iterator over the [`Outcome`]s in order they become available. - /// - /// [`swarm_unordered`](Self::swarm_unordered) should be preferred when `inputs` is an - /// [`ExactSizeIterator`]. - pub fn map_unordered( - &self, - inputs: impl IntoIterator, - ) -> impl Iterator> { - let (tx, rx) = outcome_channel(); - // `map` is required (rather than `inspect`) because we need owned items - let task_ids: Vec<_> = inputs - .into_iter() - .map(|task| self.apply_send(task, tx.clone())) - .collect(); - rx.select_unordered(task_ids) - } - - /// Iterates over `inputs` and sends each one to the `Hive` for processing. Returns a [`Vec`] - /// of task IDs. The [`Outcome`]s of the tasks will be sent to `tx` upon completion. - /// - /// [`swarm_send`](Self::swarm_send) should be preferred when `inputs` is an - /// [`ExactSizeIterator`]. - pub fn map_send( - &self, - inputs: impl IntoIterator, - tx: OutcomeSender, - ) -> Vec { - inputs - .into_iter() - .map(|input| self.apply_send(input, tx.clone())) - .collect() - } - - /// Iterates over `inputs` and sends each one to the `Hive` for processing. Returns a [`Vec`] - /// of task IDs. The [`Outcome`]s of the task are retained and available for later retrieval. - /// - /// [`swarm_store`](Self::swarm_store) should be preferred when `inputs` is an - /// [`ExactSizeIterator`]. - pub fn map_store(&self, inputs: impl IntoIterator) -> Vec { - inputs - .into_iter() - .map(|input| self.apply_store(input)) - .collect() - } - - /// Iterates over `items` and calls `f` with a mutable reference to a state value (initialized - /// to `init`) and each item. `F` returns an input that is sent to the `Hive` for processing. - /// Returns an [`OutcomeBatch`] of the outputs and the final state value. - pub fn scan( - &self, - items: impl IntoIterator, - init: St, - f: F, - ) -> (OutcomeBatch, St) - where - F: FnMut(&mut St, T) -> W::Input, - { - let (tx, rx) = outcome_channel(); - let (task_ids, fold_value) = self.scan_send(items, tx, init, f); - let outcomes = rx.select_unordered(task_ids).into(); - (outcomes, fold_value) - } - - /// Iterates over `items` and calls `f` with a mutable reference to a state value (initialized - /// to `init`) and each item. `F` returns an input that is sent to the `Hive` for processing, - /// or an error. Returns an [`OutcomeBatch`] of the outputs, a [`Vec`] of errors, and the final - /// state value. - pub fn try_scan( - &self, - items: impl IntoIterator, - init: St, - mut f: F, - ) -> (OutcomeBatch, Vec, St) - where - F: FnMut(&mut St, T) -> Result, - { - let (tx, rx) = outcome_channel(); - let (task_ids, errors, fold_value) = items.into_iter().fold( - (Vec::new(), Vec::new(), init), - |(mut task_ids, mut errors, mut acc), inp| { - match f(&mut acc, inp) { - Ok(input) => task_ids.push(self.apply_send(input, tx.clone())), - Err(err) => errors.push(err), - } - (task_ids, errors, acc) - }, - ); - let outcomes = rx.select_unordered(task_ids).into(); - (outcomes, errors, fold_value) - } - - /// Iterates over `items` and calls `f` with a mutable reference to a state value (initialized - /// to `init`) and each item. `f` returns an input that is sent to the `Hive` for processing. - /// The outputs are sent to `tx` in the order they become available. Returns a [`Vec`] of the - /// task IDs and the final state value. - pub fn scan_send( - &self, - items: impl IntoIterator, - tx: OutcomeSender, - init: St, - mut f: F, - ) -> (Vec, St) - where - F: FnMut(&mut St, T) -> W::Input, - { - items - .into_iter() - .fold((Vec::new(), init), |(mut task_ids, mut acc), item| { - let input = f(&mut acc, item); - task_ids.push(self.apply_send(input, tx.clone())); - (task_ids, acc) - }) - } - - /// Iterates over `items` and calls `f` with a mutable reference to a state value (initialized - /// to `init`) and each item. `f` returns an input that is sent to the `Hive` for processing, - /// or an error. The outputs are sent to `tx` in the order they become available. This - /// function returns the final state value and a [`Vec`] of results, where each result is - /// either a task ID or an error. - pub fn try_scan_send( - &self, - items: impl IntoIterator, - tx: OutcomeSender, - init: St, - mut f: F, - ) -> (Vec>, St) - where - F: FnMut(&mut St, T) -> Result, - { - items - .into_iter() - .fold((Vec::new(), init), |(mut results, mut acc), inp| { - results.push(f(&mut acc, inp).map(|input| self.apply_send(input, tx.clone()))); - (results, acc) - }) - } - - /// Iterates over `items` and calls `f` with a mutable reference to a state value (initialized - /// to `init`) and each item. `f` returns an input that is sent to the `Hive` for processing. - /// This function returns the final state value and a [`Vec`] of task IDs. The [`Outcome`]s of - /// the tasks are retained and available for later retrieval. - pub fn scan_store( - &self, - items: impl IntoIterator, - init: St, - mut f: F, - ) -> (Vec, St) - where - F: FnMut(&mut St, T) -> W::Input, - { - items - .into_iter() - .fold((Vec::new(), init), |(mut task_ids, mut acc), item| { - let input = f(&mut acc, item); - task_ids.push(self.apply_store(input)); - (task_ids, acc) - }) - } - - /// Iterates over `items` and calls `f` with a mutable reference to a state value (initialized - /// to `init`) and each item. `f` returns an input that is sent to the `Hive` for processing, - /// or an error. This function returns the final value of the state value and a [`Vec`] of - /// results, where each result is either a task ID or an error. The [`Outcome`]s of the - /// tasks are retained and available for later retrieval. - pub fn try_scan_store( - &self, - items: impl IntoIterator, - init: St, - mut f: F, - ) -> (Vec>, St) - where - F: FnMut(&mut St, T) -> Result, - { - items - .into_iter() - .fold((Vec::new(), init), |(mut results, mut acc), item| { - results.push(f(&mut acc, item).map(|input| self.apply_store(input))); - (results, acc) - }) - } - - /// Blocks the calling thread until all tasks finish. - pub fn join(&self) { - (&self.0.as_ref().unwrap().shared).wait_on_done(); - } - - /// Returns the [`MutexGuard`](parking_lot::MutexGuard) for the [`Queen`]. - /// - /// Note that the `Queen` will remain locked until the returned guard is dropped, and that - /// locking the `Queen` prevents new worker threads from being started. - pub fn queen(&self) -> impl Deref + '_ { - (&self.0.as_ref().unwrap().shared).queen.lock() - } - - /// Returns the number of worker threads that have been requested, i.e., the maximum number of - /// tasks that could be processed concurrently. This may be greater than - /// [`active_workers`](Self::active_workers) if any of the worker threads failed to start. - pub fn max_workers(&self) -> usize { - (&self.0.as_ref().unwrap().shared) - .config - .num_threads - .get_or_default() - } - - /// Returns the number of worker threads that have been successfully started. This may be - /// fewer than [`max_workers`](Self::max_workers) if any of the worker threads failed to start. - pub fn alive_workers(&self) -> usize { - (&self.0.as_ref().unwrap().shared) - .spawn_results - .lock() - .iter() - .filter(|result| result.is_ok()) - .count() - } - - /// Returns `true` if there are any "dead" worker threads that failed to spawn. - pub fn has_dead_workers(&self) -> bool { - (&self.0.as_ref().unwrap().shared) - .spawn_results - .lock() - .iter() - .any(|result| result.is_err()) - } - - /// Attempts to respawn any dead worker threads. Returns the number of worker threads that were - /// successfully respawned. - pub fn revive_workers(&self) -> usize { - let shared = &self.0.as_ref().unwrap().shared; - shared - .respawn_dead_threads(|thread_index| Self::try_spawn(thread_index, Arc::clone(shared))) - } - - /// Returns the number of tasks currently (queued for processing, being processed). - pub fn num_tasks(&self) -> (u64, u64) { - (&self.0.as_ref().unwrap().shared).num_tasks() - } - - /// Returns the number of times one of this `Hive`'s worker threads has panicked. - pub fn num_panics(&self) -> usize { - (&self.0.as_ref().unwrap().shared).num_panics.get() - } - - /// Returns `true` if this `Hive` has been poisoned - i.e., its internal state has been - /// corrupted such that it is no longer able to process tasks. - /// - /// Note that, when a `Hive` is poisoned, it is still possible to call methods that extract - /// its stored [`Outcome`]s (e.g., [`take_stored`](Self::take_stored)) or consume it (e.g., - /// [`try_into_husk`](Self::try_into_husk)). - pub fn is_poisoned(&self) -> bool { - (&self.0.as_ref().unwrap().shared).is_poisoned() - } - - /// Returns `true` if the suspended flag is set. - pub fn is_suspended(&self) -> bool { - (&self.0.as_ref().unwrap().shared).is_suspended() - } - - /// Sets the suspended flag, which notifies worker threads that they a) MAY terminate their - /// current task early (returning an [`Outcome::Unprocessed`]), and b) MUST not accept new - /// tasks, and instead block until the suspended flag is cleared. - /// - /// Call [`resume`](Self::resume) to unset the suspended flag and continue processing tasks. - /// - /// Note: this does *not* prevent new tasks from being queued, and there is a window of time - /// (~1 second) after the suspended flag is set within which a worker thread may still accept a - /// new task. - /// - /// # Examples - /// - /// ``` - /// use beekeeper::bee::stock::{Thunk, ThunkWorker}; - /// use beekeeper::hive::Builder; - /// use std::thread; - /// use std::time::Duration; - /// - /// # fn main() { - /// let hive = Builder::new() - /// .num_threads(4) - /// .build_with_default::>(); - /// hive.map((0..10).map(|_| Thunk::of(|| thread::sleep(Duration::from_secs(3))))); - /// thread::sleep(Duration::from_secs(1)); // Allow first set of tasks to be started. - /// // There should be 4 active tasks and 6 queued tasks. - /// hive.suspend(); - /// assert_eq!(hive.num_tasks(), (6, 4)); - /// // Wait for active tasks to complete. - /// hive.join(); - /// assert_eq!(hive.num_tasks(), (6, 0)); - /// hive.resume(); - /// // Wait for remaining tasks to complete. - /// hive.join(); - /// assert_eq!(hive.num_tasks(), (0, 0)); - /// # } - /// ``` - pub fn suspend(&self) { - (&self.0.as_ref().unwrap().shared).set_suspended(true); - } - - /// Unsets the suspended flag, allowing worker threads to continue processing queued tasks. - pub fn resume(&self) { - (&self.0.as_ref().unwrap().shared).set_suspended(false); - } - - /// Removes all `Unprocessed` outcomes from this `Hive` and returns them as an iterator over - /// the input values. - fn take_unprocessed_inputs(&self) -> impl ExactSizeIterator { - (&self.0.as_ref().unwrap().shared) - .take_unprocessed() - .into_iter() - .map(|outcome| match outcome { - Outcome::Unprocessed { input, task_id: _ } => input, - _ => unreachable!(), - }) - } - - /// If this `Hive` is suspended, resumes this `Hive` and re-submits any unprocessed tasks for - /// processing, with their results to be sent to `tx`. Returns a [`Vec`] of task IDs that - /// were resumed. - pub fn resume_send(&self, outcome_tx: OutcomeSender) -> Vec { - (&self.0.as_ref().unwrap().shared) - .set_suspended(false) - .then(|| self.swarm_send(self.take_unprocessed_inputs(), outcome_tx)) - .unwrap_or_default() - } - - /// If this `Hive` is suspended, resumes this `Hive` and re-submit any unprocessed tasks for - /// processing, with their results to be stored in the queue. Returns a [`Vec`] of task IDs - /// that were resumed. - pub fn resume_store(&self) -> Vec { - (&self.0.as_ref().unwrap().shared) - .set_suspended(false) - .then(|| self.swarm_store(self.take_unprocessed_inputs())) - .unwrap_or_default() - } - - /// Returns all stored outcomes as a [`HashMap`] of task IDs to `Outcome`s. - pub fn take_stored(&self) -> HashMap> { - (&self.0.as_ref().unwrap().shared).take_outcomes() - } - - /// Consumes this `Hive` and attempts to return a [`Husk`] containing the remnants of this - /// `Hive`, including any stored task outcomes, and all the data necessary to create a new - /// `Hive`. - /// - /// If this `Hive` has been cloned, and those clones have not been dropped, this method - /// returns `None` since it cannot take exclusive ownership of the internal shared data. - /// - /// This method first joins on the `Hive` to wait for all tasks to finish. - pub fn try_into_husk(mut self) -> Option> { - if (&self.0.as_ref().unwrap().shared).num_referrers() > 1 { - return None; - } - // take the inner value and replace it with `None` - let inner = self.0.take().unwrap(); - // wait for all tasks to finish - inner.shared.wait_on_done(); - // drop the task sender so receivers will drop automatically - drop(inner.task_tx); - // wait for worker threads to drop, then take ownership of the shared data and convert it - // into a Husk - let mut shared = inner.shared; - let mut backoff = None::; - loop { - // TODO: may want to have some timeout or other kind of limit to prevent this from - // looping forever if a worker thread somehow gets stuck, or if the `num_referrers` - // counter is corrupted - shared = match Arc::try_unwrap(shared) { - Ok(shared) => { - return Some(shared.try_into_husk()); - } - Err(shared) => { - backoff.get_or_insert_with(Backoff::new).spin(); - shared - } - }; - } - } -} - -impl Default for Hive> { - fn default() -> Self { - Builder::default().build_with_default::() - } -} - -impl> Clone for Hive { - /// Creates a shallow copy of this `Hive` containing references to its same internal state, - /// i.e., all clones of a `Hive` submit tasks to the same shared worker thread pool. - fn clone(&self) -> Self { - let inner = self.0.as_ref().unwrap(); - (&inner.shared).referrer_is_cloning(); - Self(Some(inner.clone())) - } -} - -impl, L: LocalQueue> Clone for HiveInner { - fn clone(&self) -> Self { - HiveInner { - task_tx: self.task_tx.clone(), - shared: Arc::clone(&self.shared), - } - } -} - -impl> Debug for Hive { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if let Some(inner) = self.0.as_ref() { - f.debug_struct("Hive") - .field("task_tx", &inner.task_tx) - .field("shared", &inner.shared) - .finish() - } else { - f.write_str("Hive {}") - } - } -} - -impl> PartialEq for Hive { - fn eq(&self, other: &Hive) -> bool { - let self_shared = &self.0.as_ref().unwrap().shared; - let other_shared = &other.0.as_ref().unwrap().shared; - Arc::ptr_eq(self_shared, other_shared) - } -} - -impl> Eq for Hive {} - -impl> DerefOutcomes for Hive { - #[inline] - fn outcomes_deref(&self) -> impl Deref>> { - (&self.0.as_ref().unwrap().shared).outcomes() - } - - #[inline] - fn outcomes_deref_mut(&mut self) -> impl DerefMut>> { - (&self.0.as_ref().unwrap().shared).outcomes() - } -} - -impl> Drop for Hive { - fn drop(&mut self) { - // if this Hive has already been turned into a Husk, it's inner value will be `None` - if let Some(inner) = self.0.as_ref() { - // reduce the referrer count - let _ = inner.shared.referrer_is_dropping(); - // if this Hive is the only one with a pointer to the shared data, poison it - // to prevent any worker threads that still have access to the shared data from - // re-spawning. - if inner.shared.num_referrers() == 0 { - inner.shared.poison(); - } - } - } -} - -/// Sentinel for a worker thread. Until the sentinel is cancelled, it will respawn the worker -/// thread if it panics. -struct Sentinel, L: LocalQueue> { - thread_index: usize, - shared: Arc>, - active: bool, -} - -impl, L: LocalQueue> Sentinel { - fn new(thread_index: usize, shared: Arc>) -> Self { - Self { - thread_index, - shared, - active: true, - } - } - - /// Cancel and destroy this sentinel. - fn cancel(mut self) { - self.active = false; - } -} - -impl, L: LocalQueue> Drop for Sentinel { - fn drop(&mut self) { - if self.active { - // if the sentinel is active, that means the thread panicked during task execution, so - // we have to finish the task here before respawning - self.shared.finish_task(thread::panicking()); - // only respawn if the sentinel is active and the hive has not been poisoned - if !self.shared.is_poisoned() { - // can't do anything with the previous result - let _ = self - .shared - .respawn_thread(self.thread_index, |thread_index| { - Hive::try_spawn(thread_index, Arc::clone(&self.shared)) - }); - } - } - } -} - -#[cfg(not(feature = "affinity"))] -mod no_affinity { - use crate::bee::{Queen, Worker}; - use crate::hive::{Hive, LocalQueue, Shared}; - - impl> Hive { - #[inline] - pub(super) fn init_thread>(_: usize, _: &Shared) {} - } -} - -#[cfg(feature = "affinity")] -mod affinity { - use crate::bee::{Queen, Worker}; - use crate::hive::cores::Cores; - use crate::hive::{Hive, Poisoned, Shared}; - - impl> Hive { - /// Tries to pin the worker thread to a specific CPU core. - #[inline] - pub(super) fn init_thread(thread_index: usize, shared: &Shared) { - if let Some(core) = shared.get_core_affinity(thread_index) { - core.try_pin_current(); - } - } - - /// Attempts to increase the number of worker threads by `num_threads`. - /// - /// The provided `affinity` specifies additional CPU core indices to which the worker - /// threads may be pinned - these are added to the existing pool of core indices (if any). - /// - /// Returns the number of new worker threads that were successfully started (which may be - /// fewer than `num_threads`) or a `Poisoned` error if the hive has been poisoned. - pub fn grow_with_affinity>( - &self, - num_threads: usize, - affinity: C, - ) -> Result { - (&self.0.as_ref().unwrap().shared).add_core_affinity(affinity.into()); - self.grow(num_threads) - } - - /// Sets the number of worker threads to the number of available CPU cores. An attempt is - /// made to pin each worker thread to a different CPU core. - /// - /// Returns the number of new threads spun up (if any) or a `Poisoned` error if the hive - /// has been poisoned. - pub fn use_all_cores_with_affinity(&self) -> Result { - (&self.0.as_ref().unwrap().shared).add_core_affinity(Cores::all()); - self.use_all_cores() - } - } -} - -#[cfg(feature = "batching")] -mod batching { - use crate::bee::{Queen, Worker}; - use crate::hive::Hive; - - impl> Hive { - /// Returns the batch size for worker threads. - pub fn worker_batch_size(&self) -> usize { - (&self.0.as_ref().unwrap().shared).batch_size() - } - - /// Sets the batch size for worker threads. This will block the current thread until all - /// worker thread queues can be resized. - pub fn set_worker_batch_size(&self, batch_size: usize) { - (&self.0.as_ref().unwrap().shared).set_batch_size(batch_size); - } - } -} - -#[cfg(not(feature = "retry"))] -mod no_retry { - use crate::bee::{Queen, Worker}; - use crate::hive::{Hive, LocalQueue, Outcome, Shared, Task}; - - impl> Hive { - #[inline] - pub(super) fn execute>( - task: Task, - _thread_index: usize, - worker: &mut W, - shared: &Shared, - ) { - let (input, ctx, outcome_tx) = task.into_parts(); - let result = worker.apply(input, &ctx); - let outcome = Outcome::from_worker_result(result, ctx.task_id()); - shared.send_or_store_outcome(outcome, outcome_tx); - } - } -} - -#[cfg(feature = "retry")] -mod retry { - use crate::bee::{ApplyError, Queen, Worker}; - use crate::hive::{Hive, Outcome, Shared, Task}; - - impl> Hive { - #[inline] - pub(super) fn execute( - task: Task, - thread_index: usize, - worker: &mut W, - shared: &Shared, - ) { - let (input, mut ctx, outcome_tx) = task.into_parts(); - match worker.apply(input, &ctx) { - Err(ApplyError::Retryable { input, .. }) if shared.can_retry(&ctx) => { - ctx.inc_attempt(); - shared.queue_retry(thread_index, input, ctx, outcome_tx); - } - result => { - let outcome = Outcome::from_worker_result(result, ctx.task_id()); - shared.send_or_store_outcome(outcome, outcome_tx); - } - } - } - } -} - -#[cfg(test)] -mod tests { - use super::Poisoned; - use crate::bee::stock::{Caller, Thunk, ThunkWorker}; - use crate::hive::{outcome_channel, Builder, Outcome, OutcomeIteratorExt}; - use std::collections::HashMap; - use std::thread; - use std::time::Duration; - - #[test] - fn test_suspend() { - let hive = Builder::new() - .num_threads(4) - .build_with_default::>(); - let outcome_iter = - hive.map((0..10).map(|_| Thunk::of(|| thread::sleep(Duration::from_secs(3))))); - // Allow first set of tasks to be started. - thread::sleep(Duration::from_secs(1)); - // There should be 4 active tasks and 6 queued tasks. - hive.suspend(); - assert_eq!(hive.num_tasks(), (6, 4)); - // Wait for active tasks to complete. - hive.join(); - assert_eq!(hive.num_tasks(), (6, 0)); - hive.resume(); - // Wait for remaining tasks to complete. - hive.join(); - assert_eq!(hive.num_tasks(), (0, 0)); - let outputs: Vec<_> = outcome_iter.into_outputs().collect(); - assert_eq!(outputs.len(), 10); - } - - #[test] - fn test_spawn_after_poison() { - let hive = Builder::new() - .num_threads(4) - .build_with_default::>(); - assert_eq!(hive.max_workers(), 4); - assert_eq!(hive.alive_workers(), 4); - // poison hive using private method - hive.0.as_ref().unwrap().shared.poison(); - // attempt to spawn a new task - assert!(matches!(hive.grow(1), Err(Poisoned))); - // make sure the worker count wasn't increased - assert_eq!(hive.max_workers(), 4); - assert_eq!(hive.alive_workers(), 4); - } - - #[test] - fn test_apply_after_poison() { - let hive = Builder::new() - .num_threads(4) - .build_with(Caller::of(|i: usize| i * 2)); - // poison hive using private method - hive.0.as_ref().unwrap().shared.poison(); - // submit a task, check that it comes back unprocessed - let (tx, rx) = outcome_channel(); - let sent_input = 1; - let sent_task_id = hive.apply_send(sent_input, tx.clone()); - let outcome = rx.recv().unwrap(); - match outcome { - Outcome::Unprocessed { input, task_id } => { - assert_eq!(input, sent_input); - assert_eq!(task_id, sent_task_id); - } - _ => panic!("Expected unprocessed outcome"), - } - } - - #[test] - fn test_swarm_after_poison() { - let hive = Builder::new() - .num_threads(4) - .build_with(Caller::of(|i: usize| i * 2)); - // poison hive using private method - hive.0.as_ref().unwrap().shared.poison(); - // submit a task, check that it comes back unprocessed - let (tx, rx) = outcome_channel(); - let inputs = 0..10; - let task_ids: HashMap = hive - .swarm_send(inputs.clone(), tx) - .into_iter() - .zip(inputs) - .collect(); - for outcome in rx.into_iter().take(10) { - match outcome { - Outcome::Unprocessed { input, task_id } => { - let expected_input = task_ids.get(&task_id); - assert!(expected_input.is_some()); - assert_eq!(input, *expected_input.unwrap()); - } - _ => panic!("Expected unprocessed outcome"), - } - } - } -} diff --git a/src/hive/local.rs b/src/hive/local.rs deleted file mode 100644 index 5d9b022..0000000 --- a/src/hive/local.rs +++ /dev/null @@ -1,28 +0,0 @@ -#[cfg(any(feature = "batching", feature = "retry"))] -pub use channel::ChannelLocalQueues as LocalQueuesImpl; -#[cfg(not(any(feature = "batching", feature = "retry")))] -pub use null::NullLocalQueues as LocalQueuesImpl; - -#[cfg(not(any(feature = "batching", feature = "retry")))] -mod null { - use crate::hive::LocalQueues; - use crate::bee::Worker; - use std::marker::PhantomData; - - pub struct NullLocalQueues(PhantomData); - - impl LocalQueues for NullLocalQueues {} -} - -#[cfg(any(feature = "batching", feature = "retry"))] -mod channel { - use crate::hive::LocalQueues; - - pub struct ChannelLocalQueues { - /// worker thread-specific queues of tasks used when the `batching` feature is enabled - batch_queues: parking_lot::RwLock>>, - /// queue used for tasks that are waiting to be retried after a failure - #[cfg(feature = "retry")] - retry_queues: parking_lot::RwLock>>>, - } -} diff --git a/src/hive/mod.rs b/src/hive/mod.rs index a43fca8..8332c64 100644 --- a/src/hive/mod.rs +++ b/src/hive/mod.rs @@ -356,26 +356,22 @@ //! ([`Husk::as_builder`](crate::hive::husk::Husk::as_builder)) or a new `Hive` //! ([`Husk::into_hive`](crate::hive::husk::Husk::into_hive)). mod builder; +mod channel; mod config; mod counter; mod gate; -#[allow(clippy::module_inception)] -mod hive; mod husk; -mod local; mod outcome; -// TODO: scoped hive is still a WIP //mod scoped; mod shared; mod task; -mod workstealing; +//mod workstealing; #[cfg(feature = "affinity")] pub mod cores; -#[cfg(feature = "retry")] -mod delay; pub use self::builder::Builder; +pub use self::channel::Poisoned; #[cfg(feature = "batching")] pub use self::config::set_batch_size_default; pub use self::config::{reset_defaults, set_num_threads_default, set_num_threads_default_all}; @@ -383,7 +379,6 @@ pub use self::config::{reset_defaults, set_num_threads_default, set_num_threads_ pub use self::config::{ set_max_retries_default, set_retries_default_disabled, set_retry_factor_default, }; -pub use self::hive::Poisoned; pub use self::husk::Husk; pub use self::outcome::{Outcome, OutcomeBatch, OutcomeIteratorExt, OutcomeStore}; @@ -409,6 +404,7 @@ pub mod prelude { use self::counter::DualCounter; use self::gate::{Gate, PhasedGate}; use self::outcome::{DerefOutcomes, OwnedOutcomes}; +use self::task::LocalQueuesImpl; use crate::atomic::{AtomicAny, AtomicBool, AtomicOption, AtomicUsize}; use crate::bee::{Context, Queen, TaskId, Worker}; use parking_lot::Mutex; @@ -424,24 +420,22 @@ type U32 = AtomicOption; #[cfg(feature = "retry")] type U64 = AtomicOption; -trait LocalQueues: Sized + Send + Sync + 'static {} - -type LocalQueuesImpl = local::LocalQueuesImpl; - /// A pool of worker threads that each execute the same function. /// /// See the [module documentation](crate::hive) for details. -pub struct Hive>(Option>>); +pub struct Hive>(Option>); /// A `Hive`'s inner state. Wraps a) the `Hive`'s reference to the `Shared` data (which is shared /// with the worker threads) and b) the `Sender>`, which is the sending end of the channel /// used to send tasks to the worker threads. -struct HiveInner, L: LocalQueues> { +struct HiveInner> { task_tx: TaskSender, - shared: Arc>, + shared: Arc>>, } +/// Type alias for the input task channel sender type TaskSender = std::sync::mpsc::Sender>; +/// Type alias for the input task channel receiver type TaskReceiver = std::sync::mpsc::Receiver>; /// Internal representation of a task to be processed by a `Hive`. @@ -465,8 +459,7 @@ struct Config { /// CPU cores to which worker threads can be pinned #[cfg(feature = "affinity")] affinity: Any, - /// Maximum number of tasks for a worker thread to - /// take when receiving tasks from the input channel + /// Maximum number of tasks for a worker thread to take when receiving from the input channel #[cfg(feature = "batching")] batch_size: Usize, /// Maximum number of retries for a task @@ -507,11 +500,17 @@ struct Shared, L: LocalQueues> { /// gate used by client threads to wait until all tasks have completed join_gate: PhasedGate, /// outcomes stored in the hive + /// TODO: switch to using thread-local outcome maps that need to be gathered when extracting outcomes: Mutex>>, /// local queues used by worker threads to manage tasks local_queues: L, } +/// Trait that provides access to thread-specific queues for managing tasks. Ideally, these queues +/// would be managed in a global thread-local data structure, but it has a generic type that +/// requires it to be stored within the Hive's shared data. +trait LocalQueues: Sized + Default + Send + Sync + 'static {} + #[cfg(test)] mod tests { use super::{Builder, Hive, Outcome, OutcomeIteratorExt, OutcomeStore}; @@ -1762,7 +1761,7 @@ mod affinity_tests { .core_affinity(0..2) .build_with_default::>(); - hive.map_store((0..10).map(move |i| { + channel.map_store((0..10).map(move |i| { Thunk::of(move || { if let Some(affininty) = core_affinity::get_core_ids() { eprintln!("task {} on thread with affinity {:?}", i, affininty); @@ -1779,7 +1778,7 @@ mod affinity_tests { .with_default_core_affinity() .build_with_default::>(); - hive.map_store((0..num_cpus::get()).map(move |i| { + channel.map_store((0..num_cpus::get()).map(move |i| { Thunk::of(move || { if let Some(affininty) = core_affinity::get_core_ids() { eprintln!("task {} on thread with affinity {:?}", i, affininty); diff --git a/src/hive/shared.rs b/src/hive/shared.rs index 789e96a..cc02f9a 100644 --- a/src/hive/shared.rs +++ b/src/hive/shared.rs @@ -1,5 +1,4 @@ -use super::counter::CounterError; -use super::{Config, LocalQueue, Outcome, OutcomeSender, Shared, SpawnError, Task, TaskReceiver}; +use super::{Config, LocalQueues, Outcome, OutcomeSender, Shared, SpawnError, Task, TaskReceiver}; use crate::atomic::{Atomic, AtomicInt, AtomicUsize}; use crate::bee::{Context, Queen, TaskId, Worker}; use crate::channel::SenderExt; @@ -11,7 +10,7 @@ use std::thread::{Builder, JoinHandle}; use std::time::Duration; use std::{fmt, iter, mem}; -impl, L: LocalQueue> Shared { +impl, L: LocalQueues> Shared { /// Creates a new `Shared` instance with the given configuration, queen, and task receiver, /// and all other fields set to their default values. pub fn new(config: Config, queen: Q, task_rx: TaskReceiver) -> Self { @@ -30,8 +29,6 @@ impl, L: LocalQueue> Shared { join_gate: Default::default(), outcomes: Default::default(), local_queues: Default::default(), - #[cfg(feature = "retry")] - retry_queues: Default::default(), } } @@ -372,7 +369,7 @@ impl, L: LocalQueue> Shared { } } -impl, L: LocalQueue> fmt::Debug for Shared { +impl, L: LocalQueues> fmt::Debug for Shared { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let (queued, active) = self.num_tasks(); f.debug_struct("Shared") @@ -441,12 +438,12 @@ mod no_batching { mod batching { use super::{NextTaskError, Shared, Task}; use crate::bee::{Queen, Worker}; - use crate::hive::LocalQueue; + use crate::hive::LocalQueues; use crossbeam_queue::ArrayQueue; use std::collections::HashSet; use std::time::Duration; - impl, L: LocalQueue> Shared { + impl, L: LocalQueues> Shared { pub(super) fn init_local_queues(&self, start_index: usize, end_index: usize) { let mut local_queues = self.local_queues.write(); assert_eq!(local_queues.len(), start_index); @@ -589,16 +586,6 @@ fn send_or_store>>( }); } -#[derive(thiserror::Error, Debug)] -pub enum NextTaskError { - #[error("Task receiver disconnected")] - Disconnected, - #[error("The hive has been poisoned")] - Poisoned, - #[error("Task counter has invalid state")] - InvalidCounter(CounterError), -} - #[cfg(not(feature = "retry"))] mod no_retry { use super::{LocalQueue, NextTaskError, Task}; @@ -666,10 +653,10 @@ mod retry { use crate::atomic::Atomic; use crate::bee::{Context, Queen, Worker}; use crate::hive::delay::DelayQueue; - use crate::hive::{Husk, OutcomeSender, Shared, Task}; + use crate::hive::{Husk, LocalQueues, OutcomeSender, Shared, Task}; use std::time::{Duration, Instant}; - impl> Shared { + impl, L: LocalQueues> Shared { /// Initializes the retry queues worker threads in the specified range. pub(super) fn init_retry_queues(&self, start_index: usize, end_index: usize) { let mut retry_queues = self.retry_queues.write(); @@ -796,15 +783,14 @@ mod retry { mod tests { use crate::bee::stock::ThunkWorker; use crate::bee::DefaultQueen; - - #[cfg(not(feature = "batching"))] - type LocalQueue = (); - #[cfg(feature = "batching")] - type LocalQueue = crossbeam_deque::ArrayQueue>>; + use crate::hive::LocalQueuesImpl; type VoidThunkWorker = ThunkWorker<()>; - type VoidThunkWorkerShared = - super::Shared, LocalQueue>; + type VoidThunkWorkerShared = super::Shared< + VoidThunkWorker, + DefaultQueen, + LocalQueuesImpl, + >; #[test] fn test_sync_shared() { diff --git a/src/hive/task.rs b/src/hive/task.rs deleted file mode 100644 index 03d071b..0000000 --- a/src/hive/task.rs +++ /dev/null @@ -1,34 +0,0 @@ -use super::{Outcome, OutcomeSender, Task}; -use crate::bee::{Context, TaskId, Worker}; - -impl Task { - /// Creates a new `Task`. - pub fn new(input: W::Input, ctx: Context, outcome_tx: Option>) -> Self { - Task { - input, - ctx, - outcome_tx, - } - } - - /// Returns the ID of this task. - pub fn id(&self) -> TaskId { - self.ctx.task_id() - } - - /// Consumes this `Task` and returns a tuple `(input, context, outcome_tx)`. - pub fn into_parts(self) -> (W::Input, Context, Option>) { - (self.input, self.ctx, self.outcome_tx) - } - - /// Consumes this `Task` and returns a `Outcome::Unprocessed` outcome with the input and ID, - /// and the outcome sender. - pub fn into_unprocessed(self) -> (Outcome, Option>) { - let (input, ctx, outcome_tx) = self.into_parts(); - let outcome = Outcome::Unprocessed { - input, - task_id: ctx.task_id(), - }; - (outcome, outcome_tx) - } -} From 38d4138f95b83c68a3fb33baf07296cd5ae3fd62 Mon Sep 17 00:00:00 2001 From: jdidion Date: Fri, 7 Feb 2025 14:07:02 -0800 Subject: [PATCH 04/67] refactor --- CHANGELOG.md | 3 + Cargo.toml | 6 +- benches/perf.rs | 2 +- src/atomic.rs | 4 - src/bee/context.rs | 132 +++++++-- src/bee/mod.rs | 2 + src/bee/stock/call.rs | 26 +- src/bee/stock/echo.rs | 2 +- src/bee/stock/thunk.rs | 6 +- src/bee/worker.rs | 14 +- src/hive/channel.rs | 370 +++++++++++++----------- src/hive/counter.rs | 2 - src/hive/husk.rs | 4 +- src/hive/mod.rs | 164 ++++++++--- src/hive/outcome/outcome.rs | 161 +++++++++-- src/hive/outcome/store.rs | 2 +- src/hive/shared.rs | 556 +++++++++++++----------------------- src/hive/task/delay.rs | 11 +- src/hive/task/global.rs | 77 +++++ src/hive/task/iter.rs | 27 -- src/hive/task/local.rs | 365 ++++++++++++++++++++++- src/hive/task/mod.rs | 72 +++-- src/util.rs | 2 +- 23 files changed, 1299 insertions(+), 711 deletions(-) create mode 100644 src/hive/task/global.rs delete mode 100644 src/hive/task/iter.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index 86bc83b..66bd282 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,8 +2,11 @@ ## 0.3.0 +* **Breaking** + * `beekeeper::bee::Context` has been changed from a struct to a trait, and `beekeeper::bee::Worker` now has a generic parameter for the `Context` type. * Features * Added the `batching` feature, which enables worker threads to queue up batches of tasks locally, which can alleviate contention between threads in the pool, especially when there are many short-lived tasks. + * Added the `Context::submit` method, which enables tasks to submit new tasks to the `Hive`. * Other * Switched to using thread-local retry queues for the implementation of the `retry` feature, to reduce thread-contention diff --git a/Cargo.toml b/Cargo.toml index 95e509e..25c7f68 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ repository = "https://github.com/jdidion/beekeeper" license = "MIT OR Apache-2.0" [dependencies] +crossbeam-channel = "0.5.13" crossbeam-deque = "0.8.6" crossbeam-utils = "0.8.20" num = "0.4.3" @@ -20,8 +21,7 @@ thiserror = "1.0.63" core_affinity = { version = "0.8.1", optional = true } # required with the `batching` feature crossbeam-queue = { version = "0.3.12", optional = true } -# alternate channel implementations that can be enabled with features -crossbeam-channel = { version = "0.5.13", optional = true } +# required with alternate outcome channel implementations that can be enabled with features flume = { version = "0.11.1", optional = true } loole = { version = "0.4.0", optional = true } @@ -41,7 +41,7 @@ default = ["batching", "retry"] affinity = ["dep:core_affinity"] batching = ["dep:crossbeam-queue"] retry = [] -crossbeam = ["dep:crossbeam-channel"] +crossbeam = [] flume = ["dep:flume"] loole = ["dep:loole"] diff --git a/benches/perf.rs b/benches/perf.rs index 53175fb..35715ac 100644 --- a/benches/perf.rs +++ b/benches/perf.rs @@ -17,7 +17,7 @@ fn bench_apply_short_task(bencher: Bencher, (num_threads, num_tasks): (&usize, & bencher.bench_local(|| { let (tx, rx) = outcome_channel(); for i in 0..*num_tasks { - hive.apply_send(i, tx.clone()); + hive.apply_send(i, &tx); } hive.join(); rx.into_iter().take(*num_tasks).for_each(black_box_drop); diff --git a/src/atomic.rs b/src/atomic.rs index ec3d74c..2f08c18 100644 --- a/src/atomic.rs +++ b/src/atomic.rs @@ -28,8 +28,6 @@ pub trait Atomic: Clone + Debug + Default + From pub struct Orderings { pub load: Ordering, pub swap: Ordering, - pub fetch_update_set: Ordering, - pub fetch_update_fetch: Ordering, pub fetch_add: Ordering, pub fetch_sub: Ordering, } @@ -39,8 +37,6 @@ impl Default for Orderings { Orderings { load: Ordering::Acquire, swap: Ordering::Release, - fetch_update_set: Ordering::AcqRel, - fetch_update_fetch: Ordering::Acquire, fetch_add: Ordering::AcqRel, fetch_sub: Ordering::AcqRel, } diff --git a/src/bee/context.rs b/src/bee/context.rs index e2b9315..0eec520 100644 --- a/src/bee/context.rs +++ b/src/bee/context.rs @@ -1,57 +1,127 @@ //! The context for a task processed by a `Worker`. -use crate::atomic::{Atomic, AtomicBool}; use std::fmt::Debug; -use std::sync::Arc; pub type TaskId = usize; -/// Context for a task. -#[derive(Debug, Default)] -pub struct Context { +/// Trait that provides a `Context` with limited access to a worker thread's state during +/// task execution. +pub trait TaskContext: Debug { + /// Returns `true` if tasks in progress should be cancelled. + fn cancel_tasks(&self) -> bool; + + /// Submits a new task to the `Hive` that is executing the current task. + fn submit_task(&self, input: I) -> TaskId; +} + +#[derive(Debug)] +pub struct Context<'a, I> { task_id: TaskId, - cancelled: Arc, + task_ctx: Option>>, + subtask_ids: Option>, #[cfg(feature = "retry")] attempt: u32, } -impl Context { - /// Creates a new `Context` with the given task_id and shared cancellation status. - pub fn new(task_id: TaskId, cancelled: Arc) -> Self { - Self { - task_id, - cancelled, - #[cfg(feature = "retry")] - attempt: 0, +impl<'a, I> Context<'a, I> { + /// The task_id of this task within the `Hive`. + pub fn task_id(&self) -> TaskId { + self.task_id + } + + /// Returns `true` if the task has been cancelled. + /// + /// A long-running `Worker` should check this periodically and, if it returns `true`, exit + /// early with an `ApplyError::Cancelled` result. + pub fn is_cancelled(&self) -> bool { + self.task_ctx + .as_ref() + .map(|worker| worker.cancel_tasks()) + .unwrap_or(false) + } + + /// Submits a new task to the `Hive` that is executing the current task. + /// + /// If a thread-local queue is available and has capacity, the task will be added to it, + /// otherwise it is added to the global queue. The ID of the submitted task is stored in this + /// `Context` and ultimately returned in the `Outcome` of the submitting task. + /// + /// The task will be submitted with the same outcome sender as the current task, or stored in + /// the `Hive` if there is no sender. + /// + /// Returns an `Err` containing `input` if the new task was not successfully submitted. + pub fn submit(&mut self, input: I) -> Result<(), I> { + if let Some(worker) = self.task_ctx.as_ref() { + let task_id = worker.submit_task(input); + self.subtask_ids.get_or_insert_default().push(task_id); + Ok(()) + } else { + Err(input) } } - /// Creates an empty `Context`. + pub(crate) fn into_subtask_ids(self) -> Option> { + self.subtask_ids + } +} + +#[cfg(not(feature = "retry"))] +impl<'a, I> Context<'a, I> { + /// Returns a new empty context. This is primarily useful for testing. pub fn empty() -> Self { - Self::new(0, Arc::new(AtomicBool::from(false))) + Self { + task_id: 0, + task_ctx: None, + subtask_ids: None, + } } - /// The task_id of this task within the `Hive`. - pub fn task_id(&self) -> TaskId { - self.task_id + /// Creates a new `Context` with the given task_id and shared cancellation status. + pub fn new(task_id: TaskId, task_ctx: Option>>) -> Self { + Self { + task_id, + task_ctx, + subtask_ids: None, + } } - /// Returns `true` if the task has been cancelled. A long-running `Worker` should check this - /// periodically and, if it returns `true`, exit early with an `ApplyError::Cancelled` result. - pub fn is_cancelled(&self) -> bool { - self.cancelled.get() + /// The number of previous attempts to execute the current task. + /// + /// Always returns `0`. + pub fn attempt(&self) -> u32 { + 0 } } #[cfg(feature = "retry")] -impl Context { - /// The current retry attempt. The value is `0` for the first attempt and increments by `1` for - /// each retry attempt (if any). - pub fn attempt(&self) -> u32 { - self.attempt +impl<'a, I> Context<'a, I> { + /// Returns a new empty context. This is primarily useful for testing. + pub fn empty() -> Self { + Self { + task_id: 0, + attempt: 0, + task_ctx: None, + subtask_ids: None, + } + } + + /// Creates a new `Context` with the given task_id and shared cancellation status. + pub fn new( + task_id: TaskId, + attempt: u32, + task_ctx: Option>>, + ) -> Self { + Self { + task_id, + attempt, + task_ctx, + subtask_ids: None, + } } - /// Increments the retry attempt. - pub(crate) fn inc_attempt(&mut self) { - self.attempt += 1; + /// The number of previous attempts to execute the current task. + /// + /// Returns `0` for the first attempt and increments by `1` for each retry attempt (if any). + pub fn attempt(&self) -> u32 { + self.attempt } } diff --git a/src/bee/mod.rs b/src/bee/mod.rs index b72fdc9..56662d7 100644 --- a/src/bee/mod.rs +++ b/src/bee/mod.rs @@ -120,6 +120,8 @@ pub use error::{ApplyError, ApplyRefError}; pub use queen::{CloneQueen, DefaultQueen, Queen}; pub use worker::{RefWorker, RefWorkerResult, Worker, WorkerError, WorkerResult}; +pub(crate) use context::TaskContext; + pub mod prelude { pub use super::{ ApplyError, ApplyRefError, Context, Queen, RefWorker, RefWorkerResult, Worker, WorkerError, diff --git a/src/bee/stock/call.rs b/src/bee/stock/call.rs index 93bf057..a78dc18 100644 --- a/src/bee/stock/call.rs +++ b/src/bee/stock/call.rs @@ -57,7 +57,7 @@ where type Error = (); #[inline] - fn apply(&mut self, input: Self::Input, _: &Context) -> WorkerResult { + fn apply(&mut self, input: Self::Input, _: &Context) -> WorkerResult { Ok((self.0.f)(input)) } } @@ -110,7 +110,7 @@ where type Error = E; #[inline] - fn apply(&mut self, input: Self::Input, _: &Context) -> WorkerResult { + fn apply(&mut self, input: Self::Input, _: &Context) -> WorkerResult { (self.0.f)(input).map_err(|error| ApplyError::Fatal { error, input: None }) } } @@ -167,7 +167,11 @@ where type Error = E; #[inline] - fn apply_ref(&mut self, input: &Self::Input, _: &Context) -> RefWorkerResult { + fn apply_ref( + &mut self, + input: &Self::Input, + _: &Context, + ) -> RefWorkerResult { (self.0.f)(input).map_err(|error| ApplyRefError::Fatal(error)) } } @@ -203,7 +207,7 @@ impl RetryCaller { I: Send + Sync + 'static, O: Send + Sync + 'static, E: Send + Sync + Debug + 'static, - F: FnMut(I, &Context) -> Result> + Clone + 'static, + F: FnMut(I, &Context) -> Result> + Clone + 'static, { RetryCaller(Callable::of(f)) } @@ -214,14 +218,14 @@ where I: Send + 'static, O: Send + 'static, E: Send + Debug + 'static, - F: FnMut(I, &Context) -> Result> + Clone + 'static, + F: FnMut(I, &Context) -> Result> + Clone + 'static, { type Input = I; type Output = O; type Error = E; #[inline] - fn apply(&mut self, input: Self::Input, ctx: &Context) -> WorkerResult { + fn apply(&mut self, input: Self::Input, ctx: &Context) -> WorkerResult { (self.0.f)(input, ctx) } } @@ -232,7 +236,7 @@ impl Clone for RetryCaller { } } -impl Result>> Debug +impl) -> Result>> Debug for RetryCaller { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -242,7 +246,7 @@ impl Result>> Debug impl From for RetryCaller where - F: FnMut(I, &Context) -> Result> + Clone + 'static, + F: FnMut(I, &Context) -> Result> + Clone + 'static, { fn from(f: F) -> Self { RetryCaller(Callable::of(f)) @@ -265,9 +269,11 @@ mod tests { (bool, u8), u8, String, - impl FnMut((bool, u8), &Context) -> Result> + Clone + 'static, + impl FnMut((bool, u8), &Context<(bool, u8)>) -> Result> + + Clone + + 'static, > { - RetryCaller::of(|input: (bool, u8), _: &Context| { + RetryCaller::of(|input: (bool, u8), _: &Context<(bool, u8)>| { if input.0 { Ok(input.1 + 1) } else { diff --git a/src/bee/stock/echo.rs b/src/bee/stock/echo.rs index 4aa39c2..d1dccaa 100644 --- a/src/bee/stock/echo.rs +++ b/src/bee/stock/echo.rs @@ -18,7 +18,7 @@ impl Worker for EchoWorker { type Error = (); #[inline] - fn apply(&mut self, input: Self::Input, _: &Context) -> WorkerResult { + fn apply(&mut self, input: Self::Input, _: &Context) -> WorkerResult { Ok(input) } } diff --git a/src/bee/stock/thunk.rs b/src/bee/stock/thunk.rs index 32ed9d8..26b4b15 100644 --- a/src/bee/stock/thunk.rs +++ b/src/bee/stock/thunk.rs @@ -20,7 +20,7 @@ impl Worker for ThunkWorker { type Error = (); #[inline] - fn apply(&mut self, f: Self::Input, _: &Context) -> WorkerResult { + fn apply(&mut self, f: Self::Input, _: &Context) -> WorkerResult { Ok(f.0.call_box()) } } @@ -41,7 +41,7 @@ impl Worker for FunkWorker type Error = E; #[inline] - fn apply(&mut self, f: Self::Input, _: &Context) -> WorkerResult { + fn apply(&mut self, f: Self::Input, _: &Context) -> WorkerResult { f.0.call_box() .map_err(|error| ApplyError::Fatal { error, input: None }) } @@ -63,7 +63,7 @@ impl Worker for PunkWorker { type Output = T; type Error = (); - fn apply(&mut self, f: Self::Input, _: &Context) -> WorkerResult { + fn apply(&mut self, f: Self::Input, _: &Context) -> WorkerResult { Panic::try_call_boxed(None, f.0).map_err(|payload| ApplyError::Panic { input: None, payload, diff --git a/src/bee/worker.rs b/src/bee/worker.rs index 177cefa..adbc93b 100644 --- a/src/bee/worker.rs +++ b/src/bee/worker.rs @@ -32,7 +32,7 @@ pub trait Worker: Debug + Sized + 'static { /// /// This method should not panic. If it may panic, then [`Panic::try_call`] should be used to /// catch the panic and turn it into an [`ApplyError::Panic`] error. - fn apply(&mut self, _: Self::Input, _: &Context) -> WorkerResult; + fn apply(&mut self, _: Self::Input, _: &Context) -> WorkerResult; /// Applies this `Worker`'s function sequentially to an iterator of inputs and returns a /// iterator over the outputs. @@ -65,7 +65,7 @@ pub trait RefWorker: Debug + Sized + 'static { /// The type of error produced by this function. type Error: Send + Debug; - fn apply_ref(&mut self, _: &Self::Input, _: &Context) -> RefWorkerResult; + fn apply_ref(&mut self, _: &Self::Input, _: &Context) -> RefWorkerResult; } /// Blanket implementation of `Worker` for `RefWorker` that calls `apply_ref` and catches any @@ -81,7 +81,7 @@ where type Output = O; type Error = E; - fn apply(&mut self, input: Self::Input, ctx: &Context) -> WorkerResult { + fn apply(&mut self, input: Self::Input, ctx: &Context) -> WorkerResult { match Panic::try_call(None, || self.apply_ref(&input, ctx)) { Ok(Ok(output)) => Ok(output), Ok(Err(error)) => Err(error.into_apply_error(input)), @@ -106,7 +106,7 @@ mod tests { type Output = u8; type Error = (); - fn apply(&mut self, input: Self::Input, _: &Context) -> WorkerResult { + fn apply(&mut self, input: Self::Input, _: &Context) -> WorkerResult { Ok(input + 1) } } @@ -133,7 +133,11 @@ mod tests { type Output = u8; type Error = (); - fn apply_ref(&mut self, input: &Self::Input, _: &Context) -> RefWorkerResult { + fn apply_ref( + &mut self, + input: &Self::Input, + _: &Context, + ) -> RefWorkerResult { match *input { 0 => Err(ApplyRefError::Retryable(())), 1 => Err(ApplyRefError::Fatal(())), diff --git a/src/hive/channel.rs b/src/hive/channel.rs index 05e6685..5bfe6a3 100644 --- a/src/hive/channel.rs +++ b/src/hive/channel.rs @@ -1,14 +1,13 @@ use super::prelude::*; -use super::{ - Config, DerefOutcomes, HiveInner, LocalQueues, OutcomeSender, Shared, SpawnError, TaskSender, -}; +use super::task::{ChannelGlobalQueue, ChannelLocalQueues}; +use super::{Config, DerefOutcomes, GlobalQueue, LocalQueues, OutcomeSender, Shared, SpawnError}; use crate::atomic::Atomic; -use crate::bee::{DefaultQueen, Queen, TaskId, Worker}; +use crate::bee::{DefaultQueen, Queen, TaskContext, TaskId, Worker}; use crossbeam_utils::Backoff; use std::collections::HashMap; -use std::fmt::Debug; +use std::fmt; use std::ops::{Deref, DerefMut}; -use std::sync::{mpsc, Arc}; +use std::sync::Arc; use std::thread::{self, JoinHandle}; #[derive(thiserror::Error, Debug)] @@ -17,24 +16,24 @@ pub struct Poisoned; impl> Hive { /// Spawns a new worker thread with the specified index and with access to the `shared` data. - fn try_spawn>( + fn try_spawn, L: LocalQueues>( thread_index: usize, - shared: Arc>, + shared: &Arc>, ) -> Result, SpawnError> { + let thread_builder = shared.thread_builder(); + let shared = Arc::clone(shared); // spawn a thread that executes the worker loop - shared.thread_builder().spawn(move || { + thread_builder.spawn(move || { // perform one-time initialization of the worker thread Self::init_thread(thread_index, &shared); // create a Sentinel that will spawn a new thread on panic until it is cancelled let sentinel = Sentinel::new(thread_index, Arc::clone(&shared)); - // create a new Worker instance + // create a new worker to process tasks let mut worker = shared.create_worker(); - // execute the main loop - // get the next task to process - this decrements the queued counter and increments - // the active counter - while let Ok(task) = shared.next_task(thread_index) { - // execute the task until it succeeds or we reach maximum retries - this should - // be the only place where a panic can occur + // execute the main loop: get the next task to process, which decrements the queued + // counter and increments the active counter + while let Some(task) = shared.get_next_task(thread_index) { + // execute the task and dispose of the outcome Self::execute(task, thread_index, &mut worker, &shared); // finish the task - decrements the active counter and notifies other threads shared.finish_task(false); @@ -51,15 +50,15 @@ impl> Hive { /// (`config.num_threads`) but the actual number of threads available may be lower if there /// are any errors during spawning. pub(super) fn new(config: Config, queen: Q) -> Self { - let (task_tx, task_rx) = mpsc::channel(); - let shared = Arc::new(Shared::new(config.into_sync(), queen, task_rx)); - shared.init_threads(|thread_index| Self::try_spawn(thread_index, Arc::clone(&shared))); - Self(Some(HiveInner { task_tx, shared })) + let global_queue = ChannelGlobalQueue::default(); + let shared = Arc::new(Shared::new(config.into_sync(), global_queue, queen)); + shared.init_threads(|thread_index| Self::try_spawn(thread_index, &shared)); + Self(Some(shared)) } #[inline] - fn task_tx(&self) -> &TaskSender { - &self.0.as_ref().unwrap().task_tx + fn shared(&self) -> &Arc, ChannelLocalQueues>> { + &self.0.as_ref().unwrap() } /// Attempts to increase the number of worker threads by `num_threads`. Returns the number of @@ -69,13 +68,13 @@ impl> Hive { if num_threads == 0 { return Ok(0); } - let shared = &self.0.as_ref().unwrap().shared; + let shared = self.shared(); // do not start any new threads if the hive is poisoned if shared.is_poisoned() { return Err(Poisoned); } let num_started = shared.grow_threads(num_threads, |thread_index| { - Self::try_spawn(thread_index, Arc::clone(shared)) + Self::try_spawn(thread_index, shared) }); Ok(num_started) } @@ -93,25 +92,13 @@ impl> Hive { /// the `Hive` for later retrieval. /// /// This method is called by all the `*apply*` methods. - fn send_one(&self, input: W::Input, outcome_tx: Option>) -> TaskId { + #[inline] + fn send_one(&self, input: W::Input, outcome_tx: Option<&OutcomeSender>) -> TaskId { #[cfg(debug_assertions)] if self.max_workers() == 0 { dbg!("WARNING: no worker threads are active for hive"); } - let shared = &self.0.as_ref().unwrap().shared; - let task = shared.prepare_task(input, outcome_tx); - let task_id = task.id(); - // try to send the task to the hive; if the hive is poisoned or if sending fails, convert - // the task into an `Unprocessed` outcome and try to send it to the outcome channel; if - // that fails, store the outcome in the hive - if let Some(abandoned_task) = if self.is_poisoned() { - Some(task) - } else { - self.task_tx().send(task).err().map(|err| err.0) - } { - shared.abandon_task(abandoned_task); - } - task_id + self.shared().send_one_global(input, outcome_tx) } /// Sends one `input` to the `Hive` for procesing and returns the result, blocking until the @@ -119,13 +106,13 @@ impl> Hive { /// an [`Outcome`] with the task output or an error. pub fn apply(&self, input: W::Input) -> Outcome { let (tx, rx) = outcome_channel(); - let task_id = self.send_one(input, Some(tx)); + let task_id = self.send_one(input, Some(&tx)); rx.recv().unwrap_or_else(|_| Outcome::Missing { task_id }) } /// Sends one `input` to the `Hive` for processing and returns its ID. The [`Outcome`] of /// the task will be sent to `tx` upon completion. - pub fn apply_send(&self, input: W::Input, tx: OutcomeSender) -> TaskId { + pub fn apply_send(&self, input: W::Input, tx: &OutcomeSender) -> TaskId { self.send_one(input, Some(tx)) } @@ -143,7 +130,8 @@ impl> Hive { /// range of task IDs (a single atomic operation) rather than one at a time. /// /// This method is called by all the `swarm*` methods. - fn send_batch(&self, batch: T, outcome_tx: Option>) -> Vec + #[inline] + fn send_batch(&self, batch: T, outcome_tx: Option<&OutcomeSender>) -> Vec where T: IntoIterator, T::IntoIter: ExactSizeIterator, @@ -152,29 +140,7 @@ impl> Hive { if self.max_workers() == 0 { dbg!("WARNING: no worker threads are active for hive"); } - let task_tx = self.task_tx(); - let iter = batch.into_iter(); - let (batch_size, _) = iter.size_hint(); - let shared = &self.0.as_ref().unwrap().shared; - let batch = shared.prepare_batch(batch_size, iter, outcome_tx); - if !self.is_poisoned() { - batch - .map(|task| { - let task_id = task.id(); - // try to send the task to the hive; if sending fails, convert the task into an - // `Unprocessed` outcome and try to send it to the outcome channel; if that - // fails, store the outcome in the hive - if let Err(err) = task_tx.send(task) { - shared.abandon_task(err.0); - } - task_id - }) - .collect() - } else { - // if the hive is poisoned, convert all tasks into `Unprocessed` outcomes and try to - // send them to their outcome channels or store them in the hive - (&self.0.as_ref().unwrap().shared).abandon_batch(batch) - } + self.shared().send_batch_global(batch, outcome_tx) } /// Sends a `batch` of inputs to the `Hive` for processing, and returns an iterator over the @@ -188,7 +154,7 @@ impl> Hive { T::IntoIter: ExactSizeIterator, { let (tx, rx) = outcome_channel(); - let task_ids = self.send_batch(batch, Some(tx)); + let task_ids = self.send_batch(batch, Some(&tx)); rx.select_ordered(task_ids) } @@ -205,7 +171,7 @@ impl> Hive { T::IntoIter: ExactSizeIterator, { let (tx, rx) = outcome_channel(); - let task_ids = self.send_batch(batch, Some(tx)); + let task_ids = self.send_batch(batch, Some(&tx)); rx.select_unordered(task_ids) } @@ -214,7 +180,7 @@ impl> Hive { /// /// This method is more efficient than [`map_send`](Self::map_send) when the input is an /// [`ExactSizeIterator`]. - pub fn swarm_send(&self, batch: T, outcome_tx: OutcomeSender) -> Vec + pub fn swarm_send(&self, batch: T, outcome_tx: &OutcomeSender) -> Vec where T: IntoIterator, T::IntoIter: ExactSizeIterator, @@ -245,7 +211,7 @@ impl> Hive { let (tx, rx) = outcome_channel(); let task_ids: Vec<_> = inputs .into_iter() - .map(|task| self.apply_send(task, tx.clone())) + .map(|task| self.apply_send(task, &tx)) .collect(); rx.select_ordered(task_ids) } @@ -263,7 +229,7 @@ impl> Hive { // `map` is required (rather than `inspect`) because we need owned items let task_ids: Vec<_> = inputs .into_iter() - .map(|task| self.apply_send(task, tx.clone())) + .map(|task| self.apply_send(task, &tx)) .collect(); rx.select_unordered(task_ids) } @@ -276,11 +242,11 @@ impl> Hive { pub fn map_send( &self, inputs: impl IntoIterator, - tx: OutcomeSender, + tx: &OutcomeSender, ) -> Vec { inputs .into_iter() - .map(|input| self.apply_send(input, tx.clone())) + .map(|input| self.apply_send(input, tx)) .collect() } @@ -309,7 +275,7 @@ impl> Hive { F: FnMut(&mut St, T) -> W::Input, { let (tx, rx) = outcome_channel(); - let (task_ids, fold_value) = self.scan_send(items, tx, init, f); + let (task_ids, fold_value) = self.scan_send(items, &tx, init, f); let outcomes = rx.select_unordered(task_ids).into(); (outcomes, fold_value) } @@ -332,7 +298,7 @@ impl> Hive { (Vec::new(), Vec::new(), init), |(mut task_ids, mut errors, mut acc), inp| { match f(&mut acc, inp) { - Ok(input) => task_ids.push(self.apply_send(input, tx.clone())), + Ok(input) => task_ids.push(self.apply_send(input, &tx)), Err(err) => errors.push(err), } (task_ids, errors, acc) @@ -349,7 +315,7 @@ impl> Hive { pub fn scan_send( &self, items: impl IntoIterator, - tx: OutcomeSender, + tx: &OutcomeSender, init: St, mut f: F, ) -> (Vec, St) @@ -360,7 +326,7 @@ impl> Hive { .into_iter() .fold((Vec::new(), init), |(mut task_ids, mut acc), item| { let input = f(&mut acc, item); - task_ids.push(self.apply_send(input, tx.clone())); + task_ids.push(self.apply_send(input, tx)); (task_ids, acc) }) } @@ -373,7 +339,7 @@ impl> Hive { pub fn try_scan_send( &self, items: impl IntoIterator, - tx: OutcomeSender, + tx: &OutcomeSender, init: St, mut f: F, ) -> (Vec>, St) @@ -383,7 +349,7 @@ impl> Hive { items .into_iter() .fold((Vec::new(), init), |(mut results, mut acc), inp| { - results.push(f(&mut acc, inp).map(|input| self.apply_send(input, tx.clone()))); + results.push(f(&mut acc, inp).map(|input| self.apply_send(input, tx))); (results, acc) }) } @@ -434,7 +400,7 @@ impl> Hive { /// Blocks the calling thread until all tasks finish. pub fn join(&self) { - (&self.0.as_ref().unwrap().shared).wait_on_done(); + (self.shared()).wait_on_done(); } /// Returns the [`MutexGuard`](parking_lot::MutexGuard) for the [`Queen`]. @@ -442,23 +408,20 @@ impl> Hive { /// Note that the `Queen` will remain locked until the returned guard is dropped, and that /// locking the `Queen` prevents new worker threads from being started. pub fn queen(&self) -> impl Deref + '_ { - (&self.0.as_ref().unwrap().shared).queen.lock() + (self.shared()).queen.lock() } /// Returns the number of worker threads that have been requested, i.e., the maximum number of /// tasks that could be processed concurrently. This may be greater than /// [`active_workers`](Self::active_workers) if any of the worker threads failed to start. pub fn max_workers(&self) -> usize { - (&self.0.as_ref().unwrap().shared) - .config - .num_threads - .get_or_default() + (self.shared()).config.num_threads.get_or_default() } /// Returns the number of worker threads that have been successfully started. This may be /// fewer than [`max_workers`](Self::max_workers) if any of the worker threads failed to start. pub fn alive_workers(&self) -> usize { - (&self.0.as_ref().unwrap().shared) + (self.shared()) .spawn_results .lock() .iter() @@ -468,7 +431,7 @@ impl> Hive { /// Returns `true` if there are any "dead" worker threads that failed to spawn. pub fn has_dead_workers(&self) -> bool { - (&self.0.as_ref().unwrap().shared) + (self.shared()) .spawn_results .lock() .iter() @@ -478,19 +441,18 @@ impl> Hive { /// Attempts to respawn any dead worker threads. Returns the number of worker threads that were /// successfully respawned. pub fn revive_workers(&self) -> usize { - let shared = &self.0.as_ref().unwrap().shared; - shared - .respawn_dead_threads(|thread_index| Self::try_spawn(thread_index, Arc::clone(shared))) + let shared = self.shared(); + shared.respawn_dead_threads(|thread_index| Self::try_spawn(thread_index, shared)) } /// Returns the number of tasks currently (queued for processing, being processed). pub fn num_tasks(&self) -> (u64, u64) { - (&self.0.as_ref().unwrap().shared).num_tasks() + (self.shared()).num_tasks() } /// Returns the number of times one of this `Hive`'s worker threads has panicked. pub fn num_panics(&self) -> usize { - (&self.0.as_ref().unwrap().shared).num_panics.get() + (self.shared()).num_panics.get() } /// Returns `true` if this `Hive` has been poisoned - i.e., its internal state has been @@ -500,12 +462,12 @@ impl> Hive { /// its stored [`Outcome`]s (e.g., [`take_stored`](Self::take_stored)) or consume it (e.g., /// [`try_into_husk`](Self::try_into_husk)). pub fn is_poisoned(&self) -> bool { - (&self.0.as_ref().unwrap().shared).is_poisoned() + (self.shared()).is_poisoned() } /// Returns `true` if the suspended flag is set. pub fn is_suspended(&self) -> bool { - (&self.0.as_ref().unwrap().shared).is_suspended() + (self.shared()).is_suspended() } /// Sets the suspended flag, which notifies worker threads that they a) MAY terminate their @@ -545,18 +507,18 @@ impl> Hive { /// # } /// ``` pub fn suspend(&self) { - (&self.0.as_ref().unwrap().shared).set_suspended(true); + (self.shared()).set_suspended(true); } /// Unsets the suspended flag, allowing worker threads to continue processing queued tasks. pub fn resume(&self) { - (&self.0.as_ref().unwrap().shared).set_suspended(false); + (self.shared()).set_suspended(false); } /// Removes all `Unprocessed` outcomes from this `Hive` and returns them as an iterator over /// the input values. fn take_unprocessed_inputs(&self) -> impl ExactSizeIterator { - (&self.0.as_ref().unwrap().shared) + (self.shared()) .take_unprocessed() .into_iter() .map(|outcome| match outcome { @@ -568,8 +530,8 @@ impl> Hive { /// If this `Hive` is suspended, resumes this `Hive` and re-submits any unprocessed tasks for /// processing, with their results to be sent to `tx`. Returns a [`Vec`] of task IDs that /// were resumed. - pub fn resume_send(&self, outcome_tx: OutcomeSender) -> Vec { - (&self.0.as_ref().unwrap().shared) + pub fn resume_send(&self, outcome_tx: &OutcomeSender) -> Vec { + (self.shared()) .set_suspended(false) .then(|| self.swarm_send(self.take_unprocessed_inputs(), outcome_tx)) .unwrap_or_default() @@ -579,7 +541,7 @@ impl> Hive { /// processing, with their results to be stored in the queue. Returns a [`Vec`] of task IDs /// that were resumed. pub fn resume_store(&self) -> Vec { - (&self.0.as_ref().unwrap().shared) + (self.shared()) .set_suspended(false) .then(|| self.swarm_store(self.take_unprocessed_inputs())) .unwrap_or_default() @@ -587,7 +549,7 @@ impl> Hive { /// Returns all stored outcomes as a [`HashMap`] of task IDs to `Outcome`s. pub fn take_stored(&self) -> HashMap> { - (&self.0.as_ref().unwrap().shared).take_outcomes() + (self.shared()).take_outcomes() } /// Consumes this `Hive` and attempts to return a [`Husk`] containing the remnants of this @@ -599,18 +561,17 @@ impl> Hive { /// /// This method first joins on the `Hive` to wait for all tasks to finish. pub fn try_into_husk(mut self) -> Option> { - if (&self.0.as_ref().unwrap().shared).num_referrers() > 1 { + if (self.shared()).num_referrers() > 1 { return None; } // take the inner value and replace it with `None` - let inner = self.0.take().unwrap(); + let mut shared = self.0.take().unwrap(); + // close the global queue to prevent new tasks from being submitted + shared.global_queue.close(); // wait for all tasks to finish - inner.shared.wait_on_done(); - // drop the task sender so receivers will drop automatically - drop(inner.task_tx); + shared.wait_on_done(); // wait for worker threads to drop, then take ownership of the shared data and convert it // into a Husk - let mut shared = inner.shared; let mut backoff = None::; loop { // TODO: may want to have some timeout or other kind of limit to prevent this from @@ -618,7 +579,7 @@ impl> Hive { // counter is corrupted shared = match Arc::try_unwrap(shared) { Ok(shared) => { - return Some(shared.try_into_husk()); + return Some(shared.into_husk()); } Err(shared) => { backoff.get_or_insert_with(Backoff::new).spin(); @@ -639,28 +600,16 @@ impl> Clone for Hive { /// Creates a shallow copy of this `Hive` containing references to its same internal state, /// i.e., all clones of a `Hive` submit tasks to the same shared worker thread pool. fn clone(&self) -> Self { - let inner = self.0.as_ref().unwrap(); - (&inner.shared).referrer_is_cloning(); - Self(Some(inner.clone())) - } -} - -impl> Clone for HiveInner { - fn clone(&self) -> Self { - HiveInner { - task_tx: self.task_tx.clone(), - shared: Arc::clone(&self.shared), - } + let shared = self.0.as_ref().unwrap(); + shared.referrer_is_cloning(); + Self(Some(shared.clone())) } } -impl> Debug for Hive { +impl> fmt::Debug for Hive { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if let Some(inner) = self.0.as_ref() { - f.debug_struct("Hive") - .field("task_tx", &inner.task_tx) - .field("shared", &inner.shared) - .finish() + if let Some(shared) = self.0.as_ref() { + f.debug_struct("Hive").field("shared", &shared).finish() } else { f.write_str("Hive {}") } @@ -669,8 +618,8 @@ impl> Debug for Hive { impl> PartialEq for Hive { fn eq(&self, other: &Hive) -> bool { - let self_shared = &self.0.as_ref().unwrap().shared; - let other_shared = &other.0.as_ref().unwrap().shared; + let self_shared = self.shared(); + let other_shared = &other.shared(); Arc::ptr_eq(self_shared, other_shared) } } @@ -680,26 +629,26 @@ impl> Eq for Hive {} impl> DerefOutcomes for Hive { #[inline] fn outcomes_deref(&self) -> impl Deref>> { - (&self.0.as_ref().unwrap().shared).outcomes() + (self.shared()).outcomes() } #[inline] fn outcomes_deref_mut(&mut self) -> impl DerefMut>> { - (&self.0.as_ref().unwrap().shared).outcomes() + (self.shared()).outcomes() } } impl> Drop for Hive { fn drop(&mut self) { // if this Hive has already been turned into a Husk, it's inner value will be `None` - if let Some(inner) = self.0.as_ref() { + if let Some(shared) = self.0.as_ref() { // reduce the referrer count - let _ = inner.shared.referrer_is_dropping(); + let _ = shared.referrer_is_dropping(); // if this Hive is the only one with a pointer to the shared data, poison it // to prevent any worker threads that still have access to the shared data from // re-spawning. - if inner.shared.num_referrers() == 0 { - inner.shared.poison(); + if shared.num_referrers() == 0 { + shared.poison(); } } } @@ -707,14 +656,26 @@ impl> Drop for Hive { /// Sentinel for a worker thread. Until the sentinel is cancelled, it will respawn the worker /// thread if it panics. -struct Sentinel, L: LocalQueues> { +struct Sentinel +where + W: Worker, + Q: Queen, + G: GlobalQueue, + L: LocalQueues, +{ thread_index: usize, - shared: Arc>, + shared: Arc>, active: bool, } -impl, L: LocalQueues> Sentinel { - fn new(thread_index: usize, shared: Arc>) -> Self { +impl Sentinel +where + W: Worker, + Q: Queen, + G: GlobalQueue, + L: LocalQueues, +{ + fn new(thread_index: usize, shared: Arc>) -> Self { Self { thread_index, shared, @@ -728,7 +689,13 @@ impl, L: LocalQueues> Sentinel { } } -impl, L: LocalQueues> Drop for Sentinel { +impl Drop for Sentinel +where + W: Worker, + Q: Queen, + G: GlobalQueue, + L: LocalQueues, +{ fn drop(&mut self) { if self.active { // if the sentinel is active, that means the thread panicked during task execution, so @@ -740,7 +707,7 @@ impl, L: LocalQueues> Drop for Sentinel, L: LocalQueues> Drop for Sentinel> Hive { #[inline] - pub(super) fn init_thread>(_: usize, _: &Shared) {} + pub(super) fn init_thread, L: LocalQueues>( + _: usize, + _: &Shared, + ) { + } } } @@ -785,7 +756,7 @@ mod affinity { num_threads: usize, affinity: C, ) -> Result { - (&self.0.as_ref().unwrap().shared).add_core_affinity(affinity.into()); + (self.shared()).add_core_affinity(affinity.into()); self.grow(num_threads) } @@ -795,7 +766,7 @@ mod affinity { /// Returns the number of new threads spun up (if any) or a `Poisoned` error if the hive /// has been poisoned. pub fn use_all_cores_with_affinity(&self) -> Result { - (&self.0.as_ref().unwrap().shared).add_core_affinity(Cores::all()); + (self.shared()).add_core_affinity(Cores::all()); self.use_all_cores() } } @@ -809,33 +780,82 @@ mod batching { impl> Hive { /// Returns the batch size for worker threads. pub fn worker_batch_size(&self) -> usize { - (&self.0.as_ref().unwrap().shared).batch_size() + (self.shared()).batch_size() } /// Sets the batch size for worker threads. This will block the current thread until all /// worker thread queues can be resized. pub fn set_worker_batch_size(&self, batch_size: usize) { - (&self.0.as_ref().unwrap().shared).set_batch_size(batch_size); + (self.shared()).set_batch_size(batch_size); } } } +struct HiveTaskContext<'a, W, Q, G, L> +where + W: Worker, + Q: Queen, + G: GlobalQueue, + L: LocalQueues, +{ + thread_index: usize, + shared: &'a Arc>, + outcome_tx: Option<&'a OutcomeSender>, +} + +impl<'a, W, Q, G, L> TaskContext for HiveTaskContext<'a, W, Q, G, L> +where + W: Worker, + Q: Queen, + G: GlobalQueue, + L: LocalQueues, +{ + fn cancel_tasks(&self) -> bool { + self.shared.is_suspended() + } + + fn submit_task(&self, input: W::Input) -> TaskId { + self.shared + .send_one_local(input, self.outcome_tx, self.thread_index) + } +} + +impl<'a, W, Q, G, L> fmt::Debug for HiveTaskContext<'a, W, Q, G, L> +where + W: Worker, + Q: Queen, + G: GlobalQueue, + L: LocalQueues, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("HiveTaskContext").finish() + } +} + #[cfg(not(feature = "retry"))] mod no_retry { - use crate::bee::{Queen, Worker}; - use crate::hive::{Hive, LocalQueue, Outcome, Shared, Task}; + use super::HiveTaskContext; + use crate::bee::{Context, Queen, Worker}; + use crate::hive::{GlobalQueue, Hive, LocalQueues, Outcome, Shared, Task}; + use std::sync::Arc; impl> Hive { - #[inline] - pub(super) fn execute>( + pub(super) fn execute, L: LocalQueues>( task: Task, - _thread_index: usize, + thread_index: usize, worker: &mut W, - shared: &Shared, + shared: &Arc>, ) { - let (input, ctx, outcome_tx) = task.into_parts(); + let (task_id, input, outcome_tx) = task.into_parts(); + let task_ctx = HiveTaskContext { + thread_index, + shared, + outcome_tx: outcome_tx.as_ref(), + }; + let ctx = Context::new(task_id, Some(Box::new(&task_ctx))); let result = worker.apply(input, &ctx); - let outcome = Outcome::from_worker_result(result, ctx.task_id()); + let subtask_ids = ctx.into_subtask_ids(); + let outcome = Outcome::from_worker_result(result, task_id, subtask_ids); shared.send_or_store_outcome(outcome, outcome_tx); } } @@ -843,25 +863,37 @@ mod no_retry { #[cfg(feature = "retry")] mod retry { - use crate::bee::{ApplyError, Queen, Worker}; - use crate::hive::{Hive, LocalQueues, Outcome, Shared, Task}; + use super::HiveTaskContext; + use crate::bee::{ApplyError, Context, Queen, Worker}; + use crate::hive::{GlobalQueue, Hive, LocalQueues, Outcome, Shared, Task}; + use std::sync::Arc; impl> Hive { - #[inline] - pub(super) fn execute>( + pub(super) fn execute, L: LocalQueues>( task: Task, thread_index: usize, worker: &mut W, - shared: &Shared, + shared: &Arc>, ) { - let (input, mut ctx, outcome_tx) = task.into_parts(); - match worker.apply(input, &ctx) { - Err(ApplyError::Retryable { input, .. }) if shared.can_retry(&ctx) => { - ctx.inc_attempt(); - shared.queue_retry(thread_index, input, ctx, outcome_tx); + let (task_id, input, attempt, outcome_tx) = task.into_parts(); + let task_ctx = HiveTaskContext { + thread_index, + shared, + outcome_tx: outcome_tx.as_ref(), + }; + let ctx = Context::new(task_id, attempt, Some(Box::new(&task_ctx))); + // execute the task until it succeeds or we reach maximum retries - this should + // be the only place where a panic can occur + let result = worker.apply(input, &ctx); + let subtask_ids = ctx.into_subtask_ids(); + match result { + Err(ApplyError::Retryable { input, .. }) + if subtask_ids.is_none() && shared.can_retry(attempt) => + { + shared.send_retry(task_id, input, outcome_tx, attempt + 1, thread_index); } result => { - let outcome = Outcome::from_worker_result(result, ctx.task_id()); + let outcome = Outcome::from_worker_result(result, task_id, subtask_ids); shared.send_or_store_outcome(outcome, outcome_tx); } } @@ -909,7 +941,7 @@ mod tests { assert_eq!(hive.max_workers(), 4); assert_eq!(hive.alive_workers(), 4); // poison hive using private method - hive.0.as_ref().unwrap().shared.poison(); + hive.0.as_ref().unwrap().poison(); // attempt to spawn a new task assert!(matches!(hive.grow(1), Err(Poisoned))); // make sure the worker count wasn't increased @@ -923,11 +955,11 @@ mod tests { .num_threads(4) .build_with(Caller::of(|i: usize| i * 2)); // poison hive using private method - hive.0.as_ref().unwrap().shared.poison(); + hive.0.as_ref().unwrap().poison(); // submit a task, check that it comes back unprocessed let (tx, rx) = outcome_channel(); let sent_input = 1; - let sent_task_id = hive.apply_send(sent_input, tx.clone()); + let sent_task_id = hive.apply_send(sent_input, &tx); let outcome = rx.recv().unwrap(); match outcome { Outcome::Unprocessed { input, task_id } => { @@ -944,12 +976,12 @@ mod tests { .num_threads(4) .build_with(Caller::of(|i: usize| i * 2)); // poison hive using private method - hive.0.as_ref().unwrap().shared.poison(); + hive.0.as_ref().unwrap().poison(); // submit a task, check that it comes back unprocessed let (tx, rx) = outcome_channel(); let inputs = 0..10; let task_ids: HashMap = hive - .swarm_send(inputs.clone(), tx) + .swarm_send(inputs.clone(), &tx) .into_iter() .zip(inputs) .collect(); diff --git a/src/hive/counter.rs b/src/hive/counter.rs index bb19c11..6a4741d 100644 --- a/src/hive/counter.rs +++ b/src/hive/counter.rs @@ -4,8 +4,6 @@ use crate::atomic::{Atomic, AtomicInt, AtomicU64, Ordering, Orderings}; const SEQCST_ORDERING: Orderings = Orderings { load: Ordering::SeqCst, swap: Ordering::SeqCst, - fetch_update_set: Ordering::SeqCst, - fetch_update_fetch: Ordering::SeqCst, fetch_add: Ordering::SeqCst, fetch_sub: Ordering::SeqCst, }; diff --git a/src/hive/husk.rs b/src/hive/husk.rs index 58ee534..0c12034 100644 --- a/src/hive/husk.rs +++ b/src/hive/husk.rs @@ -67,7 +67,7 @@ impl> Husk { /// This method returns a `SpawnError` if there is an error creating the new `Hive`. pub fn into_hive_swarm_send_unprocessed( mut self, - tx: OutcomeSender, + tx: &OutcomeSender, ) -> (Hive, Vec) { let unprocessed: Vec<_> = self .remove_all_unprocessed() @@ -182,7 +182,7 @@ mod tests { hive1.suspend(); let husk1 = hive1.try_into_husk().unwrap(); let (tx, rx) = outcome_channel(); - let (hive2, task_ids) = husk1.into_hive_swarm_send_unprocessed(tx); + let (hive2, task_ids) = husk1.into_hive_swarm_send_unprocessed(&tx); // now spin up worker threads to process the tasks hive2.grow(8).expect("error spawning threads"); hive2.join(); diff --git a/src/hive/mod.rs b/src/hive/mod.rs index 8332c64..a912049 100644 --- a/src/hive/mod.rs +++ b/src/hive/mod.rs @@ -404,9 +404,9 @@ pub mod prelude { use self::counter::DualCounter; use self::gate::{Gate, PhasedGate}; use self::outcome::{DerefOutcomes, OwnedOutcomes}; -use self::task::LocalQueuesImpl; +use self::task::{ChannelGlobalQueue, ChannelLocalQueues}; use crate::atomic::{AtomicAny, AtomicBool, AtomicOption, AtomicUsize}; -use crate::bee::{Context, Queen, TaskId, Worker}; +use crate::bee::{Queen, TaskId, Worker}; use parking_lot::Mutex; use std::collections::HashMap; use std::io::Error as SpawnError; @@ -423,26 +423,23 @@ type U64 = AtomicOption; /// A pool of worker threads that each execute the same function. /// /// See the [module documentation](crate::hive) for details. -pub struct Hive>(Option>); - -/// A `Hive`'s inner state. Wraps a) the `Hive`'s reference to the `Shared` data (which is shared -/// with the worker threads) and b) the `Sender>`, which is the sending end of the channel -/// used to send tasks to the worker threads. -struct HiveInner> { - task_tx: TaskSender, - shared: Arc>>, -} +pub struct Hive>( + Option, ChannelLocalQueues>>>, +); /// Type alias for the input task channel sender -type TaskSender = std::sync::mpsc::Sender>; +type TaskSender = crossbeam_channel::Sender>; /// Type alias for the input task channel receiver -type TaskReceiver = std::sync::mpsc::Receiver>; +type TaskReceiver = crossbeam_channel::Receiver>; /// Internal representation of a task to be processed by a `Hive`. +#[derive(Debug)] struct Task { + id: TaskId, input: W::Input, - ctx: Context, outcome_tx: Option>, + #[cfg(feature = "retry")] + attempt: u32, } /// Core configuration parameters that are set by a `Builder`, used in a `Hive`, and preserved in a @@ -471,13 +468,15 @@ struct Config { } /// Data shared by all worker threads in a `Hive`. -struct Shared, L: LocalQueues> { +struct Shared, G: GlobalQueue, L: LocalQueues> { /// core configuration parameters config: Config, + /// global task queue used by the `Hive` to send tasks to the worker threads + global_queue: G, + /// local queues used by worker threads to manage tasks + local_queues: L, /// the `Queen` used to create new workers queen: Mutex, - /// receiver for the channel used by the `Hive` to send tasks to the worker threads - task_rx: Mutex>, /// The results of spawning each worker spawn_results: Mutex, SpawnError>>>, /// allows for 2^48 queued tasks and 2^16 active tasks @@ -494,7 +493,7 @@ struct Shared, L: LocalQueues> { poisoned: AtomicBool, /// whether the hive is suspended - if true, active tasks may complete and new tasks may be /// queued, but new tasks will not be processed - suspended: Arc, + suspended: AtomicBool, /// gate used by worker threads to wait until the hive is resumed resume_gate: Gate, /// gate used by client threads to wait until all tasks have completed @@ -502,14 +501,94 @@ struct Shared, L: LocalQueues> { /// outcomes stored in the hive /// TODO: switch to using thread-local outcome maps that need to be gathered when extracting outcomes: Mutex>>, - /// local queues used by worker threads to manage tasks - local_queues: L, } -/// Trait that provides access to thread-specific queues for managing tasks. Ideally, these queues -/// would be managed in a global thread-local data structure, but it has a generic type that -/// requires it to be stored within the Hive's shared data. -trait LocalQueues: Sized + Default + Send + Sync + 'static {} +#[derive(thiserror::Error, Debug)] +pub enum GlobalPopError { + #[error("Task queue is closed")] + Closed, + #[error("The hive has been poisoned")] + Poisoned, +} + +/// Trait that provides access to a global queue for receiving tasks. +trait GlobalQueue: Sized + Send + Sync + 'static { + /// Tries to add a task to the global queue. + /// + /// Returns an error if the queue is disconnected. + fn try_push(&self, task: Task) -> Result<(), Task>; + + /// Tries to take a task from the global queue. + /// + /// Returns `None` if a task is not available, where each implementation may have a different + /// definition of "available". + /// + /// Returns an error if the queue is disconnected. + fn try_pop(&self) -> Option, GlobalPopError>>; + + /// Drains all tasks from the global queue and returns them as an iterator. + fn drain(&self) -> Vec>; + + /// Closes this `GlobalQueue` so no more tasks may be pushed. + fn close(&self); +} + +/// Trait that provides access to thread-specific queues for managing tasks. +/// +/// Ideally, these queues would be managed in a global thread-local data structure, but since tasks +/// are `Worker`-specific, each `Hive` must have it's own set of queues stored within the Hive's +/// shared data. +trait LocalQueues>: Sized + Default + Send + Sync + 'static { + /// Initializes the local queues for the given range of worker thread indices. + fn init_for_threads>( + &self, + start_index: usize, + end_index: usize, + shared: &Shared, + ); + + /// Changes the size of the local queues to `size`. + #[cfg(feature = "batching")] + fn resize>( + &self, + start_index: usize, + end_index: usize, + new_size: usize, + shared: &Shared, + ); + + /// Attempts to add a task to the local queue if space is available, otherwise adds it to the + /// global queue. If adding to the global queue fails, the task is abandoned (converted to an + /// `Unprocessed` outcome and sent to the outcome channel or stored in the hive). + fn push>( + &self, + task: Task, + thread_index: usize, + shared: &Shared, + ); + + /// Attempts to remove a task from the local queue for the given worker thread index. + /// + /// Returns `None` if there is no task immediately available. + fn try_pop>( + &self, + thread_index: usize, + shared: &Shared, + ) -> Option>; + + /// Drains all tasks from all local queues and returns them as an iterator. + fn drain(&self) -> Vec>; + + /// Attempts to add `task` to the local retry queue. Returns the earliest `Instant` at which it + /// might be retried. + #[cfg(feature = "retry")] + fn retry>( + &self, + task: Task, + thread_index: usize, + shared: &Shared, + ) -> Option; +} #[cfg(test)] mod tests { @@ -568,7 +647,7 @@ mod tests { let hive = thunk_hive::(0); // check that with 0 threads no tasks are scheduled let (tx, rx) = super::outcome_channel(); - let _ = hive.apply_send(Thunk::of(|| 0), tx); + let _ = hive.apply_send(Thunk::of(|| 0), &tx); thread::sleep(ONE_SEC); assert_eq!(hive.num_tasks().0, 1); assert!(matches!(rx.try_recv_msg(), Message::ChannelEmpty)); @@ -637,7 +716,11 @@ mod tests { type Output = u8; type Error = (); - fn apply_ref(&mut self, input: &Self::Input, ctx: &Context) -> RefWorkerResult { + fn apply_ref( + &mut self, + input: &Self::Input, + ctx: &Context, + ) -> RefWorkerResult { for _ in 0..3 { thread::sleep(Duration::from_secs(1)); if ctx.is_cancelled() { @@ -702,7 +785,7 @@ mod tests { let (tx, _) = super::outcome_channel(); // Panic all the existing threads. for _ in 0..TEST_TASKS { - hive.apply_send(Thunk::of(|| panic!("intentional panic")), tx.clone()); + hive.apply_send(Thunk::of(|| panic!("intentional panic")), &tx); } hive.join(); // Ensure that none of the threads have panicked @@ -721,7 +804,7 @@ mod tests { let (tx, rx) = super::outcome_channel(); // Panic all the existing threads. for i in 0..TEST_TASKS { - hive.apply_send(i as u8, tx.clone()); + hive.apply_send(i as u8, &tx); } hive.join(); // Ensure that none of the threads have panicked @@ -1042,7 +1125,7 @@ mod tests { i }) }), - tx, + &tx, ); let (mut outcome_task_ids, values): (Vec, Vec) = rx .iter() @@ -1130,7 +1213,7 @@ mod tests { i }) }), - tx, + &tx, ); let (mut outcome_task_ids, values): (Vec, Vec) = rx .iter() @@ -1202,7 +1285,7 @@ mod tests { .num_threads(4) .build_with(Caller::of(|i| i * i)); let (tx, rx) = super::outcome_channel(); - let (mut task_ids, state) = hive.scan_send(0..10, tx, 0, |acc, i| { + let (mut task_ids, state) = hive.scan_send(0..10, &tx, 0, |acc, i| { *acc += i; *acc }); @@ -1237,7 +1320,7 @@ mod tests { .num_threads(4) .build_with(Caller::of(|i| i * i)); let (tx, rx) = super::outcome_channel(); - let (results, state) = hive.try_scan_send(0..10, tx, 0, |acc, i| { + let (results, state) = hive.try_scan_send(0..10, &tx, 0, |acc, i| { *acc += i; Ok::<_, String>(*acc) }); @@ -1275,7 +1358,7 @@ mod tests { .build_with(OnceCaller::of(|i: i32| Ok::<_, String>(i * i))); let (tx, _) = super::outcome_channel(); let _ = hive - .try_scan_send(0..10, tx, 0, |_, _| Err("fail")) + .try_scan_send(0..10, &tx, 0, |_, _| Err("fail")) .0 .into_iter() .map(Result::unwrap) @@ -1608,7 +1691,7 @@ mod tests { // return results to your own channel... let (tx, rx) = crate::hive::outcome_channel(); - let task_ids = hive.swarm_send((0..10).map(|i: i32| Thunk::of(move || i * i)), tx); + let task_ids = hive.swarm_send((0..10).map(|i: i32| Thunk::of(move || i * i)), &tx); let outputs: Vec<_> = rx.select_unordered_outputs(task_ids).collect(); assert_eq!(285, outputs.into_iter().sum()); @@ -1652,7 +1735,7 @@ mod tests { type Output = String; type Error = io::Error; - fn apply(&mut self, input: Self::Input, _: &Context) -> WorkerResult { + fn apply(&mut self, input: Self::Input, _: &Context) -> WorkerResult { self.write_char(input).map_err(|error| ApplyError::Fatal { input: Some(input), error, @@ -1816,7 +1899,7 @@ mod batching_tests { thread::sleep(Duration::from_millis(100)); thread::current().id() }), - tx.clone(), + &tx, ); thread::sleep(Duration::from_millis(100)); task_id @@ -1830,7 +1913,7 @@ mod batching_tests { thread::current().id() }) }), - tx.clone(), + &tx, ); init_task_ids.into_iter().chain(rest_task_ids).collect() } @@ -1932,7 +2015,7 @@ mod retry_tests { use crate::hive::{Builder, Outcome, OutcomeIteratorExt}; use std::time::{Duration, SystemTime}; - fn echo_time(i: usize, ctx: &Context) -> Result> { + fn echo_time(i: usize, ctx: &Context) -> Result> { let attempt = ctx.attempt(); if attempt == 3 { Ok("Success".into()) @@ -1960,7 +2043,10 @@ mod retry_tests { #[test] fn test_retries_fail() { - fn sometimes_fail(i: usize, _: &Context) -> Result> { + fn sometimes_fail( + i: usize, + _: &Context, + ) -> Result> { match i % 3 { 0 => Ok("Success".into()), 1 => Err(ApplyError::Retryable { diff --git a/src/hive/outcome/outcome.rs b/src/hive/outcome/outcome.rs index be7c179..ff926a0 100644 --- a/src/hive/outcome/outcome.rs +++ b/src/hive/outcome/outcome.rs @@ -5,7 +5,9 @@ use std::fmt::Debug; /// The possible outcomes of a task execution. /// -/// Each outcome includes the task ID of the task that produced it. +/// Each outcome includes the task ID of the task that produced it. Tasks that submitted +/// subtasks (via [`crate::bee::Context::submit_task`]) produce `Outcome` variants that have +/// `subtask_ids`. /// /// Note that `Outcome`s can only be compared or ordered with other `Outcome`s produced by the same /// `Hive`, because comparison/ordering is completely based on the task ID. @@ -13,6 +15,13 @@ use std::fmt::Debug; pub enum Outcome { /// The task was executed successfully. Success { value: W::Output, task_id: TaskId }, + /// The task was executed successfully, and it also submitted one or more subtask_ids to the + /// `Hive`. + SuccessWithSubtasks { + value: W::Output, + task_id: TaskId, + subtask_ids: Vec, + }, /// The task failed with an error that was not retryable. The input value that caused the /// failure is provided if possible. Failure { @@ -20,9 +29,24 @@ pub enum Outcome { error: W::Error, task_id: TaskId, }, + /// The task failed with an error that was not retryable, but it submitted one or more subtask_ids + /// before failing. The input value that caused the failure is provided if possible. + FailureWithSubtasks { + input: Option, + error: W::Error, + task_id: TaskId, + subtask_ids: Vec, + }, /// The task was not executed before the Hive was dropped, or processing of the task was /// interrupted (e.g., by `suspend`ing the `Hive`). Unprocessed { input: W::Input, task_id: TaskId }, + /// The task was not executed before the Hive was dropped, or processing of the task was + /// interrupted (e.g., by `suspend`ing the `Hive`), but it first submitted one or more subtask_ids. + UnprocessedWithSubtasks { + input: W::Input, + task_id: TaskId, + subtask_ids: Vec, + }, /// The task with the given task_id was not found in the `Hive` or iterator from which it was /// being requested. Missing { task_id: TaskId }, @@ -32,6 +56,14 @@ pub enum Outcome { payload: Panic, task_id: TaskId, }, + /// The task panicked, but it submitted one or more subtask_ids before panicking. The input value + /// that caused the panic is provided if possible. + PanicWithSubtasks { + input: Option, + payload: Panic, + task_id: TaskId, + subtask_ids: Vec, + }, /// The task failed after retrying the maximum number of times. #[cfg(feature = "retry")] MaxRetriesAttempted { @@ -42,11 +74,28 @@ pub enum Outcome { } impl Outcome { - /// Converts a worker `result` into an `Outcome` with the given task_id. - pub(in crate::hive) fn from_worker_result(result: WorkerResult, task_id: TaskId) -> Self { - match result { - Ok(value) => Self::Success { task_id, value }, - Err(ApplyError::Retryable { input, error }) => { + /// Converts a worker `result` into an `Outcome` with the given task_id and optional subtask ids. + pub(in crate::hive) fn from_worker_result( + result: WorkerResult, + task_id: TaskId, + subtask_ids: Option>, + ) -> Self { + match (result, subtask_ids) { + (Ok(value), Some(subtask_ids)) => Self::SuccessWithSubtasks { + value, + task_id, + subtask_ids, + }, + (Ok(value), None) => Self::Success { value, task_id }, + (Err(ApplyError::Retryable { input, error, .. }), Some(subtask_ids)) => { + Self::FailureWithSubtasks { + input: Some(input), + error, + task_id, + subtask_ids, + } + } + (Err(ApplyError::Retryable { input, error }), None) => { #[cfg(feature = "retry")] { Self::MaxRetriesAttempted { @@ -64,13 +113,36 @@ impl Outcome { } } } - Err(ApplyError::Fatal { input, error }) => Self::Failure { + (Err(ApplyError::Fatal { input, error }), Some(subtask_ids)) => { + Self::FailureWithSubtasks { + input, + error, + task_id, + subtask_ids, + } + } + (Err(ApplyError::Fatal { input, error }), None) => Self::Failure { input, error, task_id, }, - Err(ApplyError::Cancelled { input }) => Self::Unprocessed { input, task_id }, - Err(ApplyError::Panic { input, payload }) => Self::Panic { + (Err(ApplyError::Cancelled { input }), Some(subtask_ids)) => { + Self::UnprocessedWithSubtasks { + input, + task_id, + subtask_ids, + } + } + (Err(ApplyError::Cancelled { input }), None) => Self::Unprocessed { input, task_id }, + (Err(ApplyError::Panic { input, payload }), Some(subtask_ids)) => { + Self::PanicWithSubtasks { + input, + payload, + task_id, + subtask_ids, + } + } + (Err(ApplyError::Panic { input, payload }), None) => Self::Panic { input, payload, task_id, @@ -102,15 +174,31 @@ impl Outcome { pub fn task_id(&self) -> &TaskId { match self { Self::Success { task_id, .. } + | Self::SuccessWithSubtasks { task_id, .. } | Self::Failure { task_id, .. } + | Self::FailureWithSubtasks { task_id, .. } | Self::Unprocessed { task_id, .. } + | Self::UnprocessedWithSubtasks { task_id, .. } | Self::Missing { task_id } - | Self::Panic { task_id, .. } => task_id, + | Self::Panic { task_id, .. } + | Self::PanicWithSubtasks { task_id, .. } => task_id, #[cfg(feature = "retry")] Self::MaxRetriesAttempted { task_id, .. } => task_id, } } + /// Returns the IDs of the tasks submitted by the task that produced this outcome, or `None` + /// if the task did not submit any subtasks. + pub fn subtask_ids(&self) -> Option<&Vec> { + match self { + Self::SuccessWithSubtasks { subtask_ids, .. } + | Self::FailureWithSubtasks { subtask_ids, .. } + | Self::UnprocessedWithSubtasks { subtask_ids, .. } + | Self::PanicWithSubtasks { subtask_ids, .. } => Some(subtask_ids), + _ => None, + } + } + /// Consumes this `Outcome` and returns the value if it is a `Success`, otherwise panics. pub fn unwrap(self) -> W::Output { self.success().expect("not a Success outcome") @@ -127,11 +215,14 @@ impl Outcome { /// Consumes this `Outcome` and returns the input value if available, otherwise `None`. pub fn try_into_input(self) -> Option { match self { - Self::Success { .. } => None, - Self::Failure { input, .. } => input, - Self::Unprocessed { input, .. } => Some(input), - Self::Missing { .. } => None, - Self::Panic { input, .. } => input, + Self::Failure { input, .. } + | Self::FailureWithSubtasks { input, .. } + | Self::Panic { input, .. } + | Self::PanicWithSubtasks { input, .. } => input, + Self::Unprocessed { input, .. } | Self::UnprocessedWithSubtasks { input, .. } => { + Some(input) + } + Self::Success { .. } | Self::SuccessWithSubtasks { .. } | Self::Missing { .. } => None, #[cfg(feature = "retry")] Self::MaxRetriesAttempted { input, .. } => Some(input), } @@ -143,11 +234,15 @@ impl Outcome { /// * Otherwise returns `None`. pub fn try_into_error(self) -> Option { match self { - Self::Success { .. } => None, - Self::Failure { error, .. } => Some(error), - Self::Unprocessed { .. } => None, - Self::Missing { .. } => None, - Self::Panic { payload, .. } => payload.resume(), + Self::Failure { error, .. } | Self::FailureWithSubtasks { error, .. } => Some(error), + Self::Panic { payload, .. } | Self::PanicWithSubtasks { payload, .. } => { + payload.resume() + } + Self::Success { .. } + | Self::SuccessWithSubtasks { .. } + | Self::Unprocessed { .. } + | Self::UnprocessedWithSubtasks { .. } + | Self::Missing { .. } => None, #[cfg(feature = "retry")] Self::MaxRetriesAttempted { error, .. } => Some(error), } @@ -157,11 +252,27 @@ impl Outcome { impl PartialEq for Outcome { fn eq(&self, other: &Self) -> bool { match (self, other) { - (Self::Success { task_id: a, .. }, Self::Success { task_id: b, .. }) => a == b, - (Self::Failure { task_id: a, .. }, Self::Failure { task_id: b, .. }) => a == b, - (Self::Unprocessed { task_id: a, .. }, Self::Unprocessed { task_id: b, .. }) => a == b, - (Self::Missing { task_id: a }, Self::Missing { task_id: b }) => a == b, - (Self::Panic { task_id: a, .. }, Self::Panic { task_id: b, .. }) => a == b, + (Self::Success { task_id: a, .. }, Self::Success { task_id: b, .. }) + | ( + Self::SuccessWithSubtasks { task_id: a, .. }, + Self::SuccessWithSubtasks { task_id: b, .. }, + ) + | (Self::Failure { task_id: a, .. }, Self::Failure { task_id: b, .. }) + | ( + Self::FailureWithSubtasks { task_id: a, .. }, + Self::FailureWithSubtasks { task_id: b, .. }, + ) + | (Self::Unprocessed { task_id: a, .. }, Self::Unprocessed { task_id: b, .. }) + | ( + Self::UnprocessedWithSubtasks { task_id: a, .. }, + Self::UnprocessedWithSubtasks { task_id: b, .. }, + ) + | (Self::Missing { task_id: a }, Self::Missing { task_id: b }) + | (Self::Panic { task_id: a, .. }, Self::Panic { task_id: b, .. }) + | ( + Self::PanicWithSubtasks { task_id: a, .. }, + Self::PanicWithSubtasks { task_id: b, .. }, + ) => a == b, #[cfg(feature = "retry")] ( Self::MaxRetriesAttempted { task_id: a, .. }, diff --git a/src/hive/outcome/store.rs b/src/hive/outcome/store.rs index 0a22b7c..b38835d 100644 --- a/src/hive/outcome/store.rs +++ b/src/hive/outcome/store.rs @@ -354,7 +354,7 @@ mod tests { type Output = u8; type Error = (); - fn apply(&mut self, i: Self::Input, _: &Context) -> WorkerResult { + fn apply(&mut self, i: Self::Input, _: &Context) -> WorkerResult { Ok(i) } } diff --git a/src/hive/shared.rs b/src/hive/shared.rs index cc02f9a..91cdc72 100644 --- a/src/hive/shared.rs +++ b/src/hive/shared.rs @@ -1,23 +1,31 @@ -use super::{Config, LocalQueues, Outcome, OutcomeSender, Shared, SpawnError, Task, TaskReceiver}; +use super::task::ChannelGlobalQueue; +use super::{ + Config, GlobalQueue, Husk, LocalQueues, Outcome, OutcomeSender, Shared, SpawnError, Task, +}; use crate::atomic::{Atomic, AtomicInt, AtomicUsize}; -use crate::bee::{Context, Queen, TaskId, Worker}; +use crate::bee::{Queen, TaskId, Worker}; use crate::channel::SenderExt; use parking_lot::Mutex; use std::collections::HashMap; use std::ops::DerefMut; -use std::sync::mpsc::RecvTimeoutError; use std::thread::{Builder, JoinHandle}; -use std::time::Duration; use std::{fmt, iter, mem}; -impl, L: LocalQueues> Shared { +impl Shared +where + W: Worker, + Q: Queen, + G: GlobalQueue, + L: LocalQueues, +{ /// Creates a new `Shared` instance with the given configuration, queen, and task receiver, /// and all other fields set to their default values. - pub fn new(config: Config, queen: Q, task_rx: TaskReceiver) -> Self { + pub fn new(config: Config, global_queue: G, queen: Q) -> Self { Shared { config, + global_queue, queen: Mutex::new(queen), - task_rx: Mutex::new(task_rx), + local_queues: Default::default(), spawn_results: Default::default(), num_tasks: Default::default(), next_task_id: Default::default(), @@ -28,7 +36,6 @@ impl, L: LocalQueues> Shared { resume_gate: Default::default(), join_gate: Default::default(), outcomes: Default::default(), - local_queues: Default::default(), } } @@ -84,10 +91,8 @@ impl, L: LocalQueues> Shared { assert_eq!(spawn_results.len(), start_index); let end_index = start_index + num_threads; // if worker threads need a local queue, initialize them before spawning - #[cfg(feature = "batching")] - self.init_local_queues(start_index, end_index); - #[cfg(feature = "retry")] - self.init_retry_queues(start_index, end_index); + self.local_queues + .init_for_threads(start_index, end_index, self); // spawn the worker threads and return the results let results: Vec<_> = (start_index..end_index).map(f).collect(); spawn_results.reserve(num_threads); @@ -146,29 +151,75 @@ impl, L: LocalQueues> Shared { /// Increments the number of queued tasks. Returns a new `Task` with the provided input and /// `outcome_tx` and the next ID. - pub fn prepare_task(&self, input: W::Input, outcome_tx: Option>) -> Task { + fn prepare_task(&self, input: W::Input, outcome_tx: Option<&OutcomeSender>) -> Task { self.num_tasks .increment_left(1) .expect("overflowed queued task counter"); let task_id = self.next_task_id.add(1); - let ctx = Context::new(task_id, self.suspended.clone()); - Task::new(input, ctx, outcome_tx) + Task::new(task_id, input, outcome_tx.cloned()) + } + + /// Adds `task` to the global queue if possible, otherwise abandons it - converts it to an + /// `Unprocessed` outcome and sends it to the outcome channel or stores it in the hive. + pub fn push_global(&self, task: Task) { + // try to send the task to the hive; if the hive is poisoned or if sending fails, convert + // the task into an `Unprocessed` outcome and try to send it to the outcome channel; if + // that fails, store the outcome in the hive + if let Some(abandoned_task) = if self.is_poisoned() { + Some(task) + } else { + self.global_queue.try_push(task).err() + } { + self.abandon_task(abandoned_task); + } } - /// Increments the number of queued tasks by the number of provided inputs. Returns an iterator - /// over `Task`s created from the provided inputs, `outcome_tx`s, and sequential task_ids. - pub fn prepare_batch<'a, T: Iterator + 'a>( - &'a self, - min_size: usize, + /// Creates a new `Task` for the given input and outcome channel, and adds it to the global + /// queue. + pub fn send_one_global( + &self, + input: W::Input, + outcome_tx: Option<&OutcomeSender>, + ) -> TaskId { + let task = self.prepare_task(input, outcome_tx); + let task_id = task.id(); + self.push_global(task); + task_id + } + + /// Creates a new `Task` for the given input and outcome channel, and attempts to add it to + /// the local queue for the specified `thread_index`. Falls back to adding it to the global + /// queue. + pub fn send_one_local( + &self, + input: W::Input, + outcome_tx: Option<&OutcomeSender>, + thread_index: usize, + ) -> TaskId { + let task = self.prepare_task(input, outcome_tx); + let task_id = task.id(); + self.local_queues.push(task, thread_index, &self); + task_id + } + + /// Creates a new `Task` for each input in the given batch and sends them to the global queue. + pub fn send_batch_global( + &self, inputs: T, - outcome_tx: Option>, - ) -> impl Iterator> + 'a { + outcome_tx: Option<&OutcomeSender>, + ) -> Vec + where + T: IntoIterator, + T::IntoIter: ExactSizeIterator, + { + let iter = inputs.into_iter(); + let (min_size, _) = iter.size_hint(); self.num_tasks .increment_left(min_size as u64) .expect("overflowed queued task counter"); let task_id_start = self.next_task_id.add(min_size); let task_id_end = task_id_start + min_size; - inputs + let tasks = iter .map(Some) .chain(iter::repeat_with(|| None)) .zip( @@ -177,16 +228,72 @@ impl, L: LocalQueues> Shared { .chain(iter::repeat_with(|| None)), ) .map_while(move |pair| match pair { - (Some(input), Some(task_id)) => Some(Task { - input, - ctx: Context::new(task_id, self.suspended.clone()), - //attempt: 0, - outcome_tx: outcome_tx.clone(), - }), - (Some(input), None) => Some(self.prepare_task(input, outcome_tx.clone())), + (Some(input), Some(task_id)) => { + Some(Task::new(task_id, input, outcome_tx.cloned())) + } + (Some(input), None) => Some(self.prepare_task(input, outcome_tx)), (None, Some(_)) => panic!("batch contained fewer than {min_size} items"), (None, None) => None, - }) + }); + if !self.is_poisoned() { + tasks + .map(|task| { + let task_id = task.id(); + // try to send the task to the hive; if sending fails, convert the task into an + // `Unprocessed` outcome and try to send it to the outcome channel; if that + // fails, store the outcome in the hive + if let Err(task) = self.global_queue.try_push(task) { + self.abandon_task(task); + } + task_id + }) + .collect() + } else { + // if the hive is poisoned, convert all tasks into `Unprocessed` outcomes and try to + // send them to their outcome channels or store them in the hive + self.abandon_batch(tasks) + } + } + + /// Returns the next available `Task`. If there is a task in any local queue, it is returned, + /// otherwise a task is requested from the global queue. + /// + /// If the hive is suspended, the calling thread blocks until the `Hive` is resumed. + /// The calling thread also blocks until a task becomes available. + /// + /// Returns an error if the hive is poisoned or if the local queues are empty, and the global + /// queue is disconnected. + pub fn get_next_task(&self, thread_index: usize) -> Option> { + loop { + // block while the hive is suspended + self.wait_on_resume(); + // stop iteration if the hive is poisoned + if self.is_poisoned() { + return None; + } + // try to get a task from the local queues + if let Some(task) = self.local_queues.try_pop(thread_index, &self) { + break Ok(task); + } + // fall back to requesting a task from the global queue + if let Some(result) = self.global_queue.try_pop() { + break result; + } + } + // if a task was successfully received, decrement the queued counter and increment the + // active counter + .map(|task| match self.num_tasks.transfer(1) { + Ok(_) => Some(task), + Err(_) => { + // the hive is in a corrupted state - abandon this task and then poison the hive + // so it can't be used anymore + self.abandon_task(task); + self.poison(); + None + } + }) + .ok() + .flatten() } /// Sends an outcome to `outcome_tx`, or stores it in the `Hive` shared data if there is no @@ -337,6 +444,11 @@ impl, L: LocalQueues> Shared { self.suspended.get() } + #[inline] + pub fn wait_on_resume(&self) { + self.resume_gate.wait_while(|| self.is_suspended()); + } + /// Returns a mutable reference to the retained task outcomes. pub fn outcomes(&self) -> impl DerefMut>> + '_ { self.outcomes.lock() @@ -367,9 +479,36 @@ impl, L: LocalQueues> Shared { .map(|task_id| outcomes.remove(&task_id).unwrap()) .collect() } + + /// Drains all queued tasks, converts them into `Outcome::Unprocessed` outcomes, and tries + /// to send them or (if the task does not have a sender, or if the send fails) stores them + /// in the `outcomes` map. + fn drain_tasks_into_unprocessed(&self) { + self.abandon_batch(self.global_queue.drain().into_iter()); + self.abandon_batch(self.local_queues.drain().into_iter()); + } + + /// Consumes this `Shared` and returns a `Husk` containing the `Queen`, panic count, stored + /// outcomes, and all configuration information necessary to create a new `Hive`. Any queued + /// tasks are converted into `Outcome::Unprocessed` outcomes and either sent to the task's + /// sender or (if there is no sender, or the send fails) stored in the `outcomes` map. + pub fn into_husk(self) -> Husk { + self.drain_tasks_into_unprocessed(); + Husk::new( + self.config.into_unsync(), + self.queen.into_inner(), + self.num_panics.into_inner(), + self.outcomes.into_inner(), + ) + } } -impl, L: LocalQueues> fmt::Debug for Shared { +impl fmt::Debug for Shared, L> +where + W: Worker, + Q: Queen, + L: LocalQueues>, +{ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let (queued, active) = self.num_tasks(); f.debug_struct("Shared") @@ -406,50 +545,18 @@ mod affinity { } } -#[inline] -fn task_recv_timeout(rx: &TaskReceiver) -> Option, NextTaskError>> { - // time to wait in between polling the retry queue and then the task receiver - const RECV_TIMEOUT: Duration = Duration::from_secs(1); - match rx.recv_timeout(RECV_TIMEOUT) { - Ok(task) => Some(Ok(task)), - Err(RecvTimeoutError::Disconnected) => Some(Err(NextTaskError::Disconnected)), - Err(RecvTimeoutError::Timeout) => None, - } -} - -#[cfg(not(feature = "batching"))] -mod no_batching { - use super::{LocalQueue, NextTaskError, Shared, Task}; - use crate::bee::{Queen, Worker}; - - impl, L: LocalQueue> Shared { - /// Tries to receive a task from the input channel. - /// - /// Returns an error if the channel has disconnected. Returns `None` if a task is not - /// received within the timeout period (currently hard-coded to 1 second). - #[inline] - pub(super) fn get_task(&self, _: usize) -> Option, NextTaskError>> { - super::task_recv_timeout(&self.task_rx.lock()) - } - } -} - #[cfg(feature = "batching")] mod batching { - use super::{NextTaskError, Shared, Task}; use crate::bee::{Queen, Worker}; - use crate::hive::LocalQueues; - use crossbeam_queue::ArrayQueue; - use std::collections::HashSet; - use std::time::Duration; - - impl, L: LocalQueues> Shared { - pub(super) fn init_local_queues(&self, start_index: usize, end_index: usize) { - let mut local_queues = self.local_queues.write(); - assert_eq!(local_queues.len(), start_index); - (start_index..end_index).for_each(|_| local_queues.push(L::new(self))); - } + use crate::hive::{GlobalQueue, LocalQueues, Shared}; + impl Shared + where + W: Worker, + Q: Queen, + G: GlobalQueue, + L: LocalQueues, + { /// Returns the local queue batch size. pub fn batch_size(&self) -> usize { self.config.batch_size.get().unwrap_or_default() @@ -476,305 +583,50 @@ mod batching { if num_threads == 0 { return prev_batch_size; } - // keep track of which queues need to be resized - // TODO: this method could cause a hang if one of the worker threads is stuck - we - // might want to keep track of each queue's size and if we don't see it shrink within - // a certain amount of time, we give up on that thread and leave it with a wrong-sized - // queue (which should never cause a panic) - let mut to_resize: HashSet = (0..num_threads).collect(); - // iterate until we've resized them all - loop { - // scope the mutable access to local_queues - { - let mut local_queues = self.local_queues.write(); - to_resize.retain(|thread_index| { - let queue = if let Some(queue) = local_queues.get_mut(*thread_index) { - queue - } else { - return false; - }; - if queue.len() > batch_size { - return true; - } - let new_queue = ArrayQueue::new(batch_size); - while let Some(task) = queue.pop() { - if let Err(task) = new_queue.push(task) { - // for some reason we can't push the task to the new queue - // this should never happen, but just in case we turn it into - // an unprocessed outcome - self.abandon_task(task); - } - } - // this is safe because the worker threads can't get readable access to the - // queue while this thread holds the lock - let old_queue = std::mem::replace(queue, new_queue); - assert!(old_queue.is_empty()); - false - }); - } - if to_resize.is_empty() { - return prev_batch_size; - } else { - // short sleep to give worker threads the chance to pull from their queues - std::thread::sleep(Duration::from_millis(10)); - } - } - } - - /// Returns the next task from the local queue if there are any, otherwise attempts to - /// fetch at least 1 and up to `batch_size + 1` tasks from the input channel and puts all - /// but the first one into the local queue. - #[inline] - pub(super) fn get_task( - &self, - thread_index: usize, - ) -> Option, NextTaskError>> { - let local_queue = &self.local_queues.read()[thread_index]; - // pop from the local queue if it has any tasks - if !local_queue.is_empty() { - return Some(Ok(local_queue.pop().unwrap())); - } - // otherwise pull at least 1 and up to `batch_size + 1` tasks from the input channel - let task_rx = self.task_rx.lock(); - // wait for the next task from the receiver - let first = super::task_recv_timeout(&task_rx); - // if we fail after trying to get one, don't keep trying to fill the queue - if first.as_ref().map(|result| result.is_ok()).unwrap_or(false) { - let batch_size = self.batch_size(); - // batch size 0 means batching is disabled - if batch_size > 0 { - // otherwise try to take up to `batch_size` tasks from the input channel - // and add them to the local queue, but don't block if the input channel - // is empty - for result in task_rx - .try_iter() - .take(batch_size) - .map(|task| local_queue.push(task)) - { - if let Err(task) = result { - // for some reason we can't push the task to the local queue; - // this should never happen, but just in case we turn it into an - // unprocessed outcome and stop iterating - self.abandon_task(task); - break; - } - } - } - } - first - } - } -} - -/// Sends each `Task` to its associated outcome sender (if any) or stores it in `outcomes`. -/// TODO: if `outcomes` were `DerefMut` then the argument could either be a mutable referece or -/// a Lazy that aquires the lock on first access. Unfortunately, rust's Lazy does not support -/// mutable access, so we'd need something like OnceCell or OnceMutex. -fn send_or_store>>( - tasks: I, - outcomes: &mut HashMap>, -) { - tasks.for_each(|task| { - let (outcome, outcome_tx) = task.into_unprocessed(); - if let Some(outcome) = if let Some(tx) = outcome_tx { - tx.try_send_msg(outcome) - } else { - Some(outcome) - } { - outcomes.insert(*outcome.task_id(), outcome); - } - }); -} - -#[cfg(not(feature = "retry"))] -mod no_retry { - use super::{LocalQueue, NextTaskError, Task}; - use crate::atomic::Atomic; - use crate::bee::{Queen, Worker}; - use crate::hive::{Husk, Shared}; - - impl, L: LocalQueue> Shared { - /// Returns the next queued `Task`. The thread blocks until a new task becomes available, and - /// since this requires holding a lock on the task `Reciever`, this also blocks any other - /// threads that call this method. Returns `None` if the task `Sender` has hung up and there - /// are no tasks queued. Also returns `None` if the cancelled flag has been set. - pub fn next_task(&self, thread_index: usize) -> Result, NextTaskError> { - loop { - self.resume_gate.wait_while(|| self.is_suspended()); - - if self.is_poisoned() { - return Err(NextTaskError::Poisoned); - } - - if let Some(result) = self.get_task(thread_index) { - break result; - } - } - .and_then(|task| match self.num_tasks.transfer(1) { - Ok(_) => Ok(task), - Err(e) => { - // poison the hive so it can't be used anymore - self.poison(); - Err(NextTaskError::InvalidCounter(e)) - } - }) - } - - /// Drains all queued tasks, converts them into `Outcome::Unprocessed` outcomes, and tries - /// to send them or (if the task does not have a sender, or if the send fails) stores them - /// in the `outcomes` map. - pub fn drain_tasks_into_unprocessed(&self) { - let task_rx = self.task_rx.lock(); - let mut outcomes = self.outcomes.lock(); - super::send_or_store(task_rx.try_iter(), &mut outcomes); - } - - /// Consumes this `Shared` and returns a `Husk` containing the `Queen`, panic count, stored - /// outcomes, and all configuration information necessary to create a new `Hive`. Any queued - /// tasks are converted into `Outcome::Unprocessed` outcomes and either sent to the task's - /// sender or (if there is no sender, or the send fails) stored in the `outcomes` map. - pub fn try_into_husk(self) -> Husk { - let task_rx = self.task_rx.into_inner(); - let mut outcomes = self.outcomes.into_inner(); - super::send_or_store(task_rx.try_iter(), &mut outcomes); - Husk::new( - self.config.into_unsync(), - self.queen.into_inner(), - self.num_panics.into_inner(), - outcomes, - ) + self.local_queues.resize(0, num_threads, batch_size, &self); + prev_batch_size } } } #[cfg(feature = "retry")] mod retry { - use super::NextTaskError; - use crate::atomic::Atomic; - use crate::bee::{Context, Queen, Worker}; - use crate::hive::delay::DelayQueue; - use crate::hive::{Husk, LocalQueues, OutcomeSender, Shared, Task}; - use std::time::{Duration, Instant}; - - impl, L: LocalQueues> Shared { - /// Initializes the retry queues worker threads in the specified range. - pub(super) fn init_retry_queues(&self, start_index: usize, end_index: usize) { - let mut retry_queues = self.retry_queues.write(); - assert_eq!(retry_queues.len(), start_index); - (start_index..end_index).for_each(|_| retry_queues.push(DelayQueue::default())) - } + use crate::bee::{Queen, Worker}; + use crate::hive::{GlobalQueue, LocalQueues, OutcomeSender, Shared, Task, TaskId}; + use std::time::Instant; + impl Shared + where + W: Worker, + Q: Queen, + G: GlobalQueue, + L: LocalQueues, + { /// Returns `true` if the hive is configured to retry tasks and the `attempt` field of the /// given `ctx` is less than the maximum number of retries. - pub fn can_retry(&self, ctx: &Context) -> bool { + pub fn can_retry(&self, attempt: u32) -> bool { self.config .max_retries .get() - .map(|max_retries| ctx.attempt() < max_retries) + .map(|max_retries| attempt < max_retries) .unwrap_or(false) } - /// Adds a task to the retry queue with a delay based on `ctx.attempt()`. - pub fn queue_retry( + /// Adds a task with the given `task_id`, `input`, and `outcome_tx` to the local retry + /// queue for the specified `thread_index`. + pub fn send_retry( &self, - thread_index: usize, + task_id: TaskId, input: W::Input, - ctx: Context, outcome_tx: Option>, + attempt: u32, + thread_index: usize, ) -> Option { - // compute the delay - let delay = self - .config - .retry_factor - .get() - .map(|retry_factor| { - 2u64.checked_pow(ctx.attempt() - 1) - .and_then(|multiplier| { - retry_factor - .checked_mul(multiplier) - .or(Some(u64::MAX)) - .map(Duration::from_nanos) - }) - .unwrap() - }) - .unwrap_or_default(); - // try to queue the task - let task = Task::new(input, ctx, outcome_tx); self.num_tasks .increment_left(1) .expect("overflowed queued task counter"); - if let Some(queue) = self.retry_queues.read().get(thread_index) { - queue.push(task, delay) - } else { - Err(task) - } - // if unable to queue the task, abandon it - .map_err(|task| self.abandon_task(task)) - .ok() - } - - /// Returns the next queued `Task`. The thread blocks until a new task becomes available, - /// and since this requires holding a lock on the task `Reciever`, this also blocks any - /// other threads that call this method. Returns an error if the task `Sender` has hung up - /// and there are no tasks queued for retry. - pub fn next_task(&self, thread_index: usize) -> Result, NextTaskError> { - loop { - self.resume_gate.wait_while(|| self.is_suspended()); - - if self.is_poisoned() { - return Err(NextTaskError::Poisoned); - } - - if let Some(task) = self - .retry_queues - .read() - .get(thread_index) - .and_then(|queue| queue.try_pop()) - { - break Ok(task); - } - - if let Some(result) = self.get_task(thread_index) { - break result; - } - } - .and_then(|task| match self.num_tasks.transfer(1) { - Ok(_) => Ok(task), - Err(e) => Err(NextTaskError::InvalidCounter(e)), - }) - } - - /// Drains all queued tasks, converts them into `Outcome::Unprocessed` outcomes, and tries - /// to send them or (if the task does not have a sender, or if the send fails) stores them - /// in the `outcomes` map. - pub fn drain_tasks_into_unprocessed(&self) { - let mut outcomes = self.outcomes.lock(); - let task_rx = self.task_rx.lock(); - super::send_or_store(task_rx.try_iter(), &mut outcomes); - let mut retry_queue = self.retry_queues.write(); - for queue in retry_queue.iter_mut() { - super::send_or_store(queue.drain(), &mut outcomes); - } - } - - /// Consumes this `Shared` and returns a `Husk` containing the `Queen`, panic count, stored - /// outcomes, and all configuration information necessary to create a new `Hive`. Any queued - /// tasks are converted into `Outcome::Unprocessed` outcomes and either sent to the task's - /// sender or (if there is no sender, or the send fails) stored in the `outcomes` map. - pub fn try_into_husk(self) -> Husk { - let mut outcomes = self.outcomes.into_inner(); - let task_rx = self.task_rx.into_inner(); - super::send_or_store(task_rx.try_iter(), &mut outcomes); - let mut retry_queue = self.retry_queues.into_inner(); - for queue in retry_queue.iter_mut() { - super::send_or_store(queue.drain(), &mut outcomes); - } - Husk::new( - self.config.into_unsync(), - self.queen.into_inner(), - self.num_panics.into_inner(), - outcomes, - ) + let task = Task::with_attempt(task_id, input, outcome_tx, attempt); + self.local_queues.retry(task, thread_index, &self) } } } @@ -783,13 +635,15 @@ mod retry { mod tests { use crate::bee::stock::ThunkWorker; use crate::bee::DefaultQueen; - use crate::hive::LocalQueuesImpl; + use crate::hive::task::ChannelGlobalQueue; + use crate::hive::ChannelLocalQueues; type VoidThunkWorker = ThunkWorker<()>; type VoidThunkWorkerShared = super::Shared< VoidThunkWorker, DefaultQueen, - LocalQueuesImpl, + ChannelGlobalQueue, + ChannelLocalQueues, >; #[test] diff --git a/src/hive/task/delay.rs b/src/hive/task/delay.rs index 7e13186..7faa6c4 100644 --- a/src/hive/task/delay.rs +++ b/src/hive/task/delay.rs @@ -14,6 +14,11 @@ use std::time::{Duration, Instant}; pub struct DelayQueue(UnsafeCell>>); impl DelayQueue { + /// Returns the number of items currently in the queue. + pub fn len(&self) -> usize { + unsafe { self.0.get().as_ref().unwrap().len() } + } + /// Pushes an item onto the queue. Returns the `Instant` at which the item will be available, /// or an error with `item` if there was an error pushing the item. pub fn push(&self, item: T, delay: Duration) -> Result { @@ -120,12 +125,6 @@ mod tests { use super::DelayQueue; use std::{thread, time::Duration}; - impl DelayQueue { - fn len(&self) -> usize { - unsafe { self.0.get().as_ref().unwrap().len() } - } - } - #[test] fn test_works() { let queue = DelayQueue::default(); diff --git a/src/hive/task/global.rs b/src/hive/task/global.rs new file mode 100644 index 0000000..9311ee6 --- /dev/null +++ b/src/hive/task/global.rs @@ -0,0 +1,77 @@ +pub use channel::GlobalQueueImpl as ChannelGlobalQueue; + +mod channel { + use crate::atomic::{Atomic, AtomicBool}; + use crate::bee::Worker; + use crate::hive::{GlobalPopError, GlobalQueue, Task, TaskReceiver, TaskSender}; + use crossbeam_channel::RecvTimeoutError; + use std::time::Duration; + + pub struct GlobalQueueImpl { + tx: TaskSender, + rx: TaskReceiver, + closed: AtomicBool, + } + + impl GlobalQueueImpl { + /// Returns a new `GlobalQueue` that uses the given channel sender for pushing new tasks + /// and the given channel receiver for popping tasks. + pub fn new(tx: TaskSender, rx: TaskReceiver) -> Self { + Self { + tx, + rx, + closed: AtomicBool::default(), + } + } + + #[cfg(feature = "batching")] + pub fn try_iter(&self) -> crossbeam_channel::TryIter> { + self.rx.try_iter() + } + + pub fn try_pop_timeout( + &self, + timeout: Duration, + ) -> Option, GlobalPopError>> { + match self.rx.recv_timeout(timeout) { + Ok(task) => Some(Ok(task)), + Err(RecvTimeoutError::Disconnected) => Some(Err(GlobalPopError::Closed)), + Err(RecvTimeoutError::Timeout) if self.closed.get() && self.rx.is_empty() => { + Some(Err(GlobalPopError::Closed)) + } + Err(RecvTimeoutError::Timeout) => None, + } + } + } + + impl GlobalQueue for GlobalQueueImpl { + fn try_push(&self, task: Task) -> Result<(), Task> { + if !self.closed.get() { + self.tx.try_send(task).map_err(|err| err.into_inner()) + } else { + Err(task) + } + } + + fn try_pop(&self) -> Option, GlobalPopError>> { + // time to wait in between polling the retry queue and then the task receiver + const RECV_TIMEOUT: Duration = Duration::from_secs(1); + self.try_pop_timeout(RECV_TIMEOUT) + } + + fn drain(&self) -> Vec> { + self.rx.try_iter().collect() + } + + fn close(&self) { + self.closed.set(true); + } + } + + impl Default for GlobalQueueImpl { + fn default() -> Self { + let (tx, rx) = crossbeam_channel::unbounded(); + Self::new(tx, rx) + } + } +} diff --git a/src/hive/task/iter.rs b/src/hive/task/iter.rs deleted file mode 100644 index 90bd8b4..0000000 --- a/src/hive/task/iter.rs +++ /dev/null @@ -1,27 +0,0 @@ -use crate::bee::{Queen, Worker}; -use crate::hive::counter::CounterError; -use crate::hive::{LocalQueues, Shared, Task}; -use std::sync::Arc; - -#[derive(thiserror::Error, Debug)] -pub enum NextTaskError { - #[error("Task receiver disconnected")] - Disconnected, - #[error("The hive has been poisoned")] - Poisoned, - #[error("Task counter has invalid state")] - InvalidCounter(CounterError), -} - -pub struct TaskIterator, L: LocalQueues> { - thread_index: usize, - shared: Arc>, -} - -impl, L: LocalQueues> Iterator for TaskIterator { - type Item = Task; - - fn next(&mut self) -> Option { - todo!() - } -} diff --git a/src/hive/task/local.rs b/src/hive/task/local.rs index f7dee9e..23b5ce6 100644 --- a/src/hive/task/local.rs +++ b/src/hive/task/local.rs @@ -1,26 +1,64 @@ #[cfg(any(feature = "batching", feature = "retry"))] -pub use channel::ChannelLocalQueues as LocalQueuesImpl; +pub use channel::LocalQueuesImpl as ChannelLocalQueues; #[cfg(not(any(feature = "batching", feature = "retry")))] -pub use null::NullLocalQueues as LocalQueuesImpl; +pub use null::LocalQueuesImpl as ChannelLocalQueues; #[cfg(not(any(feature = "batching", feature = "retry")))] mod null { - use crate::bee::Worker; - use crate::hive::LocalQueues; + use crate::bee::{Queen, Worker}; + use crate::hive::{ChannelGlobalQueue, LocalQueues, Shared, Task}; use std::marker::PhantomData; - pub struct NullLocalQueues(PhantomData); + pub struct LocalQueuesImpl(PhantomData W>); - impl LocalQueues for NullLocalQueues {} + impl LocalQueues> for LocalQueuesImpl { + fn init_for_threads>( + &self, + _: usize, + _: usize, + _: &Shared, Self>, + ) { + } + + #[inline(always)] + fn push>( + &self, + task: Task, + _: usize, + shared: &Shared, Self>, + ) { + shared.push_global(task); + } + + #[inline(always)] + fn try_pop>( + &self, + _: usize, + _: &Shared, Self>, + ) -> Option> { + None + } + + fn drain(&self) -> Vec> { + Vec::new() + } + } + + impl Default for LocalQueuesImpl { + fn default() -> Self { + Self(PhantomData) + } + } } #[cfg(any(feature = "batching", feature = "retry"))] mod channel { - use crate::bee::Worker; - use crate::hive::{LocalQueues, Task}; + use crate::bee::{Queen, Worker}; + use crate::hive::task::ChannelGlobalQueue; + use crate::hive::{LocalQueues, Shared, Task}; use parking_lot::RwLock; - pub struct ChannelLocalQueues { + pub struct LocalQueuesImpl { /// thread-local queues of tasks used when the `batching` feature is enabled #[cfg(feature = "batching")] batch_queues: RwLock>>>, @@ -29,9 +67,158 @@ mod channel { retry_queues: RwLock>>>, } - impl LocalQueues for ChannelLocalQueues {} + #[cfg(feature = "retry")] + impl LocalQueuesImpl { + #[inline] + fn try_pop_retry(&self, thread_index: usize) -> Option> { + self.retry_queues + .read() + .get(thread_index) + .and_then(|queue| queue.try_pop()) + } + } + + #[cfg(feature = "batching")] + impl LocalQueuesImpl { + // time to wait in between polling the retry queue and then the task receiver + const POP_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(1); - impl Default for ChannelLocalQueues { + #[inline] + fn try_push_local(&self, task: Task, thread_index: usize) -> Result<(), Task> { + self.batch_queues.read()[thread_index].push(task) + } + + #[inline] + fn try_pop_local_or_refill>( + &self, + thread_index: usize, + shared: &Shared, Self>, + ) -> Option> { + let local_queue = &self.batch_queues.read()[thread_index]; + // pop from the local queue if it has any tasks + if !local_queue.is_empty() { + return local_queue.pop(); + } + // otherwise pull at least 1 and up to `batch_size + 1` tasks from the input channel + // wait for the next task from the receiver + let first = shared + .global_queue + .try_pop_timeout(Self::POP_TIMEOUT) + .map(Result::ok) + .flatten(); + // if we fail after trying to get one, don't keep trying to fill the queue + if first.is_some() { + let batch_size = shared.batch_size(); + // batch size 0 means batching is disabled + if batch_size > 0 { + // otherwise try to take up to `batch_size` tasks from the input channel + // and add them to the local queue, but don't block if the input channel + // is empty + for result in shared + .global_queue + .try_iter() + .take(batch_size) + .map(|task| local_queue.push(task)) + { + if let Err(task) = result { + // for some reason we can't push the task to the local queue; + // this should never happen, but just in case we turn it into an + // unprocessed outcome and stop iterating + shared.abandon_task(task); + break; + } + } + } + } + first + } + } + + impl LocalQueues> for LocalQueuesImpl { + fn init_for_threads>( + &self, + start_index: usize, + end_index: usize, + #[allow(unused_variables)] shared: &Shared, Self>, + ) { + #[cfg(feature = "batching")] + self.init_batch_queues_for_threads(start_index, end_index, shared); + #[cfg(feature = "retry")] + self.init_retry_queues_for_threads(start_index, end_index); + } + + #[cfg(feature = "batching")] + fn resize>( + &self, + start_index: usize, + end_index: usize, + new_size: usize, + shared: &Shared, Self>, + ) { + self.resize_batch_queues(start_index, end_index, new_size, shared); + } + + /// Creates a task from `input` and pushes it to the local queue if there is space, + /// otherwise attempts to add it to the global queue. Returns the task ID if the push + /// succeeds, otherwise returns an error with the input. + fn push>( + &self, + task: Task, + #[allow(unused_variables)] thread_index: usize, + shared: &Shared, Self>, + ) { + #[cfg(feature = "batching")] + let task = match self.try_push_local(task, thread_index) { + Ok(_) => return, + Err(task) => task, + }; + shared.push_global(task); + } + + /// Returns the next task from the local queue if there are any, otherwise attempts to + /// fetch at least 1 and up to `batch_size + 1` tasks from the input channel and puts all + /// but the first one into the local queue. + fn try_pop>( + &self, + thread_index: usize, + #[allow(unused_variables)] shared: &Shared, Self>, + ) -> Option> { + #[cfg(feature = "retry")] + if let Some(task) = self.try_pop_retry(thread_index) { + return Some(task); + } + #[cfg(feature = "batching")] + if let Some(task) = self.try_pop_local_or_refill(thread_index, shared) { + return Some(task); + } + None + } + + fn drain(&self) -> Vec> { + let mut tasks = Vec::new(); + #[cfg(feature = "batching")] + { + self.drain_batch_queues_into(&mut tasks); + } + #[cfg(feature = "retry")] + { + self.drain_retry_queues_into(&mut tasks); + } + tasks + } + + #[cfg(feature = "retry")] + fn retry>( + &self, + task: Task, + thread_index: usize, + shared: &Shared, Self>, + ) -> Option { + self.try_push_retry(task, thread_index, shared) + } + } + + impl Default for LocalQueuesImpl { fn default() -> Self { Self { #[cfg(feature = "batching")] @@ -41,4 +228,160 @@ mod channel { } } } + + #[cfg(feature = "batching")] + mod batching { + use super::LocalQueuesImpl; + use crate::bee::{Queen, Worker}; + use crate::hive::{ChannelGlobalQueue, Shared, Task}; + use crossbeam_queue::ArrayQueue; + use std::collections::HashSet; + use std::time::Duration; + + impl LocalQueuesImpl { + pub(super) fn init_batch_queues_for_threads>( + &self, + start_index: usize, + end_index: usize, + shared: &Shared, Self>, + ) { + let mut batch_queues = self.batch_queues.write(); + assert_eq!(batch_queues.len(), start_index); + let queue_size = shared.batch_size().max(1); + (start_index..end_index) + .for_each(|_| batch_queues.push(ArrayQueue::new(queue_size))); + } + + pub(super) fn resize_batch_queues>( + &self, + start_index: usize, + end_index: usize, + batch_size: usize, + shared: &Shared, Self>, + ) { + // keep track of which queues need to be resized + // TODO: this method could cause a hang if one of the worker threads is stuck - we + // might want to keep track of each queue's size and if we don't see it shrink + // within a certain amount of time, we give up on that thread and leave it with a + // wrong-sized queue (which should never cause a panic) + let mut to_resize: HashSet = (start_index..end_index).collect(); + // iterate until we've resized them all + loop { + // scope the mutable access to local_queues + { + let mut batch_queues = self.batch_queues.write(); + to_resize.retain(|thread_index| { + let queue = if let Some(queue) = batch_queues.get_mut(*thread_index) { + queue + } else { + return false; + }; + if queue.len() > batch_size { + return true; + } + let new_queue = ArrayQueue::new(batch_size); + while let Some(task) = queue.pop() { + if let Err(task) = new_queue.push(task) { + // for some reason we can't push the task to the new queue + // this should never happen, but just in case we turn it into + // an unprocessed outcome + shared.abandon_task(task); + } + } + // this is safe because the worker threads can't get readable access to the + // queue while this thread holds the lock + let old_queue = std::mem::replace(queue, new_queue); + assert!(old_queue.is_empty()); + false + }); + } + if !to_resize.is_empty() { + // short sleep to give worker threads the chance to pull from their queues + std::thread::sleep(Duration::from_millis(10)); + } + } + } + + pub(super) fn drain_batch_queues_into(&self, tasks: &mut Vec>) { + let _ = self + .batch_queues + .write() + .iter_mut() + .fold(tasks, |tasks, queue| { + tasks.reserve(queue.len()); + while let Some(task) = queue.pop() { + tasks.push(task); + } + tasks + }); + } + } + } + + #[cfg(feature = "retry")] + mod retry { + use super::LocalQueuesImpl; + use crate::bee::{Queen, Worker}; + use crate::hive::task::delay::DelayQueue; + use crate::hive::{ChannelGlobalQueue, Shared, Task}; + use std::time::{Duration, Instant}; + + impl LocalQueuesImpl { + /// Initializes the retry queues worker threads in the specified range. + pub(super) fn init_retry_queues_for_threads( + &self, + start_index: usize, + end_index: usize, + ) { + let mut retry_queues = self.retry_queues.write(); + assert_eq!(retry_queues.len(), start_index); + (start_index..end_index).for_each(|_| retry_queues.push(DelayQueue::default())) + } + + /// Adds a task to the retry queue with a delay based on `attempt`. + pub(super) fn try_push_retry>( + &self, + task: Task, + thread_index: usize, + shared: &Shared, Self>, + ) -> Option { + // compute the delay + let delay = shared + .config + .retry_factor + .get() + .map(|retry_factor| { + 2u64.checked_pow(task.attempt - 1) + .and_then(|multiplier| { + retry_factor + .checked_mul(multiplier) + .or(Some(u64::MAX)) + .map(Duration::from_nanos) + }) + .unwrap() + }) + .unwrap_or_default(); + if let Some(queue) = self.retry_queues.read().get(thread_index) { + queue.push(task, delay) + } else { + Err(task) + } + // if unable to queue the task, abandon it + .map_err(|task| shared.abandon_task(task)) + .ok() + } + + pub(super) fn drain_retry_queues_into(&self, tasks: &mut Vec>) { + let _ = self + .retry_queues + .write() + .iter_mut() + .fold(tasks, |tasks, queue| { + tasks.reserve(queue.len()); + tasks.extend(queue.drain()); + tasks + }); + } + } + } } diff --git a/src/hive/task/mod.rs b/src/hive/task/mod.rs index e636987..0728e90 100644 --- a/src/hive/task/mod.rs +++ b/src/hive/task/mod.rs @@ -1,41 +1,75 @@ #[cfg(feature = "retry")] mod delay; -mod iter; +mod global; mod local; -pub use local::LocalQueuesImpl; +pub use global::ChannelGlobalQueue; +pub use local::ChannelLocalQueues; use super::{Outcome, OutcomeSender, Task}; -use crate::bee::{Context, TaskId, Worker}; +use crate::bee::{TaskId, Worker}; +impl Task { + /// Returns the ID of this task. + pub fn id(&self) -> TaskId { + self.id + } + + /// Consumes this `Task` and returns a `Outcome::Unprocessed` outcome with the input and ID, + /// and the outcome sender. + pub fn into_unprocessed(self) -> (Outcome, Option>) { + let outcome = Outcome::Unprocessed { + input: self.input, + task_id: self.id, + }; + (outcome, self.outcome_tx) + } +} + +#[cfg(not(feature = "retry"))] impl Task { /// Creates a new `Task`. - pub fn new(input: W::Input, ctx: Context, outcome_tx: Option>) -> Self { + pub fn new(id: TaskId, input: W::Input, outcome_tx: Option>) -> Self { Task { + id, input, - ctx, outcome_tx, } } - /// Returns the ID of this task. - pub fn id(&self) -> TaskId { - self.ctx.task_id() + pub fn into_parts(self) -> (TaskId, W::Input, Option>) { + (self.id, self.input, self.outcome_tx) } +} - /// Consumes this `Task` and returns a tuple `(input, context, outcome_tx)`. - pub fn into_parts(self) -> (W::Input, Context, Option>) { - (self.input, self.ctx, self.outcome_tx) +#[cfg(feature = "retry")] +impl Task { + /// Creates a new `Task`. + pub fn new(id: TaskId, input: W::Input, outcome_tx: Option>) -> Self { + Task { + id, + input, + outcome_tx, + attempt: 0, + } } - /// Consumes this `Task` and returns a `Outcome::Unprocessed` outcome with the input and ID, - /// and the outcome sender. - pub fn into_unprocessed(self) -> (Outcome, Option>) { - let (input, ctx, outcome_tx) = self.into_parts(); - let outcome = Outcome::Unprocessed { + /// Creates a new `Task`. + pub fn with_attempt( + id: TaskId, + input: W::Input, + outcome_tx: Option>, + attempt: u32, + ) -> Self { + Task { + id, input, - task_id: ctx.task_id(), - }; - (outcome, outcome_tx) + outcome_tx, + attempt, + } + } + + pub fn into_parts(self) -> (TaskId, W::Input, u32, Option>) { + (self.id, self.input, self.attempt, self.outcome_tx) } } diff --git a/src/util.rs b/src/util.rs index aab909d..f326057 100644 --- a/src/util.rs +++ b/src/util.rs @@ -154,7 +154,7 @@ mod retry { O: Send + Sync + 'static, E: Send + Sync + Debug + 'static, Inputs: IntoIterator, - F: FnMut(I, &Context) -> Result> + Send + Sync + Clone + 'static, + F: FnMut(I, &Context) -> Result> + Send + Sync + Clone + 'static, { Builder::default() .num_threads(num_threads) From b427c72b966f37b536826fc93ef2a5c2a1db95d6 Mon Sep 17 00:00:00 2001 From: jdidion Date: Fri, 7 Feb 2025 15:42:20 -0800 Subject: [PATCH 05/67] implement non-locking outcome store --- CHANGELOG.md | 5 ++-- src/hive/mod.rs | 6 ++-- src/hive/outcome/mod.rs | 2 ++ src/hive/outcome/queue.rs | 58 +++++++++++++++++++++++++++++++++++++++ src/hive/shared.rs | 14 ++++------ 5 files changed, 71 insertions(+), 14 deletions(-) create mode 100644 src/hive/outcome/queue.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index 66bd282..c39b72e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,12 +3,13 @@ ## 0.3.0 * **Breaking** - * `beekeeper::bee::Context` has been changed from a struct to a trait, and `beekeeper::bee::Worker` now has a generic parameter for the `Context` type. + * `beekeeper::bee::Context` now takes a generic parameter that must be input type of the `Worker`. * Features * Added the `batching` feature, which enables worker threads to queue up batches of tasks locally, which can alleviate contention between threads in the pool, especially when there are many short-lived tasks. * Added the `Context::submit` method, which enables tasks to submit new tasks to the `Hive`. * Other - * Switched to using thread-local retry queues for the implementation of the `retry` feature, to reduce thread-contention + * Switched to using thread-local retry queues for the implementation of the `retry` feature, to reduce thread-contention. + * Switched to storing `Outcome`s in the hive using a data structure that does not require locking when inserting, which should reduce thread contention when using `*_store` operations. ## 0.2.1 diff --git a/src/hive/mod.rs b/src/hive/mod.rs index a912049..d476d70 100644 --- a/src/hive/mod.rs +++ b/src/hive/mod.rs @@ -380,7 +380,7 @@ pub use self::config::{ set_max_retries_default, set_retries_default_disabled, set_retry_factor_default, }; pub use self::husk::Husk; -pub use self::outcome::{Outcome, OutcomeBatch, OutcomeIteratorExt, OutcomeStore}; +pub use self::outcome::{Outcome, OutcomeBatch, OutcomeIteratorExt, OutcomeQueue, OutcomeStore}; /// Sender type for channel used to send task outcomes. pub type OutcomeSender = crate::channel::Sender>; @@ -408,7 +408,6 @@ use self::task::{ChannelGlobalQueue, ChannelLocalQueues}; use crate::atomic::{AtomicAny, AtomicBool, AtomicOption, AtomicUsize}; use crate::bee::{Queen, TaskId, Worker}; use parking_lot::Mutex; -use std::collections::HashMap; use std::io::Error as SpawnError; use std::sync::Arc; use std::thread::JoinHandle; @@ -499,8 +498,7 @@ struct Shared, G: GlobalQueue, L: LocalQueues>>, + outcomes: OutcomeQueue, } #[derive(thiserror::Error, Debug)] diff --git a/src/hive/outcome/mod.rs b/src/hive/outcome/mod.rs index bd16c81..e908aef 100644 --- a/src/hive/outcome/mod.rs +++ b/src/hive/outcome/mod.rs @@ -2,11 +2,13 @@ mod batch; mod iter; #[allow(clippy::module_inception)] mod outcome; +mod queue; mod store; pub use batch::OutcomeBatch; pub use iter::OutcomeIteratorExt; pub use outcome::Outcome; +pub use queue::OutcomeQueue; pub use store::OutcomeStore; pub(super) use store::sealed::{DerefOutcomes, OwnedOutcomes}; diff --git a/src/hive/outcome/queue.rs b/src/hive/outcome/queue.rs new file mode 100644 index 0000000..b109983 --- /dev/null +++ b/src/hive/outcome/queue.rs @@ -0,0 +1,58 @@ +use super::Outcome; +use crate::bee::Worker; +use crate::hive::TaskId; +use crossbeam_queue::SegQueue; +use parking_lot::Mutex; +use std::collections::HashMap; +use std::ops::DerefMut; + +pub struct OutcomeQueue { + queue: SegQueue>, + outcomes: Mutex>>, +} + +impl OutcomeQueue { + /// Adds an outcome to the queue. + pub fn push(&self, outcome: Outcome) { + self.queue.push(outcome); + } + + /// Flushes the queue into the map of outcomes and returns a mutable reference to the map. + pub fn get_mut(&self) -> impl DerefMut>> + '_ { + let mut outcomes = self.outcomes.lock(); + // add any queued outcomes to the map + while let Some(outcome) = self.queue.pop() { + outcomes.insert(*outcome.task_id(), outcome); + } + outcomes + } + + /// Flushes the queue into the map of outcomes, then takes all outcomes from the map and + /// returns them. + pub fn drain(&self) -> HashMap> { + let mut outcomes: HashMap> = self.outcomes.lock().drain().collect(); + // add any queued outcomes to the map + while let Some(outcome) = self.queue.pop() { + outcomes.insert(*outcome.task_id(), outcome); + } + outcomes + } + + pub fn into_inner(self) -> HashMap> { + let mut outcomes = self.outcomes.into_inner(); + // add any queued outcomes to the map + while let Some(outcome) = self.queue.pop() { + outcomes.insert(*outcome.task_id(), outcome); + } + outcomes + } +} + +impl Default for OutcomeQueue { + fn default() -> Self { + Self { + queue: Default::default(), + outcomes: Default::default(), + } + } +} diff --git a/src/hive/shared.rs b/src/hive/shared.rs index 91cdc72..5a7d8be 100644 --- a/src/hive/shared.rs +++ b/src/hive/shared.rs @@ -9,7 +9,7 @@ use parking_lot::Mutex; use std::collections::HashMap; use std::ops::DerefMut; use std::thread::{Builder, JoinHandle}; -use std::{fmt, iter, mem}; +use std::{fmt, iter}; impl Shared where @@ -335,7 +335,7 @@ where Some(outcome) } { outcomes - .get_or_insert_with(|| self.outcomes.lock()) + .get_or_insert_with(|| self.outcomes.get_mut()) .insert(task_id, outcome); } task_id @@ -451,24 +451,22 @@ where /// Returns a mutable reference to the retained task outcomes. pub fn outcomes(&self) -> impl DerefMut>> + '_ { - self.outcomes.lock() + self.outcomes.get_mut() } /// Adds a new outcome to the retained task outcomes. pub fn add_outcome(&self, outcome: Outcome) { - let mut lock = self.outcomes.lock(); - lock.insert(*outcome.task_id(), outcome); + self.outcomes.push(outcome); } /// Removes and returns all retained task outcomes. pub fn take_outcomes(&self) -> HashMap> { - let mut lock = self.outcomes.lock(); - mem::take(&mut *lock) + self.outcomes.drain() } /// Removes and returns all retained `Unprocessed` outcomes. pub fn take_unprocessed(&self) -> Vec> { - let mut outcomes = self.outcomes.lock(); + let mut outcomes = self.outcomes.get_mut(); let unprocessed_task_ids: Vec<_> = outcomes .keys() .cloned() From 57a54ada0f2b6c75dd55246c8a0ca50acdbaae5b Mon Sep 17 00:00:00 2001 From: jdidion Date: Fri, 7 Feb 2025 16:01:48 -0800 Subject: [PATCH 06/67] fix lints --- Cargo.toml | 4 ++-- src/bee/context.rs | 14 +++++--------- src/hive/channel.rs | 15 +++++++++------ src/hive/mod.rs | 9 +++++---- src/hive/shared.rs | 24 +++++++++++++++--------- src/hive/task/local.rs | 3 +-- 6 files changed, 37 insertions(+), 32 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 25c7f68..45bcdb8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ name = "beekeeper" description = "A full-featured worker pool library for parallelizing tasks" version = "0.3.0" edition = "2021" -rust-version = "1.80" +rust-version = "1.83" authors = ["John Didion "] repository = "https://github.com/jdidion/beekeeper" license = "MIT OR Apache-2.0" @@ -37,7 +37,7 @@ name = "perf" harness = false [features] -default = ["batching", "retry"] +default = ["affinity", "batching", "retry"] affinity = ["dep:core_affinity"] batching = ["dep:crossbeam-queue"] retry = [] diff --git a/src/bee/context.rs b/src/bee/context.rs index 0eec520..0d5b10f 100644 --- a/src/bee/context.rs +++ b/src/bee/context.rs @@ -16,13 +16,13 @@ pub trait TaskContext: Debug { #[derive(Debug)] pub struct Context<'a, I> { task_id: TaskId, - task_ctx: Option>>, + task_ctx: Option<&'a dyn TaskContext>, subtask_ids: Option>, #[cfg(feature = "retry")] attempt: u32, } -impl<'a, I> Context<'a, I> { +impl Context<'_, I> { /// The task_id of this task within the `Hive`. pub fn task_id(&self) -> TaskId { self.task_id @@ -65,7 +65,7 @@ impl<'a, I> Context<'a, I> { } #[cfg(not(feature = "retry"))] -impl<'a, I> Context<'a, I> { +impl Context<'_, I> { /// Returns a new empty context. This is primarily useful for testing. pub fn empty() -> Self { Self { @@ -76,7 +76,7 @@ impl<'a, I> Context<'a, I> { } /// Creates a new `Context` with the given task_id and shared cancellation status. - pub fn new(task_id: TaskId, task_ctx: Option>>) -> Self { + pub fn new(task_id: TaskId, task_ctx: Option<&'a dyn TaskContext>) -> Self { Self { task_id, task_ctx, @@ -105,11 +105,7 @@ impl<'a, I> Context<'a, I> { } /// Creates a new `Context` with the given task_id and shared cancellation status. - pub fn new( - task_id: TaskId, - attempt: u32, - task_ctx: Option>>, - ) -> Self { + pub fn new(task_id: TaskId, attempt: u32, task_ctx: Option<&'a dyn TaskContext>) -> Self { Self { task_id, attempt, diff --git a/src/hive/channel.rs b/src/hive/channel.rs index 5bfe6a3..ab156cb 100644 --- a/src/hive/channel.rs +++ b/src/hive/channel.rs @@ -58,7 +58,7 @@ impl> Hive { #[inline] fn shared(&self) -> &Arc, ChannelLocalQueues>> { - &self.0.as_ref().unwrap() + self.0.as_ref().unwrap() } /// Attempts to increase the number of worker threads by `num_threads`. Returns the number of @@ -733,12 +733,15 @@ mod no_affinity { mod affinity { use crate::bee::{Queen, Worker}; use crate::hive::cores::Cores; - use crate::hive::{Hive, Poisoned, Shared}; + use crate::hive::{GlobalQueue, Hive, LocalQueues, Poisoned, Shared}; impl> Hive { /// Tries to pin the worker thread to a specific CPU core. #[inline] - pub(super) fn init_thread(thread_index: usize, shared: &Shared) { + pub(super) fn init_thread, L: LocalQueues>( + thread_index: usize, + shared: &Shared, + ) { if let Some(core) = shared.get_core_affinity(thread_index) { core.try_pin_current(); } @@ -803,7 +806,7 @@ where outcome_tx: Option<&'a OutcomeSender>, } -impl<'a, W, Q, G, L> TaskContext for HiveTaskContext<'a, W, Q, G, L> +impl TaskContext for HiveTaskContext<'_, W, Q, G, L> where W: Worker, Q: Queen, @@ -820,7 +823,7 @@ where } } -impl<'a, W, Q, G, L> fmt::Debug for HiveTaskContext<'a, W, Q, G, L> +impl fmt::Debug for HiveTaskContext<'_, W, Q, G, L> where W: Worker, Q: Queen, @@ -881,7 +884,7 @@ mod retry { shared, outcome_tx: outcome_tx.as_ref(), }; - let ctx = Context::new(task_id, attempt, Some(Box::new(&task_ctx))); + let ctx = Context::new(task_id, attempt, Some(&task_ctx)); // execute the task until it succeeds or we reach maximum retries - this should // be the only place where a panic can occur let result = worker.apply(input, &ctx); diff --git a/src/hive/mod.rs b/src/hive/mod.rs index d476d70..1b7964e 100644 --- a/src/hive/mod.rs +++ b/src/hive/mod.rs @@ -423,6 +423,7 @@ type U64 = AtomicOption; /// /// See the [module documentation](crate::hive) for details. pub struct Hive>( + #[allow(clippy::type_complexity)] Option, ChannelLocalQueues>>>, ); @@ -1842,7 +1843,7 @@ mod affinity_tests { .core_affinity(0..2) .build_with_default::>(); - channel.map_store((0..10).map(move |i| { + hive.map_store((0..10).map(move |i| { Thunk::of(move || { if let Some(affininty) = core_affinity::get_core_ids() { eprintln!("task {} on thread with affinity {:?}", i, affininty); @@ -1859,7 +1860,7 @@ mod affinity_tests { .with_default_core_affinity() .build_with_default::>(); - channel.map_store((0..num_cpus::get()).map(move |i| { + hive.map_store((0..num_cpus::get()).map(move |i| { Thunk::of(move || { if let Some(affininty) = core_affinity::get_core_ids() { eprintln!("task {} on thread with affinity {:?}", i, affininty); @@ -1897,7 +1898,7 @@ mod batching_tests { thread::sleep(Duration::from_millis(100)); thread::current().id() }), - &tx, + tx, ); thread::sleep(Duration::from_millis(100)); task_id @@ -1911,7 +1912,7 @@ mod batching_tests { thread::current().id() }) }), - &tx, + tx, ); init_task_ids.into_iter().chain(rest_task_ids).collect() } diff --git a/src/hive/shared.rs b/src/hive/shared.rs index 5a7d8be..da8007e 100644 --- a/src/hive/shared.rs +++ b/src/hive/shared.rs @@ -1,4 +1,3 @@ -use super::task::ChannelGlobalQueue; use super::{ Config, GlobalQueue, Husk, LocalQueues, Outcome, OutcomeSender, Shared, SpawnError, Task, }; @@ -198,7 +197,7 @@ where ) -> TaskId { let task = self.prepare_task(input, outcome_tx); let task_id = task.id(); - self.local_queues.push(task, thread_index, &self); + self.local_queues.push(task, thread_index, self); task_id } @@ -272,7 +271,7 @@ where return None; } // try to get a task from the local queues - if let Some(task) = self.local_queues.try_pop(thread_index, &self) { + if let Some(task) = self.local_queues.try_pop(thread_index, self) { break Ok(task); } // fall back to requesting a task from the global queue @@ -501,11 +500,12 @@ where } } -impl fmt::Debug for Shared, L> +impl fmt::Debug for Shared where W: Worker, Q: Queen, - L: LocalQueues>, + G: GlobalQueue, + L: LocalQueues, { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let (queued, active) = self.num_tasks(); @@ -522,9 +522,15 @@ where mod affinity { use crate::bee::{Queen, Worker}; use crate::hive::cores::{Core, Cores}; - use crate::hive::Shared; + use crate::hive::{GlobalQueue, LocalQueues, Shared}; - impl> Shared { + impl Shared + where + W: Worker, + Q: Queen, + G: GlobalQueue, + L: LocalQueues, + { /// Adds cores to which worker threads may be pinned. pub fn add_core_affinity(&self, new_cores: Cores) { let _ = self.config.affinity.try_update_with(|mut affinity| { @@ -581,7 +587,7 @@ mod batching { if num_threads == 0 { return prev_batch_size; } - self.local_queues.resize(0, num_threads, batch_size, &self); + self.local_queues.resize(0, num_threads, batch_size, self); prev_batch_size } } @@ -624,7 +630,7 @@ mod retry { .increment_left(1) .expect("overflowed queued task counter"); let task = Task::with_attempt(task_id, input, outcome_tx, attempt); - self.local_queues.retry(task, thread_index, &self) + self.local_queues.retry(task, thread_index, self) } } } diff --git a/src/hive/task/local.rs b/src/hive/task/local.rs index 23b5ce6..1a8e1e5 100644 --- a/src/hive/task/local.rs +++ b/src/hive/task/local.rs @@ -104,8 +104,7 @@ mod channel { let first = shared .global_queue .try_pop_timeout(Self::POP_TIMEOUT) - .map(Result::ok) - .flatten(); + .and_then(Result::ok); // if we fail after trying to get one, don't keep trying to fill the queue if first.is_some() { let batch_size = shared.batch_size(); From d4c8977dccdb7fe53c4c30700646600e72caa237 Mon Sep 17 00:00:00 2001 From: jdidion Date: Tue, 11 Feb 2025 10:14:17 -0800 Subject: [PATCH 07/67] WIP --- CHANGELOG.md | 6 + benches/perf.rs | 28 +- src/bee/context.rs | 17 +- src/bee/mod.rs | 10 +- src/bee/queen.rs | 45 +- src/hive/builder.rs | 20 +- src/hive/channel.rs | 1002 ------------ src/hive/husk.rs | 33 +- src/hive/mod.rs | 3044 +++++++++++++++++++------------------ src/hive/outcome/store.rs | 5 +- src/hive/shared.rs | 59 +- src/hive/task/delay.rs | 164 -- src/hive/task/global.rs | 77 - src/hive/task/local.rs | 386 ----- src/hive/task/mod.rs | 75 - src/util.rs | 12 +- 16 files changed, 1695 insertions(+), 3288 deletions(-) delete mode 100644 src/hive/channel.rs delete mode 100644 src/hive/task/delay.rs delete mode 100644 src/hive/task/global.rs delete mode 100644 src/hive/task/local.rs delete mode 100644 src/hive/task/mod.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index c39b72e..4fc475b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,14 +2,20 @@ ## 0.3.0 +The general theme of this release is performance improvement by eliminating thread contention due to unnecessary locking of shared state. This required making some breaking changes to the API. + * **Breaking** + * `beekeeper::bee::Queen::create` now takes `&self` rather than `&mut self`. There is a new type, `beekeeper::bee::QueenMut`, with a `create(&mut self)` method, and needs to be wrapped in a `beekeeper::bee::QueenCell` to implement the `Queen` trait. This enables the `Hive` to create new workers without locking in the case of a `Queen` that does not need mutable state. * `beekeeper::bee::Context` now takes a generic parameter that must be input type of the `Worker`. * Features * Added the `batching` feature, which enables worker threads to queue up batches of tasks locally, which can alleviate contention between threads in the pool, especially when there are many short-lived tasks. * Added the `Context::submit` method, which enables tasks to submit new tasks to the `Hive`. * Other + * `beekeeper::hive::Hive` now has additional generic parameters for the global and local queue types. These default to `beekeeper::hive::ChannelGlobalQueue` and `beekeeper::hive::DefaultLocalQueues`, which provide the same behavior as before. + * `beekeeper::hive::ChannelHive` is the existing `Hive` implementation. * Switched to using thread-local retry queues for the implementation of the `retry` feature, to reduce thread-contention. * Switched to storing `Outcome`s in the hive using a data structure that does not require locking when inserting, which should reduce thread contention when using `*_store` operations. + * Switched to using `crossbeam_channel` for the `Hive`'s task input channel. ## 0.2.1 diff --git a/benches/perf.rs b/benches/perf.rs index 35715ac..6314462 100644 --- a/benches/perf.rs +++ b/benches/perf.rs @@ -9,20 +9,20 @@ static ALLOC: AllocProfiler = AllocProfiler::system(); const THREADS: &[usize] = &[1, 4, 8, 16]; const TASKS: &[usize] = &[1, 100, 10_000, 1_000_000]; -#[bench(args = iproduct!(THREADS, TASKS))] -fn bench_apply_short_task(bencher: Bencher, (num_threads, num_tasks): (&usize, &usize)) { - let hive = Builder::new() - .num_threads(*num_threads) - .build_with_default::>(); - bencher.bench_local(|| { - let (tx, rx) = outcome_channel(); - for i in 0..*num_tasks { - hive.apply_send(i, &tx); - } - hive.join(); - rx.into_iter().take(*num_tasks).for_each(black_box_drop); - }) -} +// #[bench(args = iproduct!(THREADS, TASKS))] +// fn bench_apply_short_task(bencher: Bencher, (num_threads, num_tasks): (&usize, &usize)) { +// let hive = Builder::new() +// .num_threads(*num_threads) +// .build_with_default::>(); +// bencher.bench_local(|| { +// let (tx, rx) = outcome_channel(); +// for i in 0..*num_tasks { +// hive.apply_send(i, &tx); +// } +// hive.join(); +// rx.into_iter().take(*num_tasks).for_each(black_box_drop); +// }) +// } fn main() { divan::main(); diff --git a/src/bee/context.rs b/src/bee/context.rs index 0d5b10f..09abb48 100644 --- a/src/bee/context.rs +++ b/src/bee/context.rs @@ -7,7 +7,7 @@ pub type TaskId = usize; /// task execution. pub trait TaskContext: Debug { /// Returns `true` if tasks in progress should be cancelled. - fn cancel_tasks(&self) -> bool; + fn should_cancel_tasks(&self) -> bool; /// Submits a new task to the `Hive` that is executing the current task. fn submit_task(&self, input: I) -> TaskId; @@ -23,19 +23,19 @@ pub struct Context<'a, I> { } impl Context<'_, I> { - /// The task_id of this task within the `Hive`. + /// The unique ID of this task within the `Hive`. pub fn task_id(&self) -> TaskId { self.task_id } - /// Returns `true` if the task has been cancelled. + /// Returns `true` if the current task should be cancelled. /// /// A long-running `Worker` should check this periodically and, if it returns `true`, exit /// early with an `ApplyError::Cancelled` result. pub fn is_cancelled(&self) -> bool { self.task_ctx .as_ref() - .map(|worker| worker.cancel_tasks()) + .map(|worker| worker.should_cancel_tasks()) .unwrap_or(false) } @@ -43,7 +43,8 @@ impl Context<'_, I> { /// /// If a thread-local queue is available and has capacity, the task will be added to it, /// otherwise it is added to the global queue. The ID of the submitted task is stored in this - /// `Context` and ultimately returned in the `Outcome` of the submitting task. + /// `Context` and ultimately returned in the `subtask_ids` of the `Outcome` of the submitting + /// task. /// /// The task will be submitted with the same outcome sender as the current task, or stored in /// the `Hive` if there is no sender. @@ -59,13 +60,15 @@ impl Context<'_, I> { } } + /// Consumes this `Context` and returns the IDs of the subtasks spawned during the execution + /// of the task, if any. pub(crate) fn into_subtask_ids(self) -> Option> { self.subtask_ids } } #[cfg(not(feature = "retry"))] -impl Context<'_, I> { +impl<'a, I> Context<'a, I> { /// Returns a new empty context. This is primarily useful for testing. pub fn empty() -> Self { Self { @@ -84,7 +87,7 @@ impl Context<'_, I> { } } - /// The number of previous attempts to execute the current task. + /// The number of previous failed attempts to execute the current task. /// /// Always returns `0`. pub fn attempt(&self) -> u32 { diff --git a/src/bee/mod.rs b/src/bee/mod.rs index 56662d7..75998b1 100644 --- a/src/bee/mod.rs +++ b/src/bee/mod.rs @@ -115,16 +115,14 @@ mod queen; pub mod stock; mod worker; -pub use context::{Context, TaskId}; +pub use context::{Context, TaskContext, TaskId}; pub use error::{ApplyError, ApplyRefError}; -pub use queen::{CloneQueen, DefaultQueen, Queen}; +pub use queen::{CloneQueen, DefaultQueen, Queen, QueenCell, QueenMut}; pub use worker::{RefWorker, RefWorkerResult, Worker, WorkerError, WorkerResult}; -pub(crate) use context::TaskContext; - pub mod prelude { pub use super::{ - ApplyError, ApplyRefError, Context, Queen, RefWorker, RefWorkerResult, Worker, WorkerError, - WorkerResult, + ApplyError, ApplyRefError, Context, Queen, QueenMut, RefWorker, RefWorkerResult, Worker, + WorkerError, WorkerResult, }; } diff --git a/src/bee/queen.rs b/src/bee/queen.rs index aa4ab97..64ba7b1 100644 --- a/src/bee/queen.rs +++ b/src/bee/queen.rs @@ -1,16 +1,53 @@ //! The Queen bee trait. use super::Worker; +use parking_lot::RwLock; use std::marker::PhantomData; -/// A trait for stateful factories that create `Worker`s. +/// A trait for factories that create `Worker`s. pub trait Queen: Send + Sync + 'static { /// The kind of `Worker` created by this factory. type Kind: Worker; - /// Returns a new instance of `Self::Kind`. + /// Creates and returns a new instance of `Self::Kind`, *immutably*. + fn create(&self) -> Self::Kind; +} + +/// A trait for mutable factories that create `Worker`s. +pub trait QueenMut: Send + Sync + 'static { + /// The kind of `Worker` created by this factory. + type Kind: Worker; + + /// Creates and returns a new instance of `Self::Kind`, *immutably*. fn create(&mut self) -> Self::Kind; } +/// A wrapper for a `MutQueen` that implements `Queen` using an `RwLock` internally. +pub struct QueenCell(RwLock); + +impl QueenCell { + pub fn new(mut_queen: Q) -> Self { + Self(RwLock::new(mut_queen)) + } + + pub fn into_inner(self) -> Q { + self.0.into_inner() + } +} + +impl Queen for QueenCell { + type Kind = Q::Kind; + + fn create(&self) -> Self::Kind { + self.0.write().create() + } +} + +impl Default for QueenCell { + fn default() -> Self { + Self::new(Q::default()) + } +} + /// A `Queen` that can create a `Worker` type that implements `Default`. /// /// Note that, for the implementation to be generic, `W` also needs to be `Send` and `Sync`. If you @@ -49,7 +86,7 @@ pub struct DefaultQueen(PhantomData); impl Queen for DefaultQueen { type Kind = W; - fn create(&mut self) -> Self::Kind { + fn create(&self) -> Self::Kind { Self::Kind::default() } } @@ -68,7 +105,7 @@ impl CloneQueen { impl Queen for CloneQueen { type Kind = W; - fn create(&mut self) -> Self::Kind { + fn create(&self) -> Self::Kind { self.0.clone() } } diff --git a/src/hive/builder.rs b/src/hive/builder.rs index 5532e3e..6fdc4e5 100644 --- a/src/hive/builder.rs +++ b/src/hive/builder.rs @@ -1,4 +1,4 @@ -use super::{Config, Hive}; +use super::{Config, Hive, QueuePair}; use crate::bee::{CloneQueen, DefaultQueen, Queen, Worker}; /// A `Builder` for a [`Hive`](crate::hive::Hive). @@ -259,7 +259,10 @@ impl Builder { /// assert_eq!(husk.queen().num_workers, 8); /// # } /// ``` - pub fn build(self, queen: Q) -> Hive { + pub fn build>( + self, + queen: Q, + ) -> Hive { Hive::new(self.0, queen) } @@ -267,7 +270,9 @@ impl Builder { /// [`Q::default()`](std::default::Default) to create [`Worker`]s. /// /// Returns an error if there was an error spawning the worker threads. - pub fn build_default(self) -> Hive { + pub fn build_default>( + self, + ) -> Hive { Hive::new(self.0, Q::default()) } @@ -323,7 +328,10 @@ impl Builder { /// assert_eq!(sum, 8920); /// # } /// ``` - pub fn build_with(self, worker: W) -> Hive> + pub fn build_with>( + self, + worker: W, + ) -> Hive, P::Global, P::Local, P> where W: Worker + Send + Sync + Clone, { @@ -377,7 +385,9 @@ impl Builder { /// assert_eq!(sum, -25); /// # } /// ``` - pub fn build_with_default(self) -> Hive> + pub fn build_with_default>( + self, + ) -> Hive, P::Global, P::Local, P> where W: Worker + Send + Sync + Default, { diff --git a/src/hive/channel.rs b/src/hive/channel.rs deleted file mode 100644 index ab156cb..0000000 --- a/src/hive/channel.rs +++ /dev/null @@ -1,1002 +0,0 @@ -use super::prelude::*; -use super::task::{ChannelGlobalQueue, ChannelLocalQueues}; -use super::{Config, DerefOutcomes, GlobalQueue, LocalQueues, OutcomeSender, Shared, SpawnError}; -use crate::atomic::Atomic; -use crate::bee::{DefaultQueen, Queen, TaskContext, TaskId, Worker}; -use crossbeam_utils::Backoff; -use std::collections::HashMap; -use std::fmt; -use std::ops::{Deref, DerefMut}; -use std::sync::Arc; -use std::thread::{self, JoinHandle}; - -#[derive(thiserror::Error, Debug)] -#[error("The hive has been poisoned")] -pub struct Poisoned; - -impl> Hive { - /// Spawns a new worker thread with the specified index and with access to the `shared` data. - fn try_spawn, L: LocalQueues>( - thread_index: usize, - shared: &Arc>, - ) -> Result, SpawnError> { - let thread_builder = shared.thread_builder(); - let shared = Arc::clone(shared); - // spawn a thread that executes the worker loop - thread_builder.spawn(move || { - // perform one-time initialization of the worker thread - Self::init_thread(thread_index, &shared); - // create a Sentinel that will spawn a new thread on panic until it is cancelled - let sentinel = Sentinel::new(thread_index, Arc::clone(&shared)); - // create a new worker to process tasks - let mut worker = shared.create_worker(); - // execute the main loop: get the next task to process, which decrements the queued - // counter and increments the active counter - while let Some(task) = shared.get_next_task(thread_index) { - // execute the task and dispose of the outcome - Self::execute(task, thread_index, &mut worker, &shared); - // finish the task - decrements the active counter and notifies other threads - shared.finish_task(false); - } - // this is only reachable when the main loop exits due to the task receiver having - // disconnected; cancel the Sentinel so this thread won't be re-spawned on drop - sentinel.cancel(); - }) - } - - /// Creates a new `Hive`. This should only be called from `Builder`. - /// - /// The `Hive` will attempt to spawn the configured number of worker threads - /// (`config.num_threads`) but the actual number of threads available may be lower if there - /// are any errors during spawning. - pub(super) fn new(config: Config, queen: Q) -> Self { - let global_queue = ChannelGlobalQueue::default(); - let shared = Arc::new(Shared::new(config.into_sync(), global_queue, queen)); - shared.init_threads(|thread_index| Self::try_spawn(thread_index, &shared)); - Self(Some(shared)) - } - - #[inline] - fn shared(&self) -> &Arc, ChannelLocalQueues>> { - self.0.as_ref().unwrap() - } - - /// Attempts to increase the number of worker threads by `num_threads`. Returns the number of - /// new worker threads that were successfully started (which may be fewer than `num_threads`), - /// or a `Poisoned` error if the hive has been poisoned. - pub fn grow(&self, num_threads: usize) -> Result { - if num_threads == 0 { - return Ok(0); - } - let shared = self.shared(); - // do not start any new threads if the hive is poisoned - if shared.is_poisoned() { - return Err(Poisoned); - } - let num_started = shared.grow_threads(num_threads, |thread_index| { - Self::try_spawn(thread_index, shared) - }); - Ok(num_started) - } - - /// Sets the number of worker threads to the number of available CPU cores. Returns the number - /// of new threads that were successfully started (which may be `0`), or a `Poisoned` error if - /// the hive has been poisoned. - pub fn use_all_cores(&self) -> Result { - let num_threads = num_cpus::get().saturating_sub(self.max_workers()); - self.grow(num_threads) - } - - /// Sends one input to the `Hive` for processing and returns its ID. The `Outcome` - /// of the task is sent to the `outcome_tx` channel if provided, otherwise it is retained in - /// the `Hive` for later retrieval. - /// - /// This method is called by all the `*apply*` methods. - #[inline] - fn send_one(&self, input: W::Input, outcome_tx: Option<&OutcomeSender>) -> TaskId { - #[cfg(debug_assertions)] - if self.max_workers() == 0 { - dbg!("WARNING: no worker threads are active for hive"); - } - self.shared().send_one_global(input, outcome_tx) - } - - /// Sends one `input` to the `Hive` for procesing and returns the result, blocking until the - /// result is available. Creates a channel to send the input and receive the outcome. Returns - /// an [`Outcome`] with the task output or an error. - pub fn apply(&self, input: W::Input) -> Outcome { - let (tx, rx) = outcome_channel(); - let task_id = self.send_one(input, Some(&tx)); - rx.recv().unwrap_or_else(|_| Outcome::Missing { task_id }) - } - - /// Sends one `input` to the `Hive` for processing and returns its ID. The [`Outcome`] of - /// the task will be sent to `tx` upon completion. - pub fn apply_send(&self, input: W::Input, tx: &OutcomeSender) -> TaskId { - self.send_one(input, Some(tx)) - } - - /// Sends one `input` to the `Hive` for processing and returns its ID immediately. The - /// [`Outcome`] of the task will be retained and available for later retrieval. - pub fn apply_store(&self, input: W::Input) -> TaskId { - self.send_one(input, None) - } - - /// Sends a `batch` of inputs to the `Hive` for processing, and returns a `Vec` of their - /// task IDs. The [`Outcome`]s of the tasks are sent to the `outcome_tx` channel if provided, - /// otherwise they are retained in the `Hive` for later retrieval. - /// - /// The batch is provided as an [`ExactSizeIterator`], which enables the hive to reserve a - /// range of task IDs (a single atomic operation) rather than one at a time. - /// - /// This method is called by all the `swarm*` methods. - #[inline] - fn send_batch(&self, batch: T, outcome_tx: Option<&OutcomeSender>) -> Vec - where - T: IntoIterator, - T::IntoIter: ExactSizeIterator, - { - #[cfg(debug_assertions)] - if self.max_workers() == 0 { - dbg!("WARNING: no worker threads are active for hive"); - } - self.shared().send_batch_global(batch, outcome_tx) - } - - /// Sends a `batch` of inputs to the `Hive` for processing, and returns an iterator over the - /// [`Outcome`]s in the same order as the inputs. - /// - /// This method is more efficient than [`map`](Self::map) when the input is an - /// [`ExactSizeIterator`]. - pub fn swarm(&self, batch: T) -> impl Iterator> - where - T: IntoIterator, - T::IntoIter: ExactSizeIterator, - { - let (tx, rx) = outcome_channel(); - let task_ids = self.send_batch(batch, Some(&tx)); - rx.select_ordered(task_ids) - } - - /// Sends a `batch` of inputs to the `Hive` for processing, and returns an unordered iterator - /// over the [`Outcome`]s. - /// - /// The `Outcome`s will be sent in the order they are completed; use [`swarm`](Self::swarm) to - /// instead receive the `Outcome`s in the order they were submitted. This method is more - /// efficient than [`map_unordered`](Self::map_unordered) when the input is an - /// [`ExactSizeIterator`]. - pub fn swarm_unordered(&self, batch: T) -> impl Iterator> - where - T: IntoIterator, - T::IntoIter: ExactSizeIterator, - { - let (tx, rx) = outcome_channel(); - let task_ids = self.send_batch(batch, Some(&tx)); - rx.select_unordered(task_ids) - } - - /// Sends a `batch` of inputs to the `Hive` for processing, and returns a [`Vec`] of task IDs. - /// The [`Outcome`]s of the tasks will be sent to `tx` upon completion. - /// - /// This method is more efficient than [`map_send`](Self::map_send) when the input is an - /// [`ExactSizeIterator`]. - pub fn swarm_send(&self, batch: T, outcome_tx: &OutcomeSender) -> Vec - where - T: IntoIterator, - T::IntoIter: ExactSizeIterator, - { - self.send_batch(batch, Some(outcome_tx)) - } - - /// Sends a `batch` of inputs to the `Hive` for processing, and returns a [`Vec`] of task IDs. - /// The [`Outcome`]s of the task are retained and available for later retrieval. - /// - /// This method is more efficient than `map_store` when the input is an [`ExactSizeIterator`]. - pub fn swarm_store(&self, batch: T) -> Vec - where - T: IntoIterator, - T::IntoIter: ExactSizeIterator, - { - self.send_batch(batch, None) - } - - /// Iterates over `inputs` and sends each one to the `Hive` for processing and returns an - /// iterator over the [`Outcome`]s in the same order as the inputs. - /// - /// [`swarm`](Self::swarm) should be preferred when `inputs` is an [`ExactSizeIterator`]. - pub fn map( - &self, - inputs: impl IntoIterator, - ) -> impl Iterator> { - let (tx, rx) = outcome_channel(); - let task_ids: Vec<_> = inputs - .into_iter() - .map(|task| self.apply_send(task, &tx)) - .collect(); - rx.select_ordered(task_ids) - } - - /// Iterates over `inputs`, sends each one to the `Hive` for processing, and returns an - /// iterator over the [`Outcome`]s in order they become available. - /// - /// [`swarm_unordered`](Self::swarm_unordered) should be preferred when `inputs` is an - /// [`ExactSizeIterator`]. - pub fn map_unordered( - &self, - inputs: impl IntoIterator, - ) -> impl Iterator> { - let (tx, rx) = outcome_channel(); - // `map` is required (rather than `inspect`) because we need owned items - let task_ids: Vec<_> = inputs - .into_iter() - .map(|task| self.apply_send(task, &tx)) - .collect(); - rx.select_unordered(task_ids) - } - - /// Iterates over `inputs` and sends each one to the `Hive` for processing. Returns a [`Vec`] - /// of task IDs. The [`Outcome`]s of the tasks will be sent to `tx` upon completion. - /// - /// [`swarm_send`](Self::swarm_send) should be preferred when `inputs` is an - /// [`ExactSizeIterator`]. - pub fn map_send( - &self, - inputs: impl IntoIterator, - tx: &OutcomeSender, - ) -> Vec { - inputs - .into_iter() - .map(|input| self.apply_send(input, tx)) - .collect() - } - - /// Iterates over `inputs` and sends each one to the `Hive` for processing. Returns a [`Vec`] - /// of task IDs. The [`Outcome`]s of the task are retained and available for later retrieval. - /// - /// [`swarm_store`](Self::swarm_store) should be preferred when `inputs` is an - /// [`ExactSizeIterator`]. - pub fn map_store(&self, inputs: impl IntoIterator) -> Vec { - inputs - .into_iter() - .map(|input| self.apply_store(input)) - .collect() - } - - /// Iterates over `items` and calls `f` with a mutable reference to a state value (initialized - /// to `init`) and each item. `F` returns an input that is sent to the `Hive` for processing. - /// Returns an [`OutcomeBatch`] of the outputs and the final state value. - pub fn scan( - &self, - items: impl IntoIterator, - init: St, - f: F, - ) -> (OutcomeBatch, St) - where - F: FnMut(&mut St, T) -> W::Input, - { - let (tx, rx) = outcome_channel(); - let (task_ids, fold_value) = self.scan_send(items, &tx, init, f); - let outcomes = rx.select_unordered(task_ids).into(); - (outcomes, fold_value) - } - - /// Iterates over `items` and calls `f` with a mutable reference to a state value (initialized - /// to `init`) and each item. `F` returns an input that is sent to the `Hive` for processing, - /// or an error. Returns an [`OutcomeBatch`] of the outputs, a [`Vec`] of errors, and the final - /// state value. - pub fn try_scan( - &self, - items: impl IntoIterator, - init: St, - mut f: F, - ) -> (OutcomeBatch, Vec, St) - where - F: FnMut(&mut St, T) -> Result, - { - let (tx, rx) = outcome_channel(); - let (task_ids, errors, fold_value) = items.into_iter().fold( - (Vec::new(), Vec::new(), init), - |(mut task_ids, mut errors, mut acc), inp| { - match f(&mut acc, inp) { - Ok(input) => task_ids.push(self.apply_send(input, &tx)), - Err(err) => errors.push(err), - } - (task_ids, errors, acc) - }, - ); - let outcomes = rx.select_unordered(task_ids).into(); - (outcomes, errors, fold_value) - } - - /// Iterates over `items` and calls `f` with a mutable reference to a state value (initialized - /// to `init`) and each item. `f` returns an input that is sent to the `Hive` for processing. - /// The outputs are sent to `tx` in the order they become available. Returns a [`Vec`] of the - /// task IDs and the final state value. - pub fn scan_send( - &self, - items: impl IntoIterator, - tx: &OutcomeSender, - init: St, - mut f: F, - ) -> (Vec, St) - where - F: FnMut(&mut St, T) -> W::Input, - { - items - .into_iter() - .fold((Vec::new(), init), |(mut task_ids, mut acc), item| { - let input = f(&mut acc, item); - task_ids.push(self.apply_send(input, tx)); - (task_ids, acc) - }) - } - - /// Iterates over `items` and calls `f` with a mutable reference to a state value (initialized - /// to `init`) and each item. `f` returns an input that is sent to the `Hive` for processing, - /// or an error. The outputs are sent to `tx` in the order they become available. This - /// function returns the final state value and a [`Vec`] of results, where each result is - /// either a task ID or an error. - pub fn try_scan_send( - &self, - items: impl IntoIterator, - tx: &OutcomeSender, - init: St, - mut f: F, - ) -> (Vec>, St) - where - F: FnMut(&mut St, T) -> Result, - { - items - .into_iter() - .fold((Vec::new(), init), |(mut results, mut acc), inp| { - results.push(f(&mut acc, inp).map(|input| self.apply_send(input, tx))); - (results, acc) - }) - } - - /// Iterates over `items` and calls `f` with a mutable reference to a state value (initialized - /// to `init`) and each item. `f` returns an input that is sent to the `Hive` for processing. - /// This function returns the final state value and a [`Vec`] of task IDs. The [`Outcome`]s of - /// the tasks are retained and available for later retrieval. - pub fn scan_store( - &self, - items: impl IntoIterator, - init: St, - mut f: F, - ) -> (Vec, St) - where - F: FnMut(&mut St, T) -> W::Input, - { - items - .into_iter() - .fold((Vec::new(), init), |(mut task_ids, mut acc), item| { - let input = f(&mut acc, item); - task_ids.push(self.apply_store(input)); - (task_ids, acc) - }) - } - - /// Iterates over `items` and calls `f` with a mutable reference to a state value (initialized - /// to `init`) and each item. `f` returns an input that is sent to the `Hive` for processing, - /// or an error. This function returns the final value of the state value and a [`Vec`] of - /// results, where each result is either a task ID or an error. The [`Outcome`]s of the - /// tasks are retained and available for later retrieval. - pub fn try_scan_store( - &self, - items: impl IntoIterator, - init: St, - mut f: F, - ) -> (Vec>, St) - where - F: FnMut(&mut St, T) -> Result, - { - items - .into_iter() - .fold((Vec::new(), init), |(mut results, mut acc), item| { - results.push(f(&mut acc, item).map(|input| self.apply_store(input))); - (results, acc) - }) - } - - /// Blocks the calling thread until all tasks finish. - pub fn join(&self) { - (self.shared()).wait_on_done(); - } - - /// Returns the [`MutexGuard`](parking_lot::MutexGuard) for the [`Queen`]. - /// - /// Note that the `Queen` will remain locked until the returned guard is dropped, and that - /// locking the `Queen` prevents new worker threads from being started. - pub fn queen(&self) -> impl Deref + '_ { - (self.shared()).queen.lock() - } - - /// Returns the number of worker threads that have been requested, i.e., the maximum number of - /// tasks that could be processed concurrently. This may be greater than - /// [`active_workers`](Self::active_workers) if any of the worker threads failed to start. - pub fn max_workers(&self) -> usize { - (self.shared()).config.num_threads.get_or_default() - } - - /// Returns the number of worker threads that have been successfully started. This may be - /// fewer than [`max_workers`](Self::max_workers) if any of the worker threads failed to start. - pub fn alive_workers(&self) -> usize { - (self.shared()) - .spawn_results - .lock() - .iter() - .filter(|result| result.is_ok()) - .count() - } - - /// Returns `true` if there are any "dead" worker threads that failed to spawn. - pub fn has_dead_workers(&self) -> bool { - (self.shared()) - .spawn_results - .lock() - .iter() - .any(|result| result.is_err()) - } - - /// Attempts to respawn any dead worker threads. Returns the number of worker threads that were - /// successfully respawned. - pub fn revive_workers(&self) -> usize { - let shared = self.shared(); - shared.respawn_dead_threads(|thread_index| Self::try_spawn(thread_index, shared)) - } - - /// Returns the number of tasks currently (queued for processing, being processed). - pub fn num_tasks(&self) -> (u64, u64) { - (self.shared()).num_tasks() - } - - /// Returns the number of times one of this `Hive`'s worker threads has panicked. - pub fn num_panics(&self) -> usize { - (self.shared()).num_panics.get() - } - - /// Returns `true` if this `Hive` has been poisoned - i.e., its internal state has been - /// corrupted such that it is no longer able to process tasks. - /// - /// Note that, when a `Hive` is poisoned, it is still possible to call methods that extract - /// its stored [`Outcome`]s (e.g., [`take_stored`](Self::take_stored)) or consume it (e.g., - /// [`try_into_husk`](Self::try_into_husk)). - pub fn is_poisoned(&self) -> bool { - (self.shared()).is_poisoned() - } - - /// Returns `true` if the suspended flag is set. - pub fn is_suspended(&self) -> bool { - (self.shared()).is_suspended() - } - - /// Sets the suspended flag, which notifies worker threads that they a) MAY terminate their - /// current task early (returning an [`Outcome::Unprocessed`]), and b) MUST not accept new - /// tasks, and instead block until the suspended flag is cleared. - /// - /// Call [`resume`](Self::resume) to unset the suspended flag and continue processing tasks. - /// - /// Note: this does *not* prevent new tasks from being queued, and there is a window of time - /// (~1 second) after the suspended flag is set within which a worker thread may still accept a - /// new task. - /// - /// # Examples - /// - /// ``` - /// use beekeeper::bee::stock::{Thunk, ThunkWorker}; - /// use beekeeper::hive::Builder; - /// use std::thread; - /// use std::time::Duration; - /// - /// # fn main() { - /// let hive = Builder::new() - /// .num_threads(4) - /// .build_with_default::>(); - /// hive.map((0..10).map(|_| Thunk::of(|| thread::sleep(Duration::from_secs(3))))); - /// thread::sleep(Duration::from_secs(1)); // Allow first set of tasks to be started. - /// // There should be 4 active tasks and 6 queued tasks. - /// hive.suspend(); - /// assert_eq!(hive.num_tasks(), (6, 4)); - /// // Wait for active tasks to complete. - /// hive.join(); - /// assert_eq!(hive.num_tasks(), (6, 0)); - /// hive.resume(); - /// // Wait for remaining tasks to complete. - /// hive.join(); - /// assert_eq!(hive.num_tasks(), (0, 0)); - /// # } - /// ``` - pub fn suspend(&self) { - (self.shared()).set_suspended(true); - } - - /// Unsets the suspended flag, allowing worker threads to continue processing queued tasks. - pub fn resume(&self) { - (self.shared()).set_suspended(false); - } - - /// Removes all `Unprocessed` outcomes from this `Hive` and returns them as an iterator over - /// the input values. - fn take_unprocessed_inputs(&self) -> impl ExactSizeIterator { - (self.shared()) - .take_unprocessed() - .into_iter() - .map(|outcome| match outcome { - Outcome::Unprocessed { input, task_id: _ } => input, - _ => unreachable!(), - }) - } - - /// If this `Hive` is suspended, resumes this `Hive` and re-submits any unprocessed tasks for - /// processing, with their results to be sent to `tx`. Returns a [`Vec`] of task IDs that - /// were resumed. - pub fn resume_send(&self, outcome_tx: &OutcomeSender) -> Vec { - (self.shared()) - .set_suspended(false) - .then(|| self.swarm_send(self.take_unprocessed_inputs(), outcome_tx)) - .unwrap_or_default() - } - - /// If this `Hive` is suspended, resumes this `Hive` and re-submit any unprocessed tasks for - /// processing, with their results to be stored in the queue. Returns a [`Vec`] of task IDs - /// that were resumed. - pub fn resume_store(&self) -> Vec { - (self.shared()) - .set_suspended(false) - .then(|| self.swarm_store(self.take_unprocessed_inputs())) - .unwrap_or_default() - } - - /// Returns all stored outcomes as a [`HashMap`] of task IDs to `Outcome`s. - pub fn take_stored(&self) -> HashMap> { - (self.shared()).take_outcomes() - } - - /// Consumes this `Hive` and attempts to return a [`Husk`] containing the remnants of this - /// `Hive`, including any stored task outcomes, and all the data necessary to create a new - /// `Hive`. - /// - /// If this `Hive` has been cloned, and those clones have not been dropped, this method - /// returns `None` since it cannot take exclusive ownership of the internal shared data. - /// - /// This method first joins on the `Hive` to wait for all tasks to finish. - pub fn try_into_husk(mut self) -> Option> { - if (self.shared()).num_referrers() > 1 { - return None; - } - // take the inner value and replace it with `None` - let mut shared = self.0.take().unwrap(); - // close the global queue to prevent new tasks from being submitted - shared.global_queue.close(); - // wait for all tasks to finish - shared.wait_on_done(); - // wait for worker threads to drop, then take ownership of the shared data and convert it - // into a Husk - let mut backoff = None::; - loop { - // TODO: may want to have some timeout or other kind of limit to prevent this from - // looping forever if a worker thread somehow gets stuck, or if the `num_referrers` - // counter is corrupted - shared = match Arc::try_unwrap(shared) { - Ok(shared) => { - return Some(shared.into_husk()); - } - Err(shared) => { - backoff.get_or_insert_with(Backoff::new).spin(); - shared - } - }; - } - } -} - -impl Default for Hive> { - fn default() -> Self { - Builder::default().build_with_default::() - } -} - -impl> Clone for Hive { - /// Creates a shallow copy of this `Hive` containing references to its same internal state, - /// i.e., all clones of a `Hive` submit tasks to the same shared worker thread pool. - fn clone(&self) -> Self { - let shared = self.0.as_ref().unwrap(); - shared.referrer_is_cloning(); - Self(Some(shared.clone())) - } -} - -impl> fmt::Debug for Hive { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if let Some(shared) = self.0.as_ref() { - f.debug_struct("Hive").field("shared", &shared).finish() - } else { - f.write_str("Hive {}") - } - } -} - -impl> PartialEq for Hive { - fn eq(&self, other: &Hive) -> bool { - let self_shared = self.shared(); - let other_shared = &other.shared(); - Arc::ptr_eq(self_shared, other_shared) - } -} - -impl> Eq for Hive {} - -impl> DerefOutcomes for Hive { - #[inline] - fn outcomes_deref(&self) -> impl Deref>> { - (self.shared()).outcomes() - } - - #[inline] - fn outcomes_deref_mut(&mut self) -> impl DerefMut>> { - (self.shared()).outcomes() - } -} - -impl> Drop for Hive { - fn drop(&mut self) { - // if this Hive has already been turned into a Husk, it's inner value will be `None` - if let Some(shared) = self.0.as_ref() { - // reduce the referrer count - let _ = shared.referrer_is_dropping(); - // if this Hive is the only one with a pointer to the shared data, poison it - // to prevent any worker threads that still have access to the shared data from - // re-spawning. - if shared.num_referrers() == 0 { - shared.poison(); - } - } - } -} - -/// Sentinel for a worker thread. Until the sentinel is cancelled, it will respawn the worker -/// thread if it panics. -struct Sentinel -where - W: Worker, - Q: Queen, - G: GlobalQueue, - L: LocalQueues, -{ - thread_index: usize, - shared: Arc>, - active: bool, -} - -impl Sentinel -where - W: Worker, - Q: Queen, - G: GlobalQueue, - L: LocalQueues, -{ - fn new(thread_index: usize, shared: Arc>) -> Self { - Self { - thread_index, - shared, - active: true, - } - } - - /// Cancel and destroy this sentinel. - fn cancel(mut self) { - self.active = false; - } -} - -impl Drop for Sentinel -where - W: Worker, - Q: Queen, - G: GlobalQueue, - L: LocalQueues, -{ - fn drop(&mut self) { - if self.active { - // if the sentinel is active, that means the thread panicked during task execution, so - // we have to finish the task here before respawning - self.shared.finish_task(thread::panicking()); - // only respawn if the sentinel is active and the hive has not been poisoned - if !self.shared.is_poisoned() { - // can't do anything with the previous result - let _ = self - .shared - .respawn_thread(self.thread_index, |thread_index| { - Hive::try_spawn(thread_index, &self.shared) - }); - } - } - } -} - -#[cfg(not(feature = "affinity"))] -mod no_affinity { - use crate::bee::{Queen, Worker}; - use crate::hive::{GlobalQueue, Hive, LocalQueues, Shared}; - - impl> Hive { - #[inline] - pub(super) fn init_thread, L: LocalQueues>( - _: usize, - _: &Shared, - ) { - } - } -} - -#[cfg(feature = "affinity")] -mod affinity { - use crate::bee::{Queen, Worker}; - use crate::hive::cores::Cores; - use crate::hive::{GlobalQueue, Hive, LocalQueues, Poisoned, Shared}; - - impl> Hive { - /// Tries to pin the worker thread to a specific CPU core. - #[inline] - pub(super) fn init_thread, L: LocalQueues>( - thread_index: usize, - shared: &Shared, - ) { - if let Some(core) = shared.get_core_affinity(thread_index) { - core.try_pin_current(); - } - } - - /// Attempts to increase the number of worker threads by `num_threads`. - /// - /// The provided `affinity` specifies additional CPU core indices to which the worker - /// threads may be pinned - these are added to the existing pool of core indices (if any). - /// - /// Returns the number of new worker threads that were successfully started (which may be - /// fewer than `num_threads`) or a `Poisoned` error if the hive has been poisoned. - pub fn grow_with_affinity>( - &self, - num_threads: usize, - affinity: C, - ) -> Result { - (self.shared()).add_core_affinity(affinity.into()); - self.grow(num_threads) - } - - /// Sets the number of worker threads to the number of available CPU cores. An attempt is - /// made to pin each worker thread to a different CPU core. - /// - /// Returns the number of new threads spun up (if any) or a `Poisoned` error if the hive - /// has been poisoned. - pub fn use_all_cores_with_affinity(&self) -> Result { - (self.shared()).add_core_affinity(Cores::all()); - self.use_all_cores() - } - } -} - -#[cfg(feature = "batching")] -mod batching { - use crate::bee::{Queen, Worker}; - use crate::hive::Hive; - - impl> Hive { - /// Returns the batch size for worker threads. - pub fn worker_batch_size(&self) -> usize { - (self.shared()).batch_size() - } - - /// Sets the batch size for worker threads. This will block the current thread until all - /// worker thread queues can be resized. - pub fn set_worker_batch_size(&self, batch_size: usize) { - (self.shared()).set_batch_size(batch_size); - } - } -} - -struct HiveTaskContext<'a, W, Q, G, L> -where - W: Worker, - Q: Queen, - G: GlobalQueue, - L: LocalQueues, -{ - thread_index: usize, - shared: &'a Arc>, - outcome_tx: Option<&'a OutcomeSender>, -} - -impl TaskContext for HiveTaskContext<'_, W, Q, G, L> -where - W: Worker, - Q: Queen, - G: GlobalQueue, - L: LocalQueues, -{ - fn cancel_tasks(&self) -> bool { - self.shared.is_suspended() - } - - fn submit_task(&self, input: W::Input) -> TaskId { - self.shared - .send_one_local(input, self.outcome_tx, self.thread_index) - } -} - -impl fmt::Debug for HiveTaskContext<'_, W, Q, G, L> -where - W: Worker, - Q: Queen, - G: GlobalQueue, - L: LocalQueues, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("HiveTaskContext").finish() - } -} - -#[cfg(not(feature = "retry"))] -mod no_retry { - use super::HiveTaskContext; - use crate::bee::{Context, Queen, Worker}; - use crate::hive::{GlobalQueue, Hive, LocalQueues, Outcome, Shared, Task}; - use std::sync::Arc; - - impl> Hive { - pub(super) fn execute, L: LocalQueues>( - task: Task, - thread_index: usize, - worker: &mut W, - shared: &Arc>, - ) { - let (task_id, input, outcome_tx) = task.into_parts(); - let task_ctx = HiveTaskContext { - thread_index, - shared, - outcome_tx: outcome_tx.as_ref(), - }; - let ctx = Context::new(task_id, Some(Box::new(&task_ctx))); - let result = worker.apply(input, &ctx); - let subtask_ids = ctx.into_subtask_ids(); - let outcome = Outcome::from_worker_result(result, task_id, subtask_ids); - shared.send_or_store_outcome(outcome, outcome_tx); - } - } -} - -#[cfg(feature = "retry")] -mod retry { - use super::HiveTaskContext; - use crate::bee::{ApplyError, Context, Queen, Worker}; - use crate::hive::{GlobalQueue, Hive, LocalQueues, Outcome, Shared, Task}; - use std::sync::Arc; - - impl> Hive { - pub(super) fn execute, L: LocalQueues>( - task: Task, - thread_index: usize, - worker: &mut W, - shared: &Arc>, - ) { - let (task_id, input, attempt, outcome_tx) = task.into_parts(); - let task_ctx = HiveTaskContext { - thread_index, - shared, - outcome_tx: outcome_tx.as_ref(), - }; - let ctx = Context::new(task_id, attempt, Some(&task_ctx)); - // execute the task until it succeeds or we reach maximum retries - this should - // be the only place where a panic can occur - let result = worker.apply(input, &ctx); - let subtask_ids = ctx.into_subtask_ids(); - match result { - Err(ApplyError::Retryable { input, .. }) - if subtask_ids.is_none() && shared.can_retry(attempt) => - { - shared.send_retry(task_id, input, outcome_tx, attempt + 1, thread_index); - } - result => { - let outcome = Outcome::from_worker_result(result, task_id, subtask_ids); - shared.send_or_store_outcome(outcome, outcome_tx); - } - } - } - } -} - -#[cfg(test)] -mod tests { - use super::Poisoned; - use crate::bee::stock::{Caller, Thunk, ThunkWorker}; - use crate::hive::{outcome_channel, Builder, Outcome, OutcomeIteratorExt}; - use std::collections::HashMap; - use std::thread; - use std::time::Duration; - - #[test] - fn test_suspend() { - let hive = Builder::new() - .num_threads(4) - .build_with_default::>(); - let outcome_iter = - hive.map((0..10).map(|_| Thunk::of(|| thread::sleep(Duration::from_secs(3))))); - // Allow first set of tasks to be started. - thread::sleep(Duration::from_secs(1)); - // There should be 4 active tasks and 6 queued tasks. - hive.suspend(); - assert_eq!(hive.num_tasks(), (6, 4)); - // Wait for active tasks to complete. - hive.join(); - assert_eq!(hive.num_tasks(), (6, 0)); - hive.resume(); - // Wait for remaining tasks to complete. - hive.join(); - assert_eq!(hive.num_tasks(), (0, 0)); - let outputs: Vec<_> = outcome_iter.into_outputs().collect(); - assert_eq!(outputs.len(), 10); - } - - #[test] - fn test_spawn_after_poison() { - let hive = Builder::new() - .num_threads(4) - .build_with_default::>(); - assert_eq!(hive.max_workers(), 4); - assert_eq!(hive.alive_workers(), 4); - // poison hive using private method - hive.0.as_ref().unwrap().poison(); - // attempt to spawn a new task - assert!(matches!(hive.grow(1), Err(Poisoned))); - // make sure the worker count wasn't increased - assert_eq!(hive.max_workers(), 4); - assert_eq!(hive.alive_workers(), 4); - } - - #[test] - fn test_apply_after_poison() { - let hive = Builder::new() - .num_threads(4) - .build_with(Caller::of(|i: usize| i * 2)); - // poison hive using private method - hive.0.as_ref().unwrap().poison(); - // submit a task, check that it comes back unprocessed - let (tx, rx) = outcome_channel(); - let sent_input = 1; - let sent_task_id = hive.apply_send(sent_input, &tx); - let outcome = rx.recv().unwrap(); - match outcome { - Outcome::Unprocessed { input, task_id } => { - assert_eq!(input, sent_input); - assert_eq!(task_id, sent_task_id); - } - _ => panic!("Expected unprocessed outcome"), - } - } - - #[test] - fn test_swarm_after_poison() { - let hive = Builder::new() - .num_threads(4) - .build_with(Caller::of(|i: usize| i * 2)); - // poison hive using private method - hive.0.as_ref().unwrap().poison(); - // submit a task, check that it comes back unprocessed - let (tx, rx) = outcome_channel(); - let inputs = 0..10; - let task_ids: HashMap = hive - .swarm_send(inputs.clone(), &tx) - .into_iter() - .zip(inputs) - .collect(); - for outcome in rx.into_iter().take(10) { - match outcome { - Outcome::Unprocessed { input, task_id } => { - let expected_input = task_ids.get(&task_id); - assert!(expected_input.is_some()); - assert_eq!(input, *expected_input.unwrap()); - } - _ => panic!("Expected unprocessed outcome"), - } - } - } -} diff --git a/src/hive/husk.rs b/src/hive/husk.rs index 0c12034..9795416 100644 --- a/src/hive/husk.rs +++ b/src/hive/husk.rs @@ -1,6 +1,6 @@ use super::{ Builder, Config, DerefOutcomes, Hive, Outcome, OutcomeBatch, OutcomeSender, OutcomeStore, - OwnedOutcomes, + OwnedOutcomes, QueuePair, }; use crate::bee::{Queen, TaskId, Worker}; use std::collections::HashMap; @@ -56,8 +56,8 @@ impl> Husk { /// Consumes this `Husk` and returns a new `Hive` with the same configuration and `Queen` as /// the one that produced this `Husk`. - pub fn into_hive(self) -> Hive { - self.as_builder().build(self.queen) + pub fn into_hive>(self) -> Hive { + self.as_builder().build::(self.queen) } /// Consumes this `Husk` and creates a new `Hive` with the same configuration as the one that @@ -65,16 +65,16 @@ impl> Husk { /// be sent to `tx`. Returns the new `Hive` and the IDs of the tasks that were queued. /// /// This method returns a `SpawnError` if there is an error creating the new `Hive`. - pub fn into_hive_swarm_send_unprocessed( + pub fn into_hive_swarm_send_unprocessed>( mut self, tx: &OutcomeSender, - ) -> (Hive, Vec) { + ) -> (Hive, Vec) { let unprocessed: Vec<_> = self .remove_all_unprocessed() .into_iter() .map(|(_, input)| input) .collect(); - let hive = self.as_builder().build(self.queen); + let hive = self.as_builder().build::(self.queen); let task_ids = hive.swarm_send(unprocessed, tx); (hive, task_ids) } @@ -85,13 +85,15 @@ impl> Husk { /// of the tasks that were queued. /// /// This method returns a `SpawnError` if there is an error creating the new `Hive`. - pub fn into_hive_swarm_store_unprocessed(mut self) -> (Hive, Vec) { + pub fn into_hive_swarm_store_unprocessed>( + mut self, + ) -> (Hive, Vec) { let unprocessed: Vec<_> = self .remove_all_unprocessed() .into_iter() .map(|(_, input)| input) .collect(); - let hive = self.as_builder().build(self.queen); + let hive = self.as_builder().build::(self.queen); let task_ids = hive.swarm_store(unprocessed); (hive, task_ids) } @@ -124,6 +126,7 @@ impl> OwnedOutcomes for Husk { #[cfg(test)] mod tests { use crate::bee::stock::{PunkWorker, Thunk, ThunkWorker}; + use crate::hive::queue::ChannelQueues; use crate::hive::{outcome_channel, Builder, Outcome, OutcomeIteratorExt, OutcomeStore}; #[test] @@ -131,7 +134,7 @@ mod tests { // don't spin up any worker threads so that no tasks will be processed let hive = Builder::new() .num_threads(0) - .build_with_default::>(); + .build_with_default::, ChannelQueues<_>>(); let mut task_ids = hive.map_store((0..10).map(|i| Thunk::of(move || i))); // cancel and smash the hive before the tasks can be processed hive.suspend(); @@ -156,12 +159,12 @@ mod tests { // don't spin up any worker threads so that no tasks will be processed let hive1 = Builder::new() .num_threads(0) - .build_with_default::>(); + .build_with_default::, ChannelQueues<_>>(); let _ = hive1.map_store((0..10).map(|i| Thunk::of(move || i))); // cancel and smash the hive before the tasks can be processed hive1.suspend(); let husk1 = hive1.try_into_husk().unwrap(); - let (hive2, _) = husk1.into_hive_swarm_store_unprocessed(); + let (hive2, _) = husk1.into_hive_swarm_store_unprocessed::>(); // now spin up worker threads to process the tasks hive2.grow(8).expect("error spawning threads"); hive2.join(); @@ -176,13 +179,13 @@ mod tests { // don't spin up any worker threads so that no tasks will be processed let hive1 = Builder::new() .num_threads(0) - .build_with_default::>(); + .build_with_default::, ChannelQueues<_>>(); let _ = hive1.map_store((0..10).map(|i| Thunk::of(move || i))); // cancel and smash the hive before the tasks can be processed hive1.suspend(); let husk1 = hive1.try_into_husk().unwrap(); let (tx, rx) = outcome_channel(); - let (hive2, task_ids) = husk1.into_hive_swarm_send_unprocessed(&tx); + let (hive2, task_ids) = husk1.into_hive_swarm_send_unprocessed::>(&tx); // now spin up worker threads to process the tasks hive2.grow(8).expect("error spawning threads"); hive2.join(); @@ -200,7 +203,7 @@ mod tests { fn test_into_result() { let hive = Builder::new() .num_threads(4) - .build_with_default::>(); + .build_with_default::, ChannelQueues<_>>(); hive.map_store((0..10).map(|i| Thunk::of(move || i))); hive.join(); let mut outputs = hive.try_into_husk().unwrap().into_parts().1.unwrap(); @@ -213,7 +216,7 @@ mod tests { fn test_into_result_panic() { let hive = Builder::new() .num_threads(4) - .build_with_default::>(); + .build_with_default::, ChannelQueues<_>>(); hive.map_store( (0..10).map(|i| Thunk::of(move || if i == 5 { panic!("oh no!") } else { i })), ); diff --git a/src/hive/mod.rs b/src/hive/mod.rs index 1b7964e..5a95368 100644 --- a/src/hive/mod.rs +++ b/src/hive/mod.rs @@ -356,12 +356,13 @@ //! ([`Husk::as_builder`](crate::hive::husk::Husk::as_builder)) or a new `Hive` //! ([`Husk::into_hive`](crate::hive::husk::Husk::into_hive)). mod builder; -mod channel; mod config; +mod core; mod counter; mod gate; mod husk; mod outcome; +mod queue; //mod scoped; mod shared; mod task; @@ -371,7 +372,6 @@ mod task; pub mod cores; pub use self::builder::Builder; -pub use self::channel::Poisoned; #[cfg(feature = "batching")] pub use self::config::set_batch_size_default; pub use self::config::{reset_defaults, set_num_threads_default, set_num_threads_default_all}; @@ -379,8 +379,10 @@ pub use self::config::{reset_defaults, set_num_threads_default, set_num_threads_ pub use self::config::{ set_max_retries_default, set_retries_default_disabled, set_retry_factor_default, }; +pub use self::core::Poisoned; pub use self::husk::Husk; pub use self::outcome::{Outcome, OutcomeBatch, OutcomeIteratorExt, OutcomeQueue, OutcomeStore}; +pub(crate) use self::queue::{ChannelGlobalQueue, ChannelQueues, DefaultLocalQueues}; /// Sender type for channel used to send task outcomes. pub type OutcomeSender = crate::channel::Sender>; @@ -404,7 +406,6 @@ pub mod prelude { use self::counter::DualCounter; use self::gate::{Gate, PhasedGate}; use self::outcome::{DerefOutcomes, OwnedOutcomes}; -use self::task::{ChannelGlobalQueue, ChannelLocalQueues}; use crate::atomic::{AtomicAny, AtomicBool, AtomicOption, AtomicUsize}; use crate::bee::{Queen, TaskId, Worker}; use parking_lot::Mutex; @@ -422,15 +423,13 @@ type U64 = AtomicOption; /// A pool of worker threads that each execute the same function. /// /// See the [module documentation](crate::hive) for details. -pub struct Hive>( - #[allow(clippy::type_complexity)] - Option, ChannelLocalQueues>>>, -); - -/// Type alias for the input task channel sender -type TaskSender = crossbeam_channel::Sender>; -/// Type alias for the input task channel receiver -type TaskReceiver = crossbeam_channel::Receiver>; +pub struct Hive< + W: Worker, + Q: Queen, + G: GlobalQueue, + L: LocalQueues, + P: QueuePair, +>(Option>>); /// Internal representation of a task to be processed by a `Hive`. #[derive(Debug)] @@ -468,15 +467,15 @@ struct Config { } /// Data shared by all worker threads in a `Hive`. -struct Shared, G: GlobalQueue, L: LocalQueues> { +struct Shared, P: QueuePair> { /// core configuration parameters config: Config, + /// the `Queen` used to create new workers + queen: Q, /// global task queue used by the `Hive` to send tasks to the worker threads - global_queue: G, + global_queue: P::Global, /// local queues used by worker threads to manage tasks - local_queues: L, - /// the `Queen` used to create new workers - queen: Mutex, + local_queues: P::Local, /// The results of spawning each worker spawn_results: Mutex, SpawnError>>>, /// allows for 2^48 queued tasks and 2^16 active tasks @@ -502,6 +501,13 @@ struct Shared, G: GlobalQueue, L: LocalQueues, } +pub trait QueuePair: Sized + 'static { + type Global: GlobalQueue; + type Local: LocalQueues; + + fn new() -> (Self::Global, Self::Local); +} + #[derive(thiserror::Error, Debug)] pub enum GlobalPopError { #[error("Task queue is closed")] @@ -511,7 +517,7 @@ pub enum GlobalPopError { } /// Trait that provides access to a global queue for receiving tasks. -trait GlobalQueue: Sized + Send + Sync + 'static { +pub trait GlobalQueue: Sized + Send + Sync + 'static { /// Tries to add a task to the global queue. /// /// Returns an error if the queue is disconnected. @@ -525,6 +531,9 @@ trait GlobalQueue: Sized + Send + Sync + 'static { /// Returns an error if the queue is disconnected. fn try_pop(&self) -> Option, GlobalPopError>>; + /// Returns an iterator that yields tasks from the queue unitl it is empty. + fn try_iter(&self) -> impl Iterator> + '_; + /// Drains all tasks from the global queue and returns them as an iterator. fn drain(&self) -> Vec>; @@ -537,42 +546,42 @@ trait GlobalQueue: Sized + Send + Sync + 'static { /// Ideally, these queues would be managed in a global thread-local data structure, but since tasks /// are `Worker`-specific, each `Hive` must have it's own set of queues stored within the Hive's /// shared data. -trait LocalQueues>: Sized + Default + Send + Sync + 'static { +pub trait LocalQueues>: Sized + Send + Sync + 'static { /// Initializes the local queues for the given range of worker thread indices. - fn init_for_threads>( + fn init_for_threads, P: QueuePair>( &self, start_index: usize, end_index: usize, - shared: &Shared, + shared: &Shared, ); /// Changes the size of the local queues to `size`. #[cfg(feature = "batching")] - fn resize>( + fn resize, P: QueuePair>( &self, start_index: usize, end_index: usize, new_size: usize, - shared: &Shared, + shared: &Shared, ); /// Attempts to add a task to the local queue if space is available, otherwise adds it to the /// global queue. If adding to the global queue fails, the task is abandoned (converted to an /// `Unprocessed` outcome and sent to the outcome channel or stored in the hive). - fn push>( + fn push, P: QueuePair>( &self, task: Task, thread_index: usize, - shared: &Shared, + shared: &Shared, ); /// Attempts to remove a task from the local queue for the given worker thread index. /// /// Returns `None` if there is no task immediately available. - fn try_pop>( + fn try_pop, P: QueuePair>( &self, thread_index: usize, - shared: &Shared, + shared: &Shared, ) -> Option>; /// Drains all tasks from all local queues and returns them as an iterator. @@ -581,25 +590,27 @@ trait LocalQueues>: Sized + Default + Send + Sync + /// Attempts to add `task` to the local retry queue. Returns the earliest `Instant` at which it /// might be retried. #[cfg(feature = "retry")] - fn retry>( + fn retry, P: QueuePair>( &self, task: Task, thread_index: usize, - shared: &Shared, + shared: &Shared, ) -> Option; } #[cfg(test)] mod tests { + use super::queue::{ChannelQueues, DefaultLocalQueues}; use super::{Builder, Hive, Outcome, OutcomeIteratorExt, OutcomeStore}; use crate::barrier::IndexedBarrier; use crate::bee::stock::{Caller, OnceCaller, RefCaller, Thunk, ThunkWorker}; use crate::bee::{ - ApplyError, ApplyRefError, Context, DefaultQueen, Queen, RefWorker, RefWorkerResult, - TaskId, Worker, WorkerResult, + ApplyError, ApplyRefError, Context, DefaultQueen, QueenCell, QueenMut, RefWorker, + RefWorkerResult, TaskId, Worker, WorkerResult, }; use crate::channel::{Message, ReceiverExt}; use crate::hive::outcome::DerefOutcomes; + use crate::hive::queue::ChannelGlobalQueue; use std::fmt::Debug; use std::io::{self, BufRead, BufReader, Write}; use std::process::{Child, ChildStdin, ChildStdout, Command, ExitStatus, Stdio}; @@ -615,1478 +626,1517 @@ mod tests { const SHORT_TASK: Duration = Duration::from_secs(2); const LONG_TASK: Duration = Duration::from_secs(5); - type ThunkHive = Hive, DefaultQueen>>; + type Global = ChannelGlobalQueue; + type Local = DefaultLocalQueues>; + type ThunkHive = Hive< + ThunkWorker, + DefaultQueen>, + Global>, + Local>, + ChannelQueues>, + >; /// Convenience function that returns a `Hive` configured with the global defaults, and the /// specified number of workers that execute `Thunk`s, i.e. closures that return `T`. pub fn thunk_hive(num_threads: usize) -> ThunkHive { Builder::default() .num_threads(num_threads) - .build_with_default() - } - - #[test] - fn test_works() { - let hive = thunk_hive(TEST_TASKS); - let (tx, rx) = mpsc::channel(); - assert_eq!(hive.max_workers(), TEST_TASKS); - assert_eq!(hive.alive_workers(), TEST_TASKS); - assert!(!hive.has_dead_workers()); - for _ in 0..TEST_TASKS { - let tx = tx.clone(); - hive.apply_store(Thunk::of(move || { - tx.send(1).unwrap(); - })); - } - assert_eq!(rx.iter().take(TEST_TASKS).sum::(), TEST_TASKS); - } - - #[test] - fn test_grow_from_zero() { - let hive = thunk_hive::(0); - // check that with 0 threads no tasks are scheduled - let (tx, rx) = super::outcome_channel(); - let _ = hive.apply_send(Thunk::of(|| 0), &tx); - thread::sleep(ONE_SEC); - assert_eq!(hive.num_tasks().0, 1); - assert!(matches!(rx.try_recv_msg(), Message::ChannelEmpty)); - hive.grow(1).expect("error spawning threads"); - thread::sleep(ONE_SEC); - assert_eq!(hive.num_tasks().0, 0); - assert!(matches!( - rx.try_recv_msg(), - Message::Received(Outcome::Success { value: 0, .. }) - )); - } - - #[test] - fn test_grow() { - let hive: ThunkHive<()> = Builder::new().num_threads(TEST_TASKS).build_with_default(); - // queue some long-running tasks - for _ in 0..TEST_TASKS { - hive.apply_store(Thunk::of(|| thread::sleep(LONG_TASK))); - } - thread::sleep(ONE_SEC); - assert_eq!(hive.num_tasks().1, TEST_TASKS as u64); - // increase the number of threads - let new_threads = 4; - let total_threads = new_threads + TEST_TASKS; - hive.grow(new_threads).expect("error spawning threads"); - // queue some more long-running tasks - for _ in 0..new_threads { - hive.apply_store(Thunk::of(|| thread::sleep(LONG_TASK))); - } - thread::sleep(ONE_SEC); - assert_eq!(hive.num_tasks().1, total_threads as u64); - let husk = hive.try_into_husk().unwrap(); - assert_eq!(husk.iter_successes().count(), total_threads); - } - - #[test] - fn test_suspend() { - let hive: ThunkHive<()> = Builder::new().num_threads(TEST_TASKS).build_with_default(); - // queue some long-running tasks - let total_tasks = 2 * TEST_TASKS; - for _ in 0..total_tasks { - hive.apply_store(Thunk::of(|| thread::sleep(SHORT_TASK))); - } - thread::sleep(ONE_SEC); - assert_eq!(hive.num_tasks(), (TEST_TASKS as u64, TEST_TASKS as u64)); - hive.suspend(); - // active tasks should finish but no more tasks should be started - thread::sleep(SHORT_TASK); - assert_eq!(hive.num_tasks(), (TEST_TASKS as u64, 0)); - assert_eq!(hive.num_successes(), TEST_TASKS); - hive.resume(); - // new tasks should start - thread::sleep(ONE_SEC); - assert_eq!(hive.num_tasks(), (0, TEST_TASKS as u64)); - thread::sleep(SHORT_TASK); - // all tasks should be completed - assert_eq!(hive.num_tasks(), (0, 0)); - assert_eq!(hive.num_successes(), total_tasks); - } - - #[derive(Debug, Default)] - struct MyRefWorker; - - impl RefWorker for MyRefWorker { - type Input = u8; - type Output = u8; - type Error = (); - - fn apply_ref( - &mut self, - input: &Self::Input, - ctx: &Context, - ) -> RefWorkerResult { - for _ in 0..3 { - thread::sleep(Duration::from_secs(1)); - if ctx.is_cancelled() { - return Err(ApplyRefError::Cancelled); - } - } - Ok(*input) - } - } - - #[test] - fn test_suspend_with_cancelled_tasks() { - let hive = Builder::new() - .num_threads(TEST_TASKS) - .build_with_default::(); - hive.swarm_store(0..TEST_TASKS as u8); - hive.suspend(); - // wait for tasks to be cancelled - thread::sleep(Duration::from_secs(2)); - hive.resume_store(); - thread::sleep(Duration::from_secs(1)); - // unprocessed tasks should be requeued - assert_eq!(hive.num_tasks().1, TEST_TASKS as u64); - thread::sleep(Duration::from_secs(3)); - assert_eq!(hive.num_successes(), TEST_TASKS); - } - - #[test] - fn test_num_tasks_active() { - let hive: ThunkHive<()> = Builder::new().num_threads(TEST_TASKS).build_with_default(); - for _ in 0..2 * TEST_TASKS { - hive.apply_store(Thunk::of(|| loop { - thread::sleep(LONG_TASK) - })); - } - thread::sleep(ONE_SEC); - assert_eq!(hive.num_tasks().1, TEST_TASKS as u64); - let num_threads = hive.max_workers(); - assert_eq!(num_threads, TEST_TASKS); - } - - #[test] - fn test_all_threads() { - let hive = Builder::new() - .with_thread_per_core() - .build_with_default::>(); - let num_threads = num_cpus::get(); - for _ in 0..num_threads { - hive.apply_store(Thunk::of(|| loop { - thread::sleep(LONG_TASK) - })); - } - thread::sleep(ONE_SEC); - assert_eq!(hive.num_tasks().1, num_threads as u64); - let num_threads = hive.max_workers(); - assert_eq!(num_threads, num_threads); - } - - #[test] - fn test_panic() { - let hive = thunk_hive(TEST_TASKS); - let (tx, _) = super::outcome_channel(); - // Panic all the existing threads. - for _ in 0..TEST_TASKS { - hive.apply_send(Thunk::of(|| panic!("intentional panic")), &tx); - } - hive.join(); - // Ensure that none of the threads have panicked - assert_eq!(hive.num_panics(), TEST_TASKS); - let husk = hive.try_into_husk().unwrap(); - assert_eq!(husk.num_panics(), TEST_TASKS); - } - - #[test] - fn test_catch_panic() { - let hive = Builder::new() - .num_threads(TEST_TASKS) - .build_with(RefCaller::of(|_: &u8| -> Result { - panic!("intentional panic") - })); - let (tx, rx) = super::outcome_channel(); - // Panic all the existing threads. - for i in 0..TEST_TASKS { - hive.apply_send(i as u8, &tx); - } - hive.join(); - // Ensure that none of the threads have panicked - assert_eq!(hive.num_panics(), 0); - // Check that all the results are Outcome::Panic - for outcome in rx.into_iter().take(TEST_TASKS) { - assert!(matches!(outcome, Outcome::Panic { .. })); - } - } - - #[test] - fn test_should_not_panic_on_drop_if_subtasks_panic_after_drop() { - let hive: ThunkHive<()> = Builder::new().num_threads(TEST_TASKS).build_with_default(); - let waiter = Arc::new(Barrier::new(TEST_TASKS + 1)); - let waiter_count = Arc::new(AtomicUsize::new(0)); - - // panic all the existing threads in a bit - for _ in 0..TEST_TASKS { - let waiter = waiter.clone(); - let waiter_count = waiter_count.clone(); - hive.apply_store(Thunk::of(move || { - waiter_count.fetch_add(1, Ordering::SeqCst); - waiter.wait(); - panic!("intentional panic"); - })); - } - - // queued tasks will not be processed after the hive is dropped, so we need to wait to make - // sure that all tasks have started and are waiting on the barrier - // TODO: find a Barrier implementation with try_wait() semantics - thread::sleep(Duration::from_secs(1)); - assert_eq!(waiter_count.load(Ordering::SeqCst), TEST_TASKS); - - drop(hive); - - // unblock the tasks and allow them to panic - waiter.wait(); - } - - #[test] - fn test_massive_task_creation() { - let test_tasks = 4_200_000; - - let hive = thunk_hive(TEST_TASKS); - let b0 = IndexedBarrier::new(TEST_TASKS); - let b1 = IndexedBarrier::new(TEST_TASKS); - - let (tx, rx) = mpsc::channel(); - - for _ in 0..test_tasks { - let tx = tx.clone(); - let (b0, b1) = (b0.clone(), b1.clone()); - - hive.apply_store(Thunk::of(move || { - // Wait until the pool has been filled once. - b0.wait(); - // wait so the pool can be measured - b1.wait(); - assert!(tx.send(1).is_ok()); - })); - } - - b0.wait(); - assert_eq!(hive.num_tasks().1, TEST_TASKS as u64); - b1.wait(); - - assert_eq!(rx.iter().take(test_tasks).sum::(), test_tasks); - hive.join(); - - let atomic_num_tasks_active = hive.num_tasks().1; - assert!( - atomic_num_tasks_active == 0, - "atomic_num_tasks_active: {}", - atomic_num_tasks_active - ); - } - - #[test] - fn test_name() { - let name = "test"; - let hive = Builder::new() - .thread_name(name.to_owned()) - .num_threads(2) - .build_with_default::>(); - let (tx, rx) = mpsc::channel(); - - // initial thread should share the name "test" - for _ in 0..2 { - let tx = tx.clone(); - hive.apply_store(Thunk::of(move || { - let name = thread::current().name().unwrap().to_owned(); - tx.send(name).unwrap(); - })); - } - - // new spawn thread should share the name "test" too. - hive.grow(3).expect("error spawning threads"); - let tx_clone = tx.clone(); - hive.apply_store(Thunk::of(move || { - let name = thread::current().name().unwrap().to_owned(); - tx_clone.send(name).unwrap(); - })); - - for thread_name in rx.iter().take(3) { - assert_eq!(name, thread_name); - } - } - - #[test] - fn test_stack_size() { - let stack_size = 4_000_000; - - let hive = Builder::new() - .num_threads(1) - .thread_stack_size(stack_size) - .build_with_default::>(); - - let actual_stack_size = hive - .apply(Thunk::of(|| { - //println!("This thread has a 4 MB stack size!"); - stacker::remaining_stack().unwrap() - })) - .unwrap() as f64; - - // measured value should be within 1% of actual - assert!(actual_stack_size > (stack_size as f64 * 0.99)); - assert!(actual_stack_size < (stack_size as f64 * 1.01)); - } - - #[test] - fn test_debug() { - let hive = thunk_hive::<()>(4); - let debug = format!("{:?}", hive); - assert_eq!( - debug, - "Hive { task_tx: Sender { .. }, shared: Shared { name: None, num_threads: 4, num_tasks_queued: 0, num_tasks_active: 0 } }" - ); - - let hive = Builder::new() - .thread_name("hello") - .num_threads(4) - .build_with_default::>(); - let debug = format!("{:?}", hive); - assert_eq!( - debug, - "Hive { task_tx: Sender { .. }, shared: Shared { name: \"hello\", num_threads: 4, num_tasks_queued: 0, num_tasks_active: 0 } }" - ); - - let hive = thunk_hive(4); - hive.apply_store(Thunk::of(|| thread::sleep(LONG_TASK))); - thread::sleep(ONE_SEC); - let debug = format!("{:?}", hive); - assert_eq!( - debug, - "Hive { task_tx: Sender { .. }, shared: Shared { name: None, num_threads: 4, num_tasks_queued: 0, num_tasks_active: 1 } }" - ); - } - - #[test] - fn test_repeated_join() { - let hive = Builder::new() - .thread_name("repeated join test") - .num_threads(8) - .build_with_default::>(); - let test_count = Arc::new(AtomicUsize::new(0)); - - for _ in 0..42 { - let test_count = test_count.clone(); - hive.apply_store(Thunk::of(move || { - thread::sleep(SHORT_TASK); - test_count.fetch_add(1, Ordering::Release); - })); - } - - hive.join(); - assert_eq!(42, test_count.load(Ordering::Acquire)); - - for _ in 0..42 { - let test_count = test_count.clone(); - hive.apply_store(Thunk::of(move || { - thread::sleep(SHORT_TASK); - test_count.fetch_add(1, Ordering::Relaxed); - })); - } - hive.join(); - assert_eq!(84, test_count.load(Ordering::Relaxed)); - } - - #[test] - fn test_multi_join() { - // Toggle the following lines to debug the deadlock - // fn error(_s: String) { - // use ::std::io::Write; - // let stderr = ::std::io::stderr(); - // let mut stderr = stderr.lock(); - // stderr - // .write(&_s.as_bytes()) - // .expect("Failed to write to stderr"); - // } - - let hive0 = Builder::new() - .thread_name("multi join pool0") - .num_threads(4) - .build_with_default::>(); - let hive1 = Builder::new() - .thread_name("multi join pool1") - .num_threads(4) - .build_with_default::>(); - let (tx, rx) = crate::channel::channel(); - - for i in 0..8 { - let hive1_clone = hive1.clone(); - let hive0_clone = hive0.clone(); - let tx = tx.clone(); - hive0.apply_store(Thunk::of(move || { - hive1_clone.apply_store(Thunk::of(move || { - //error(format!("p1: {} -=- {:?}\n", i, hive0_clone)); - hive0_clone.join(); - // ensure that the main thread has a chance to execute - thread::sleep(Duration::from_millis(10)); - //error(format!("p1: send({})\n", i)); - tx.send(i).expect("send failed from hive1_clone to main"); - })); - //error(format!("p0: {}\n", i)); - })); - } - drop(tx); - - // no hive1 task should be completed yet, so the channel should be empty - let before_any_send = rx.try_recv_msg(); - assert!(matches!(before_any_send, Message::ChannelEmpty)); - //error(format!("{:?}\n{:?}\n", hive0, hive1)); - hive0.join(); - //error(format!("pool0.join() complete =-= {:?}", hive1)); - hive1.join(); - //error("pool1.join() complete\n".into()); - assert_eq!(rx.into_iter().sum::(), (0..8).sum()); - } - - #[test] - fn test_empty_hive() { - // Joining an empty hive must return imminently - let hive = thunk_hive::<()>(4); - hive.join(); - } - - #[test] - fn test_no_fun_or_joy() { - // What happens when you keep adding tasks after a join - - fn sleepy_function() { - thread::sleep(LONG_TASK); - } - - let hive = Builder::new() - .thread_name("no fun or joy") - .num_threads(8) - .build_with_default::>(); - - hive.apply_store(Thunk::of(sleepy_function)); - - let p_t = hive.clone(); - thread::spawn(move || { - (0..23) - .inspect(|_| { - p_t.apply_store(Thunk::of(sleepy_function)); - }) - .count(); - }); - - hive.join(); - } - - #[test] - fn test_map() { - let hive = Builder::new() - .num_threads(2) - .build_with_default::>(); - let outputs: Vec<_> = hive - .map((0..10u8).map(|i| { - Thunk::of(move || { - thread::sleep(Duration::from_millis((10 - i as u64) * 100)); - i - }) - })) - .map(Outcome::unwrap) - .collect(); - assert_eq!(outputs, (0..10).collect::>()) - } - - #[test] - fn test_map_unordered() { - let hive = Builder::new() - .num_threads(8) - .build_with_default::>(); - let outputs: Vec<_> = hive - .map_unordered((0..8u8).map(|i| { - Thunk::of(move || { - thread::sleep(Duration::from_millis((8 - i as u64) * 100)); - i - }) - })) - .map(Outcome::unwrap) - .collect(); - assert_eq!(outputs, (0..8).rev().collect::>()) - } - - #[test] - fn test_map_send() { - let hive = Builder::new() - .num_threads(8) - .build_with_default::>(); - let (tx, rx) = super::outcome_channel(); - let mut task_ids = hive.map_send( - (0..8u8).map(|i| { - Thunk::of(move || { - thread::sleep(Duration::from_millis((8 - i as u64) * 100)); - i - }) - }), - &tx, - ); - let (mut outcome_task_ids, values): (Vec, Vec) = rx - .iter() - .map(|outcome| match outcome { - Outcome::Success { value, task_id } => (task_id, value), - _ => panic!("unexpected error"), - }) - .unzip(); - assert_eq!(values, (0..8).rev().collect::>()); - task_ids.sort(); - outcome_task_ids.sort(); - assert_eq!(task_ids, outcome_task_ids); - } - - #[test] - fn test_map_store() { - let mut hive = Builder::new() - .num_threads(8) - .build_with_default::>(); - let mut task_ids = hive.map_store((0..8u8).map(|i| { - Thunk::of(move || { - thread::sleep(Duration::from_millis((8 - i as u64) * 100)); - i - }) - })); - hive.join(); - for i in task_ids.iter() { - assert!(hive.outcomes_deref().get(i).unwrap().is_success()); - } - let (mut outcome_task_ids, values): (Vec, Vec) = task_ids - .clone() - .into_iter() - .map(|i| (i, hive.remove_success(i).unwrap())) - .collect(); - assert_eq!(values, (0..8).collect::>()); - task_ids.sort(); - outcome_task_ids.sort(); - assert_eq!(task_ids, outcome_task_ids); - } - - #[test] - fn test_swarm() { - let hive = Builder::new() - .num_threads(2) - .build_with_default::>(); - let outputs: Vec<_> = hive - .swarm((0..10u8).map(|i| { - Thunk::of(move || { - thread::sleep(Duration::from_millis((10 - i as u64) * 100)); - i - }) - })) - .map(Outcome::unwrap) - .collect(); - assert_eq!(outputs, (0..10).collect::>()) - } - - #[test] - fn test_swarm_unordered() { - let hive = Builder::new() - .num_threads(8) - .build_with_default::>(); - let outputs: Vec<_> = hive - .swarm_unordered((0..8u8).map(|i| { - Thunk::of(move || { - thread::sleep(Duration::from_millis((8 - i as u64) * 100)); - i - }) - })) - .map(Outcome::unwrap) - .collect(); - assert_eq!(outputs, (0..8).rev().collect::>()) - } - - #[test] - fn test_swarm_send() { - let hive = Builder::new() - .num_threads(8) - .build_with_default::>(); - let (tx, rx) = super::outcome_channel(); - let mut task_ids = hive.swarm_send( - (0..8u8).map(|i| { - Thunk::of(move || { - thread::sleep(Duration::from_millis((8 - i as u64) * 100)); - i - }) - }), - &tx, - ); - let (mut outcome_task_ids, values): (Vec, Vec) = rx - .iter() - .map(|outcome| match outcome { - Outcome::Success { value, task_id } => (task_id, value), - _ => panic!("unexpected error"), - }) - .unzip(); - assert_eq!(values, (0..8).rev().collect::>()); - task_ids.sort(); - outcome_task_ids.sort(); - assert_eq!(task_ids, outcome_task_ids); - } - - #[test] - fn test_swarm_store() { - let mut hive = Builder::new() - .num_threads(8) - .build_with_default::>(); - let mut task_ids = hive.swarm_store((0..8u8).map(|i| { - Thunk::of(move || { - thread::sleep(Duration::from_millis((8 - i as u64) * 100)); - i - }) - })); - hive.join(); - for i in task_ids.iter() { - assert!(hive.outcomes_deref().get(i).unwrap().is_success()); - } - let (mut outcome_task_ids, values): (Vec, Vec) = task_ids - .clone() - .into_iter() - .map(|i| (i, hive.remove_success(i).unwrap())) - .collect(); - assert_eq!(values, (0..8).collect::>()); - task_ids.sort(); - outcome_task_ids.sort(); - assert_eq!(task_ids, outcome_task_ids); - } - - #[test] - fn test_scan() { - let hive = Builder::new() - .num_threads(4) - .build_with(Caller::of(|i| i * i)); - let (outputs, state) = hive.scan(0..10, 0, |acc, i| { - *acc += i; - *acc - }); - let mut outputs = outputs.unwrap(); - outputs.sort(); - assert_eq!(outputs.len(), 10); - assert_eq!(state, 45); - assert_eq!( - outputs, - (0..10) - .scan(0, |acc, i| { - *acc += i; - Some(*acc) - }) - .map(|i| i * i) - .collect::>() - ); - } - - #[test] - fn test_scan_send() { - let hive = Builder::new() - .num_threads(4) - .build_with(Caller::of(|i| i * i)); - let (tx, rx) = super::outcome_channel(); - let (mut task_ids, state) = hive.scan_send(0..10, &tx, 0, |acc, i| { - *acc += i; - *acc - }); - assert_eq!(task_ids.len(), 10); - assert_eq!(state, 45); - let (mut outcome_task_ids, mut values): (Vec, Vec) = rx - .iter() - .map(|outcome| match outcome { - Outcome::Success { value, task_id } => (task_id, value), - _ => panic!("unexpected error"), - }) - .unzip(); - values.sort(); - assert_eq!( - values, - (0..10) - .scan(0, |acc, i| { - *acc += i; - Some(*acc) - }) - .map(|i| i * i) - .collect::>() - ); - task_ids.sort(); - outcome_task_ids.sort(); - assert_eq!(task_ids, outcome_task_ids); - } - - #[test] - fn test_try_scan_send() { - let hive = Builder::new() - .num_threads(4) - .build_with(Caller::of(|i| i * i)); - let (tx, rx) = super::outcome_channel(); - let (results, state) = hive.try_scan_send(0..10, &tx, 0, |acc, i| { - *acc += i; - Ok::<_, String>(*acc) - }); - let mut task_ids: Vec<_> = results.into_iter().map(Result::unwrap).collect(); - assert_eq!(task_ids.len(), 10); - assert_eq!(state, 45); - let (mut outcome_task_ids, mut values): (Vec, Vec) = rx - .iter() - .map(|outcome| match outcome { - Outcome::Success { value, task_id } => (task_id, value), - _ => panic!("unexpected error"), - }) - .unzip(); - values.sort(); - assert_eq!( - values, - (0..10) - .scan(0, |acc, i| { - *acc += i; - Some(*acc) - }) - .map(|i| i * i) - .collect::>() - ); - task_ids.sort(); - outcome_task_ids.sort(); - assert_eq!(task_ids, outcome_task_ids); - } - - #[test] - #[should_panic] - fn test_try_scan_send_fail() { - let hive = Builder::new() - .num_threads(4) - .build_with(OnceCaller::of(|i: i32| Ok::<_, String>(i * i))); - let (tx, _) = super::outcome_channel(); - let _ = hive - .try_scan_send(0..10, &tx, 0, |_, _| Err("fail")) - .0 - .into_iter() - .map(Result::unwrap) - .collect::>(); - } - - #[test] - fn test_scan_store() { - let mut hive = Builder::new() - .num_threads(4) - .build_with(Caller::of(|i| i * i)); - let (mut task_ids, state) = hive.scan_store(0..10, 0, |acc, i| { - *acc += i; - *acc - }); - assert_eq!(task_ids.len(), 10); - assert_eq!(state, 45); - hive.join(); - for i in task_ids.iter() { - assert!(hive.outcomes_deref().get(i).unwrap().is_success()); - } - let (mut outcome_task_ids, values): (Vec, Vec) = task_ids - .clone() - .into_iter() - .map(|i| (i, hive.remove_success(i).unwrap())) - .unzip(); - assert_eq!( - values, - (0..10) - .scan(0, |acc, i| { - *acc += i; - Some(*acc) - }) - .map(|i| i * i) - .collect::>() - ); - task_ids.sort(); - outcome_task_ids.sort(); - assert_eq!(task_ids, outcome_task_ids); - } - - #[test] - fn test_try_scan_store() { - let mut hive = Builder::new() - .num_threads(4) - .build_with(Caller::of(|i| i * i)); - let (results, state) = hive.try_scan_store(0..10, 0, |acc, i| { - *acc += i; - Ok::(*acc) - }); - let mut task_ids: Vec<_> = results.into_iter().map(Result::unwrap).collect(); - assert_eq!(task_ids.len(), 10); - assert_eq!(state, 45); - hive.join(); - for i in task_ids.iter() { - assert!(hive.outcomes_deref().get(i).unwrap().is_success()); - } - let (mut outcome_task_ids, values): (Vec, Vec) = task_ids - .clone() - .into_iter() - .map(|i| (i, hive.remove_success(i).unwrap())) - .unzip(); - assert_eq!( - values, - (0..10) - .scan(0, |acc, i| { - *acc += i; - Some(*acc) - }) - .map(|i| i * i) - .collect::>() - ); - task_ids.sort(); - outcome_task_ids.sort(); - assert_eq!(task_ids, outcome_task_ids); - } - - #[test] - #[should_panic] - fn test_try_scan_store_fail() { - let hive = Builder::new() - .num_threads(4) - .build_with(OnceCaller::of(|i: i32| Ok::(i * i))); - let _ = hive - .try_scan_store(0..10, 0, |_, _| Err("fail")) - .0 - .into_iter() - .map(Result::unwrap) - .collect::>(); - } - - #[test] - fn test_husk() { - let hive1 = Builder::new() - .num_threads(8) - .build_with_default::>(); - let task_ids = hive1.map_store((0..8u8).map(|i| Thunk::of(move || i))); - hive1.join(); - let mut husk1 = hive1.try_into_husk().unwrap(); - for i in task_ids.iter() { - assert!(husk1.outcomes_deref().get(i).unwrap().is_success()); - assert!(matches!(husk1.get(*i), Some(Outcome::Success { .. }))); - } - - let builder = husk1.as_builder(); - let hive2 = builder - .num_threads(4) - .build_with_default::>(); - hive2.map_store((0..8u8).map(|i| { - Thunk::of(move || { - thread::sleep(Duration::from_millis((8 - i as u64) * 100)); - i - }) - })); - hive2.join(); - let mut husk2 = hive2.try_into_husk().unwrap(); - - let mut outputs1 = husk1 - .remove_all() - .into_iter() - .map(Outcome::unwrap) - .collect::>(); - outputs1.sort(); - let mut outputs2 = husk2 - .remove_all() - .into_iter() - .map(Outcome::unwrap) - .collect::>(); - outputs2.sort(); - assert_eq!(outputs1, outputs2); - - let hive3 = husk1.into_hive(); - hive3.map_store((0..8u8).map(|i| { - Thunk::of(move || { - thread::sleep(Duration::from_millis((8 - i as u64) * 100)); - i - }) - })); - hive3.join(); - let husk3 = hive3.try_into_husk().unwrap(); - let (_, outcomes3) = husk3.into_parts(); - let mut outputs3 = outcomes3 - .into_iter() - .map(Outcome::unwrap) - .collect::>(); - outputs3.sort(); - assert_eq!(outputs1, outputs3); - } - - #[test] - fn test_clone() { - let hive = Builder::new() - .thread_name("clone example") - .num_threads(2) - .build_with_default::>(); - - // This batch of tasks will occupy the pool for some time - for _ in 0..6 { - hive.apply_store(Thunk::of(|| { - thread::sleep(SHORT_TASK); - })); - } - - // The following tasks will be inserted into the pool in a random fashion - let t0 = { - let hive = hive.clone(); - thread::spawn(move || { - // wait for the first batch of tasks to finish - hive.join(); - - let (tx, rx) = mpsc::channel(); - for i in 0..42 { - let tx = tx.clone(); - hive.apply_store(Thunk::of(move || { - tx.send(i).expect("channel will be waiting"); - })); - } - drop(tx); - rx.iter().sum::() - }) - }; - let t1 = { - let pool = hive.clone(); - thread::spawn(move || { - // wait for the first batch of tasks to finish - pool.join(); - - let (tx, rx) = mpsc::channel(); - for i in 1..12 { - let tx = tx.clone(); - pool.apply_store(Thunk::of(move || { - tx.send(i).expect("channel will be waiting"); - })); - } - drop(tx); - rx.iter().product::() - }) - }; - - assert_eq!( - 861, - t0.join() - .expect("thread 0 will return after calculating additions",) - ); - assert_eq!( - 39916800, - t1.join() - .expect("thread 1 will return after calculating multiplications",) - ); - } - - type VoidThunkWorker = ThunkWorker<()>; - type VoidThunkWorkerHive = Hive>; - - #[test] - fn test_send() { - fn assert_send() {} - assert_send::(); + .build_with_default::<_, ChannelQueues>>() } - #[test] - fn test_cloned_eq() { - let a = thunk_hive::<()>(2); - assert_eq!(a, a.clone()); - } - - #[test] - /// When a thread joins on a pool, it blocks until all tasks have completed. If a second thread - /// adds tasks to the pool and then joins before all the tasks have completed, both threads - /// will wait for all tasks to complete. However, as soon as all tasks have completed, all - /// joining threads are notified, and the first one to wake will exit the join and increment - /// the phase of the condvar. Subsequent notified threads will then see that the phase has been - /// changed and will wake, even if new tasks have been added in the meantime. - /// - /// In this example, this means the waiting threads will exit the join in groups of four - /// because the waiter pool has four processes. - fn test_join_wavesurfer() { - let n_waves = 4; - let n_workers = 4; - let (tx, rx) = mpsc::channel(); - let builder = Builder::new() - .num_threads(n_workers) - .thread_name("join wavesurfer"); - let waiter_hive = builder.clone().build_with_default::>(); - let clock_hive = builder.build_with_default::>(); - - let barrier = Arc::new(Barrier::new(3)); - let wave_counter = Arc::new(AtomicUsize::new(0)); - let clock_thread = { - let barrier = barrier.clone(); - let wave_counter = wave_counter.clone(); - thread::spawn(move || { - barrier.wait(); - for wave_num in 0..n_waves { - let _ = wave_counter.swap(wave_num, Ordering::SeqCst); - thread::sleep(ONE_SEC); - } - }) - }; - - { - let barrier = barrier.clone(); - clock_hive.apply_store(Thunk::of(move || { - barrier.wait(); - // this sleep is for stabilisation on weaker platforms - thread::sleep(Duration::from_millis(100)); - })); - } - - // prepare three waves of tasks (0..=11) - for worker in 0..(3 * n_workers) { - let tx = tx.clone(); - let clock_hive = clock_hive.clone(); - let wave_counter = wave_counter.clone(); - waiter_hive.apply_store(Thunk::of(move || { - let wave_before = wave_counter.load(Ordering::SeqCst); - clock_hive.join(); - // submit tasks for the next wave - clock_hive.apply_store(Thunk::of(|| thread::sleep(ONE_SEC))); - let wave_after = wave_counter.load(Ordering::SeqCst); - tx.send((wave_before, wave_after, worker)).unwrap(); - })); - } - barrier.wait(); - - clock_hive.join(); - - drop(tx); - let mut hist = vec![0; n_waves]; - let mut data = vec![]; - for (before, after, worker) in rx.iter() { - let mut dur = after - before; - if dur >= n_waves - 1 { - dur = n_waves - 1; - } - hist[dur] += 1; - data.push((before, after, worker)); - } - - println!("Histogram of wave duration:"); - for (i, n) in hist.iter().enumerate() { - println!( - "\t{}: {} {}", - i, - n, - &*(0..*n).fold("".to_owned(), |s, _| s + "*") - ); - } - - for (wave_before, wave_after, worker) in data.iter() { - if *worker < n_workers { - assert_eq!(wave_before, wave_after); - } else { - assert!(wave_before < wave_after); - } - } - clock_thread.join().unwrap(); - } - - // cargo-llvm-cov doesn't yet support doctests in stable, so we need to duplicate them in - // unit tests to get coverage - - #[test] - fn doctest_lib_2() { - // create a hive to process `Thunk`s - no-argument closures with the same return type (`i32`) - let hive = Builder::new() - .num_threads(4) - .thread_name("thunk_hive") - .build_with_default::>(); - - // return results to your own channel... - let (tx, rx) = crate::hive::outcome_channel(); - let task_ids = hive.swarm_send((0..10).map(|i: i32| Thunk::of(move || i * i)), &tx); - let outputs: Vec<_> = rx.select_unordered_outputs(task_ids).collect(); - assert_eq!(285, outputs.into_iter().sum()); - - // return results as an iterator... - let outputs2: Vec<_> = hive - .swarm((0..10).map(|i: i32| Thunk::of(move || i * -i))) - .into_outputs() - .collect(); - assert_eq!(-285, outputs2.into_iter().sum()); - } - - #[test] - fn doctest_lib_3() { - #[derive(Debug)] - struct CatWorker { - stdin: ChildStdin, - stdout: BufReader, - } - - impl CatWorker { - fn new(stdin: ChildStdin, stdout: ChildStdout) -> Self { - Self { - stdin, - stdout: BufReader::new(stdout), - } - } - - fn write_char(&mut self, c: u8) -> io::Result { - self.stdin.write_all(&[c])?; - self.stdin.write_all(b"\n")?; - self.stdin.flush()?; - let mut s = String::new(); - self.stdout.read_line(&mut s)?; - s.pop(); // exclude newline - Ok(s) - } - } - - impl Worker for CatWorker { - type Input = u8; - type Output = String; - type Error = io::Error; - - fn apply(&mut self, input: Self::Input, _: &Context) -> WorkerResult { - self.write_char(input).map_err(|error| ApplyError::Fatal { - input: Some(input), - error, - }) - } - } - - #[derive(Default)] - struct CatQueen { - children: Vec, - } - - impl CatQueen { - fn wait_for_all(&mut self) -> Vec> { - self.children - .drain(..) - .map(|mut child| child.wait()) - .collect() - } - } - - impl Queen for CatQueen { - type Kind = CatWorker; - - fn create(&mut self) -> Self::Kind { - let mut child = Command::new("cat") - .stdin(Stdio::piped()) - .stdout(Stdio::piped()) - .stderr(Stdio::inherit()) - .spawn() - .unwrap(); - let stdin = child.stdin.take().unwrap(); - let stdout = child.stdout.take().unwrap(); - self.children.push(child); - CatWorker::new(stdin, stdout) - } - } - - impl Drop for CatQueen { - fn drop(&mut self) { - self.wait_for_all() - .into_iter() - .for_each(|result| match result { - Ok(status) if status.success() => (), - Ok(status) => eprintln!("Child process failed: {}", status), - Err(e) => eprintln!("Error waiting for child process: {}", e), - }) - } - } - - // build the Hive - let hive = Builder::new().num_threads(4).build_default::(); - - // prepare inputs - let inputs: Vec = (0..8).map(|i| 97 + i).collect(); - - // execute tasks and collect outputs - let output = hive - .swarm(inputs) - .into_outputs() - .fold(String::new(), |mut a, b| { - a.push_str(&b); - a - }) - .into_bytes(); - - // verify the output - note that `swarm` ensures the outputs are in the same order - // as the inputs - assert_eq!(output, b"abcdefgh"); - - // shutdown the hive, use the Queen to wait on child processes, and report errors - let (mut queen, _) = hive.try_into_husk().unwrap().into_parts(); - let (wait_ok, wait_err): (Vec<_>, Vec<_>) = - queen.wait_for_all().into_iter().partition(Result::is_ok); - if !wait_err.is_empty() { - panic!( - "Error(s) occurred while waiting for child processes: {:?}", - wait_err - ); - } - let exec_err_codes: Vec<_> = wait_ok - .into_iter() - .map(Result::unwrap) - .filter(|status| !status.success()) - .filter_map(|status| status.code()) - .collect(); - if !exec_err_codes.is_empty() { - panic!( - "Child process(es) failed with exit codes: {:?}", - exec_err_codes - ); - } - } -} - -#[cfg(all(test, feature = "affinity"))] -mod affinity_tests { - use crate::bee::stock::{Thunk, ThunkWorker}; - use crate::hive::Builder; - - #[test] - fn test_affinity() { - let hive = Builder::new() - .thread_name("affinity example") - .num_threads(2) - .core_affinity(0..2) - .build_with_default::>(); - - hive.map_store((0..10).map(move |i| { - Thunk::of(move || { - if let Some(affininty) = core_affinity::get_core_ids() { - eprintln!("task {} on thread with affinity {:?}", i, affininty); - } - }) - })); - } - - #[test] - fn test_use_all_cores() { - let hive = Builder::new() - .thread_name("affinity example") - .with_thread_per_core() - .with_default_core_affinity() - .build_with_default::>(); - - hive.map_store((0..num_cpus::get()).map(move |i| { - Thunk::of(move || { - if let Some(affininty) = core_affinity::get_core_ids() { - eprintln!("task {} on thread with affinity {:?}", i, affininty); - } - }) - })); - } -} - -#[cfg(all(test, feature = "batching"))] -mod batching_tests { - use crate::barrier::IndexedBarrier; - use crate::bee::stock::{Thunk, ThunkWorker}; - use crate::bee::DefaultQueen; - use crate::hive::{Builder, Hive, OutcomeIteratorExt, OutcomeReceiver, OutcomeSender}; - use std::collections::HashMap; - use std::thread::{self, ThreadId}; - use std::time::Duration; - - fn launch_tasks( - hive: &Hive, DefaultQueen>>, - num_threads: usize, - num_tasks_per_thread: usize, - barrier: &IndexedBarrier, - tx: &OutcomeSender>, - ) -> Vec { - let total_tasks = num_threads * num_tasks_per_thread; - // send the first `num_threads` tasks widely spaced, so each worker thread only gets one - let init_task_ids: Vec<_> = (0..num_threads) - .map(|_| { - let barrier = barrier.clone(); - let task_id = hive.apply_send( - Thunk::of(move || { - barrier.wait(); - thread::sleep(Duration::from_millis(100)); - thread::current().id() - }), - tx, - ); - thread::sleep(Duration::from_millis(100)); - task_id - }) - .collect(); - // send the rest all at once - let rest_task_ids = hive.map_send( - (num_threads..total_tasks).map(|_| { - Thunk::of(move || { - thread::sleep(Duration::from_millis(1)); - thread::current().id() - }) - }), - tx, - ); - init_task_ids.into_iter().chain(rest_task_ids).collect() - } - - fn count_thread_ids( - rx: OutcomeReceiver>, - task_ids: Vec, - ) -> HashMap { - rx.select_unordered_outputs(task_ids) - .fold(HashMap::new(), |mut counter, id| { - *counter.entry(id).or_insert(0) += 1; - counter - }) - } - - fn run_test( - hive: &Hive, DefaultQueen>>, - num_threads: usize, - batch_size: usize, - ) { - let tasks_per_thread = batch_size + 2; - let (tx, rx) = crate::hive::outcome_channel(); - // each worker should take `batch_size` tasks for its queue + 1 to work on immediately, - // meaning there should be `batch_size + 1` tasks associated with each thread ID - let barrier = IndexedBarrier::new(num_threads); - let task_ids = launch_tasks(hive, num_threads, tasks_per_thread, &barrier, &tx); - // start the first tasks - barrier.wait(); - // wait for all tasks to complete - hive.join(); - let thread_counts = count_thread_ids(rx, task_ids); - assert_eq!(thread_counts.len(), num_threads); - assert!(thread_counts - .values() - .all(|&count| count == tasks_per_thread)); - } - - #[test] - fn test_batching() { - const NUM_THREADS: usize = 4; - const BATCH_SIZE: usize = 24; - let hive = Builder::new() - .num_threads(NUM_THREADS) - .batch_size(BATCH_SIZE) - .build_with_default::>(); - run_test(&hive, NUM_THREADS, BATCH_SIZE); - } - - #[test] - fn test_set_batch_size() { - const NUM_THREADS: usize = 4; - const BATCH_SIZE_0: usize = 10; - const BATCH_SIZE_1: usize = 20; - const BATCH_SIZE_2: usize = 50; - let hive = Builder::new() - .num_threads(NUM_THREADS) - .batch_size(BATCH_SIZE_0) - .build_with_default::>(); - run_test(&hive, NUM_THREADS, BATCH_SIZE_0); - // increase batch size - hive.set_worker_batch_size(BATCH_SIZE_2); - run_test(&hive, NUM_THREADS, BATCH_SIZE_2); - // decrease batch size - hive.set_worker_batch_size(BATCH_SIZE_1); - run_test(&hive, NUM_THREADS, BATCH_SIZE_1); - } - - #[test] - fn test_shrink_batch_size() { - const NUM_THREADS: usize = 4; - const NUM_TASKS_PER_THREAD: usize = 125; - const BATCH_SIZE_0: usize = 100; - const BATCH_SIZE_1: usize = 10; - let hive = Builder::new() - .num_threads(NUM_THREADS) - .batch_size(BATCH_SIZE_0) - .build_with_default::>(); - let (tx, rx) = crate::hive::outcome_channel(); - let barrier = IndexedBarrier::new(NUM_THREADS); - let task_ids = launch_tasks(&hive, NUM_THREADS, NUM_TASKS_PER_THREAD, &barrier, &tx); - let total_tasks = NUM_THREADS * NUM_TASKS_PER_THREAD; - assert_eq!(task_ids.len(), total_tasks); - barrier.wait(); - hive.set_worker_batch_size(BATCH_SIZE_1); - // The number of tasks completed by each thread could be variable, so we want to ensure - // that a) each processed at least `BATCH_SIZE_0` tasks, and b) there are a total of - // `NUM_TASKS` outputs with no errors - hive.join(); - let thread_counts = count_thread_ids(rx, task_ids); - assert!(thread_counts.values().all(|count| *count > BATCH_SIZE_0)); - assert_eq!(thread_counts.values().sum::(), total_tasks); - } -} - -#[cfg(all(test, feature = "retry"))] -mod retry_tests { - use crate::bee::stock::RetryCaller; - use crate::bee::{ApplyError, Context}; - use crate::hive::{Builder, Outcome, OutcomeIteratorExt}; - use std::time::{Duration, SystemTime}; - - fn echo_time(i: usize, ctx: &Context) -> Result> { - let attempt = ctx.attempt(); - if attempt == 3 { - Ok("Success".into()) - } else { - // the delay between each message should be exponential - eprintln!("Task {} attempt {}: {:?}", i, attempt, SystemTime::now()); - Err(ApplyError::Retryable { - input: i, - error: "Retryable".into(), - }) - } - } - - #[test] - fn test_retries() { - let hive = Builder::new() - .with_thread_per_core() - .max_retries(3) - .retry_factor(Duration::from_secs(1)) - .build_with(RetryCaller::of(echo_time)); - - let v: Result, _> = hive.swarm(0..10).into_results().collect(); - assert_eq!(v.unwrap().len(), 10); - } - - #[test] - fn test_retries_fail() { - fn sometimes_fail( - i: usize, - _: &Context, - ) -> Result> { - match i % 3 { - 0 => Ok("Success".into()), - 1 => Err(ApplyError::Retryable { - input: i, - error: "Retryable".into(), - }), - 2 => Err(ApplyError::Fatal { - input: Some(i), - error: "Fatal".into(), - }), - _ => unreachable!(), - } - } - - let hive = Builder::new() - .with_thread_per_core() - .max_retries(3) - .build_with(RetryCaller::of(sometimes_fail)); - - let (success, retry_failed, not_retried) = hive.swarm(0..10).fold( - (0, 0, 0), - |(success, retry_failed, not_retried), outcome| match outcome { - Outcome::Success { .. } => (success + 1, retry_failed, not_retried), - Outcome::MaxRetriesAttempted { .. } => (success, retry_failed + 1, not_retried), - Outcome::Failure { .. } => (success, retry_failed, not_retried + 1), - _ => unreachable!(), - }, - ); - - assert_eq!(success, 4); - assert_eq!(retry_failed, 3); - assert_eq!(not_retried, 3); - } - - #[test] - fn test_disable_retries() { - let hive = Builder::new() - .with_thread_per_core() - .with_no_retries() - .build_with(RetryCaller::of(echo_time)); - let v: Result, _> = hive.swarm(0..10).into_results().collect(); - assert!(v.is_err()); - } + // #[test] + // fn test_works() { + // let hive = thunk_hive(TEST_TASKS); + // let (tx, rx) = mpsc::channel(); + // assert_eq!(hive.max_workers(), TEST_TASKS); + // assert_eq!(hive.alive_workers(), TEST_TASKS); + // assert!(!hive.has_dead_workers()); + // for _ in 0..TEST_TASKS { + // let tx = tx.clone(); + // hive.apply_store(Thunk::of(move || { + // tx.send(1).unwrap(); + // })); + // } + // assert_eq!(rx.iter().take(TEST_TASKS).sum::(), TEST_TASKS); + // } + + // #[test] + // fn test_grow_from_zero() { + // let hive = thunk_hive::(0); + // // check that with 0 threads no tasks are scheduled + // let (tx, rx) = super::outcome_channel(); + // let _ = hive.apply_send(Thunk::of(|| 0), &tx); + // thread::sleep(ONE_SEC); + // assert_eq!(hive.num_tasks().0, 1); + // assert!(matches!(rx.try_recv_msg(), Message::ChannelEmpty)); + // hive.grow(1).expect("error spawning threads"); + // thread::sleep(ONE_SEC); + // assert_eq!(hive.num_tasks().0, 0); + // assert!(matches!( + // rx.try_recv_msg(), + // Message::Received(Outcome::Success { value: 0, .. }) + // )); + // } + + // #[test] + // fn test_grow() { + // let hive: ThunkHive<()> = Builder::new().num_threads(TEST_TASKS).build_with_default(); + // // queue some long-running tasks + // for _ in 0..TEST_TASKS { + // hive.apply_store(Thunk::of(|| thread::sleep(LONG_TASK))); + // } + // thread::sleep(ONE_SEC); + // assert_eq!(hive.num_tasks().1, TEST_TASKS as u64); + // // increase the number of threads + // let new_threads = 4; + // let total_threads = new_threads + TEST_TASKS; + // hive.grow(new_threads).expect("error spawning threads"); + // // queue some more long-running tasks + // for _ in 0..new_threads { + // hive.apply_store(Thunk::of(|| thread::sleep(LONG_TASK))); + // } + // thread::sleep(ONE_SEC); + // assert_eq!(hive.num_tasks().1, total_threads as u64); + // let husk = hive.try_into_husk().unwrap(); + // assert_eq!(husk.iter_successes().count(), total_threads); + // } + + // #[test] + // fn test_suspend() { + // let hive: ThunkHive<()> = Builder::new().num_threads(TEST_TASKS).build_with_default(); + // // queue some long-running tasks + // let total_tasks = 2 * TEST_TASKS; + // for _ in 0..total_tasks { + // hive.apply_store(Thunk::of(|| thread::sleep(SHORT_TASK))); + // } + // thread::sleep(ONE_SEC); + // assert_eq!(hive.num_tasks(), (TEST_TASKS as u64, TEST_TASKS as u64)); + // hive.suspend(); + // // active tasks should finish but no more tasks should be started + // thread::sleep(SHORT_TASK); + // assert_eq!(hive.num_tasks(), (TEST_TASKS as u64, 0)); + // assert_eq!(hive.num_successes(), TEST_TASKS); + // hive.resume(); + // // new tasks should start + // thread::sleep(ONE_SEC); + // assert_eq!(hive.num_tasks(), (0, TEST_TASKS as u64)); + // thread::sleep(SHORT_TASK); + // // all tasks should be completed + // assert_eq!(hive.num_tasks(), (0, 0)); + // assert_eq!(hive.num_successes(), total_tasks); + // } + + // #[derive(Debug, Default)] + // struct MyRefWorker; + + // impl RefWorker for MyRefWorker { + // type Input = u8; + // type Output = u8; + // type Error = (); + + // fn apply_ref( + // &mut self, + // input: &Self::Input, + // ctx: &Context, + // ) -> RefWorkerResult { + // for _ in 0..3 { + // thread::sleep(Duration::from_secs(1)); + // if ctx.is_cancelled() { + // return Err(ApplyRefError::Cancelled); + // } + // } + // Ok(*input) + // } + // } + + // #[test] + // fn test_suspend_with_cancelled_tasks() { + // let hive = Builder::new() + // .num_threads(TEST_TASKS) + // .build_with_default::, Local<_>>(); + // hive.swarm_store(0..TEST_TASKS as u8); + // hive.suspend(); + // // wait for tasks to be cancelled + // thread::sleep(Duration::from_secs(2)); + // hive.resume_store(); + // thread::sleep(Duration::from_secs(1)); + // // unprocessed tasks should be requeued + // assert_eq!(hive.num_tasks().1, TEST_TASKS as u64); + // thread::sleep(Duration::from_secs(3)); + // assert_eq!(hive.num_successes(), TEST_TASKS); + // } + + // #[test] + // fn test_num_tasks_active() { + // let hive: ThunkHive<()> = Builder::new().num_threads(TEST_TASKS).build_with_default(); + // for _ in 0..2 * TEST_TASKS { + // hive.apply_store(Thunk::of(|| loop { + // thread::sleep(LONG_TASK) + // })); + // } + // thread::sleep(ONE_SEC); + // assert_eq!(hive.num_tasks().1, TEST_TASKS as u64); + // let num_threads = hive.max_workers(); + // assert_eq!(num_threads, TEST_TASKS); + // } + + // #[test] + // fn test_all_threads() { + // let hive = Builder::new() + // .with_thread_per_core() + // .build_with_default::, Global<_>, Local<_>>(); + // let num_threads = num_cpus::get(); + // for _ in 0..num_threads { + // hive.apply_store(Thunk::of(|| loop { + // thread::sleep(LONG_TASK) + // })); + // } + // thread::sleep(ONE_SEC); + // assert_eq!(hive.num_tasks().1, num_threads as u64); + // let num_threads = hive.max_workers(); + // assert_eq!(num_threads, num_threads); + // } + + // #[test] + // fn test_panic() { + // let hive = thunk_hive(TEST_TASKS); + // let (tx, _) = super::outcome_channel(); + // // Panic all the existing threads. + // for _ in 0..TEST_TASKS { + // hive.apply_send(Thunk::of(|| panic!("intentional panic")), &tx); + // } + // hive.join(); + // // Ensure that none of the threads have panicked + // assert_eq!(hive.num_panics(), TEST_TASKS); + // let husk = hive.try_into_husk().unwrap(); + // assert_eq!(husk.num_panics(), TEST_TASKS); + // } + + // #[test] + // fn test_catch_panic() { + // let hive: Hive<_, _, Global<_>, Local<_>> = Builder::new() + // .num_threads(TEST_TASKS) + // .build_with(RefCaller::of(|_: &u8| -> Result { + // panic!("intentional panic") + // })); + // let (tx, rx) = super::outcome_channel(); + // // Panic all the existing threads. + // for i in 0..TEST_TASKS { + // hive.apply_send(i as u8, &tx); + // } + // hive.join(); + // // Ensure that none of the threads have panicked + // assert_eq!(hive.num_panics(), 0); + // // Check that all the results are Outcome::Panic + // for outcome in rx.into_iter().take(TEST_TASKS) { + // assert!(matches!(outcome, Outcome::Panic { .. })); + // } + // } + + // #[test] + // fn test_should_not_panic_on_drop_if_subtasks_panic_after_drop() { + // let hive: ThunkHive<()> = Builder::new().num_threads(TEST_TASKS).build_with_default(); + // let waiter = Arc::new(Barrier::new(TEST_TASKS + 1)); + // let waiter_count = Arc::new(AtomicUsize::new(0)); + + // // panic all the existing threads in a bit + // for _ in 0..TEST_TASKS { + // let waiter = waiter.clone(); + // let waiter_count = waiter_count.clone(); + // hive.apply_store(Thunk::of(move || { + // waiter_count.fetch_add(1, Ordering::SeqCst); + // waiter.wait(); + // panic!("intentional panic"); + // })); + // } + + // // queued tasks will not be processed after the hive is dropped, so we need to wait to make + // // sure that all tasks have started and are waiting on the barrier + // // TODO: find a Barrier implementation with try_wait() semantics + // thread::sleep(Duration::from_secs(1)); + // assert_eq!(waiter_count.load(Ordering::SeqCst), TEST_TASKS); + + // drop(hive); + + // // unblock the tasks and allow them to panic + // waiter.wait(); + // } + + // #[test] + // fn test_massive_task_creation() { + // let test_tasks = 4_200_000; + + // let hive = thunk_hive(TEST_TASKS); + // let b0 = IndexedBarrier::new(TEST_TASKS); + // let b1 = IndexedBarrier::new(TEST_TASKS); + + // let (tx, rx) = mpsc::channel(); + + // for _ in 0..test_tasks { + // let tx = tx.clone(); + // let (b0, b1) = (b0.clone(), b1.clone()); + + // hive.apply_store(Thunk::of(move || { + // // Wait until the pool has been filled once. + // b0.wait(); + // // wait so the pool can be measured + // b1.wait(); + // assert!(tx.send(1).is_ok()); + // })); + // } + + // b0.wait(); + // assert_eq!(hive.num_tasks().1, TEST_TASKS as u64); + // b1.wait(); + + // assert_eq!(rx.iter().take(test_tasks).sum::(), test_tasks); + // hive.join(); + + // let atomic_num_tasks_active = hive.num_tasks().1; + // assert!( + // atomic_num_tasks_active == 0, + // "atomic_num_tasks_active: {}", + // atomic_num_tasks_active + // ); + // } + + // #[test] + // fn test_name() { + // let name = "test"; + // let hive = Builder::new() + // .thread_name(name.to_owned()) + // .num_threads(2) + // .build_with_default::, Global<_>, Local<_>>(); + // let (tx, rx) = mpsc::channel(); + + // // initial thread should share the name "test" + // for _ in 0..2 { + // let tx = tx.clone(); + // hive.apply_store(Thunk::of(move || { + // let name = thread::current().name().unwrap().to_owned(); + // tx.send(name).unwrap(); + // })); + // } + + // // new spawn thread should share the name "test" too. + // hive.grow(3).expect("error spawning threads"); + // let tx_clone = tx.clone(); + // hive.apply_store(Thunk::of(move || { + // let name = thread::current().name().unwrap().to_owned(); + // tx_clone.send(name).unwrap(); + // })); + + // for thread_name in rx.iter().take(3) { + // assert_eq!(name, thread_name); + // } + // } + + // #[test] + // fn test_stack_size() { + // let stack_size = 4_000_000; + + // let hive = Builder::new() + // .num_threads(1) + // .thread_stack_size(stack_size) + // .build_with_default::, Global<_>, Local<_>>(); + + // let actual_stack_size = hive + // .apply(Thunk::of(|| { + // //println!("This thread has a 4 MB stack size!"); + // stacker::remaining_stack().unwrap() + // })) + // .unwrap() as f64; + + // // measured value should be within 1% of actual + // assert!(actual_stack_size > (stack_size as f64 * 0.99)); + // assert!(actual_stack_size < (stack_size as f64 * 1.01)); + // } + + // #[test] + // fn test_debug() { + // let hive = thunk_hive::<()>(4); + // let debug = format!("{:?}", hive); + // assert_eq!( + // debug, + // "Hive { task_tx: Sender { .. }, shared: Shared { name: None, num_threads: 4, num_tasks_queued: 0, num_tasks_active: 0 } }" + // ); + + // let hive = Builder::new() + // .thread_name("hello") + // .num_threads(4) + // .build_with_default::, Global<_>, Local<_>>(); + // let debug = format!("{:?}", hive); + // assert_eq!( + // debug, + // "Hive { task_tx: Sender { .. }, shared: Shared { name: \"hello\", num_threads: 4, num_tasks_queued: 0, num_tasks_active: 0 } }" + // ); + + // let hive = thunk_hive(4); + // hive.apply_store(Thunk::of(|| thread::sleep(LONG_TASK))); + // thread::sleep(ONE_SEC); + // let debug = format!("{:?}", hive); + // assert_eq!( + // debug, + // "Hive { task_tx: Sender { .. }, shared: Shared { name: None, num_threads: 4, num_tasks_queued: 0, num_tasks_active: 1 } }" + // ); + // } + + // #[test] + // fn test_repeated_join() { + // let hive = Builder::new() + // .thread_name("repeated join test") + // .num_threads(8) + // .build_with_default::, Global<_>, Local<_>>(); + // let test_count = Arc::new(AtomicUsize::new(0)); + + // for _ in 0..42 { + // let test_count = test_count.clone(); + // hive.apply_store(Thunk::of(move || { + // thread::sleep(SHORT_TASK); + // test_count.fetch_add(1, Ordering::Release); + // })); + // } + + // hive.join(); + // assert_eq!(42, test_count.load(Ordering::Acquire)); + + // for _ in 0..42 { + // let test_count = test_count.clone(); + // hive.apply_store(Thunk::of(move || { + // thread::sleep(SHORT_TASK); + // test_count.fetch_add(1, Ordering::Relaxed); + // })); + // } + // hive.join(); + // assert_eq!(84, test_count.load(Ordering::Relaxed)); + // } + + // #[test] + // fn test_multi_join() { + // // Toggle the following lines to debug the deadlock + // // fn error(_s: String) { + // // use ::std::io::Write; + // // let stderr = ::std::io::stderr(); + // // let mut stderr = stderr.lock(); + // // stderr + // // .write(&_s.as_bytes()) + // // .expect("Failed to write to stderr"); + // // } + + // let hive0 = Builder::new() + // .thread_name("multi join pool0") + // .num_threads(4) + // .build_with_default::, Global<_>, Local<_>>(); + // let hive1 = Builder::new() + // .thread_name("multi join pool1") + // .num_threads(4) + // .build_with_default::, Global<_>, Local<_>>(); + // let (tx, rx) = crate::channel::channel(); + + // for i in 0..8 { + // let hive1_clone = hive1.clone(); + // let hive0_clone = hive0.clone(); + // let tx = tx.clone(); + // hive0.apply_store(Thunk::of(move || { + // hive1_clone.apply_store(Thunk::of(move || { + // //error(format!("p1: {} -=- {:?}\n", i, hive0_clone)); + // hive0_clone.join(); + // // ensure that the main thread has a chance to execute + // thread::sleep(Duration::from_millis(10)); + // //error(format!("p1: send({})\n", i)); + // tx.send(i).expect("send failed from hive1_clone to main"); + // })); + // //error(format!("p0: {}\n", i)); + // })); + // } + // drop(tx); + + // // no hive1 task should be completed yet, so the channel should be empty + // let before_any_send = rx.try_recv_msg(); + // assert!(matches!(before_any_send, Message::ChannelEmpty)); + // //error(format!("{:?}\n{:?}\n", hive0, hive1)); + // hive0.join(); + // //error(format!("pool0.join() complete =-= {:?}", hive1)); + // hive1.join(); + // //error("pool1.join() complete\n".into()); + // assert_eq!(rx.into_iter().sum::(), (0..8).sum()); + // } + + // #[test] + // fn test_empty_hive() { + // // Joining an empty hive must return imminently + // let hive = thunk_hive::<()>(4); + // hive.join(); + // } + + // #[test] + // fn test_no_fun_or_joy() { + // // What happens when you keep adding tasks after a join + + // fn sleepy_function() { + // thread::sleep(LONG_TASK); + // } + + // let hive = Builder::new() + // .thread_name("no fun or joy") + // .num_threads(8) + // .build_with_default::, Global<_>, Local<_>>(); + + // hive.apply_store(Thunk::of(sleepy_function)); + + // let p_t = hive.clone(); + // thread::spawn(move || { + // (0..23) + // .inspect(|_| { + // p_t.apply_store(Thunk::of(sleepy_function)); + // }) + // .count(); + // }); + + // hive.join(); + // } + + // #[test] + // fn test_map() { + // let hive = Builder::new() + // .num_threads(2) + // .build_with_default::, Global<_>, Local<_>>(); + // let outputs: Vec<_> = hive + // .map((0..10u8).map(|i| { + // Thunk::of(move || { + // thread::sleep(Duration::from_millis((10 - i as u64) * 100)); + // i + // }) + // })) + // .map(Outcome::unwrap) + // .collect(); + // assert_eq!(outputs, (0..10).collect::>()) + // } + + // #[test] + // fn test_map_unordered() { + // let hive = Builder::new() + // .num_threads(8) + // .build_with_default::, Global<_>, Local<_>>(); + // let outputs: Vec<_> = hive + // .map_unordered((0..8u8).map(|i| { + // Thunk::of(move || { + // thread::sleep(Duration::from_millis((8 - i as u64) * 100)); + // i + // }) + // })) + // .map(Outcome::unwrap) + // .collect(); + // assert_eq!(outputs, (0..8).rev().collect::>()) + // } + + // #[test] + // fn test_map_send() { + // let hive = Builder::new() + // .num_threads(8) + // .build_with_default::, Global<_>, Local<_>>(); + // let (tx, rx) = super::outcome_channel(); + // let mut task_ids = hive.map_send( + // (0..8u8).map(|i| { + // Thunk::of(move || { + // thread::sleep(Duration::from_millis((8 - i as u64) * 100)); + // i + // }) + // }), + // &tx, + // ); + // let (mut outcome_task_ids, values): (Vec, Vec) = rx + // .iter() + // .map(|outcome| match outcome { + // Outcome::Success { value, task_id } => (task_id, value), + // _ => panic!("unexpected error"), + // }) + // .unzip(); + // assert_eq!(values, (0..8).rev().collect::>()); + // task_ids.sort(); + // outcome_task_ids.sort(); + // assert_eq!(task_ids, outcome_task_ids); + // } + + // #[test] + // fn test_map_store() { + // let mut hive = Builder::new() + // .num_threads(8) + // .build_with_default::, Global<_>, Local<_>>(); + // let mut task_ids = hive.map_store((0..8u8).map(|i| { + // Thunk::of(move || { + // thread::sleep(Duration::from_millis((8 - i as u64) * 100)); + // i + // }) + // })); + // hive.join(); + // for i in task_ids.iter() { + // assert!(hive.outcomes_deref().get(i).unwrap().is_success()); + // } + // let (mut outcome_task_ids, values): (Vec, Vec) = task_ids + // .clone() + // .into_iter() + // .map(|i| (i, hive.remove_success(i).unwrap())) + // .collect(); + // assert_eq!(values, (0..8).collect::>()); + // task_ids.sort(); + // outcome_task_ids.sort(); + // assert_eq!(task_ids, outcome_task_ids); + // } + + // #[test] + // fn test_swarm() { + // let hive = Builder::new() + // .num_threads(2) + // .build_with_default::, Global<_>, Local<_>>(); + // let outputs: Vec<_> = hive + // .swarm((0..10u8).map(|i| { + // Thunk::of(move || { + // thread::sleep(Duration::from_millis((10 - i as u64) * 100)); + // i + // }) + // })) + // .map(Outcome::unwrap) + // .collect(); + // assert_eq!(outputs, (0..10).collect::>()) + // } + + // #[test] + // fn test_swarm_unordered() { + // let hive = Builder::new() + // .num_threads(8) + // .build_with_default::, Global<_>, Local<_>>(); + // let outputs: Vec<_> = hive + // .swarm_unordered((0..8u8).map(|i| { + // Thunk::of(move || { + // thread::sleep(Duration::from_millis((8 - i as u64) * 100)); + // i + // }) + // })) + // .map(Outcome::unwrap) + // .collect(); + // assert_eq!(outputs, (0..8).rev().collect::>()) + // } + + // #[test] + // fn test_swarm_send() { + // let hive = Builder::new() + // .num_threads(8) + // .build_with_default::, Global<_>, Local<_>>(); + // let (tx, rx) = super::outcome_channel(); + // let mut task_ids = hive.swarm_send( + // (0..8u8).map(|i| { + // Thunk::of(move || { + // thread::sleep(Duration::from_millis((8 - i as u64) * 100)); + // i + // }) + // }), + // &tx, + // ); + // let (mut outcome_task_ids, values): (Vec, Vec) = rx + // .iter() + // .map(|outcome| match outcome { + // Outcome::Success { value, task_id } => (task_id, value), + // _ => panic!("unexpected error"), + // }) + // .unzip(); + // assert_eq!(values, (0..8).rev().collect::>()); + // task_ids.sort(); + // outcome_task_ids.sort(); + // assert_eq!(task_ids, outcome_task_ids); + // } + + // #[test] + // fn test_swarm_store() { + // let mut hive = Builder::new() + // .num_threads(8) + // .build_with_default::, Global<_>, Local<_>>(); + // let mut task_ids = hive.swarm_store((0..8u8).map(|i| { + // Thunk::of(move || { + // thread::sleep(Duration::from_millis((8 - i as u64) * 100)); + // i + // }) + // })); + // hive.join(); + // for i in task_ids.iter() { + // assert!(hive.outcomes_deref().get(i).unwrap().is_success()); + // } + // let (mut outcome_task_ids, values): (Vec, Vec) = task_ids + // .clone() + // .into_iter() + // .map(|i| (i, hive.remove_success(i).unwrap())) + // .collect(); + // assert_eq!(values, (0..8).collect::>()); + // task_ids.sort(); + // outcome_task_ids.sort(); + // assert_eq!(task_ids, outcome_task_ids); + // } + + // #[test] + // fn test_scan() { + // let hive: Hive<_, _, Global<_>, Local<_>> = Builder::new() + // .num_threads(4) + // .build_with(Caller::of(|i| i * i)); + // let (outputs, state) = hive.scan(0..10, 0, |acc, i| { + // *acc += i; + // *acc + // }); + // let mut outputs = outputs.unwrap(); + // outputs.sort(); + // assert_eq!(outputs.len(), 10); + // assert_eq!(state, 45); + // assert_eq!( + // outputs, + // (0..10) + // .scan(0, |acc, i| { + // *acc += i; + // Some(*acc) + // }) + // .map(|i| i * i) + // .collect::>() + // ); + // } + + // #[test] + // fn test_scan_send() { + // let hive: Hive<_, _, Global<_>, Local<_>> = Builder::new() + // .num_threads(4) + // .build_with(Caller::of(|i| i * i)); + // let (tx, rx) = super::outcome_channel(); + // let (mut task_ids, state) = hive.scan_send(0..10, &tx, 0, |acc, i| { + // *acc += i; + // *acc + // }); + // assert_eq!(task_ids.len(), 10); + // assert_eq!(state, 45); + // let (mut outcome_task_ids, mut values): (Vec, Vec) = rx + // .iter() + // .map(|outcome| match outcome { + // Outcome::Success { value, task_id } => (task_id, value), + // _ => panic!("unexpected error"), + // }) + // .unzip(); + // values.sort(); + // assert_eq!( + // values, + // (0..10) + // .scan(0, |acc, i| { + // *acc += i; + // Some(*acc) + // }) + // .map(|i| i * i) + // .collect::>() + // ); + // task_ids.sort(); + // outcome_task_ids.sort(); + // assert_eq!(task_ids, outcome_task_ids); + // } + + // #[test] + // fn test_try_scan_send() { + // let hive: Hive<_, _, Global<_>, Local<_>> = Builder::new() + // .num_threads(4) + // .build_with(Caller::of(|i| i * i)); + // let (tx, rx) = super::outcome_channel(); + // let (results, state) = hive.try_scan_send(0..10, &tx, 0, |acc, i| { + // *acc += i; + // Ok::<_, String>(*acc) + // }); + // let mut task_ids: Vec<_> = results.into_iter().map(Result::unwrap).collect(); + // assert_eq!(task_ids.len(), 10); + // assert_eq!(state, 45); + // let (mut outcome_task_ids, mut values): (Vec, Vec) = rx + // .iter() + // .map(|outcome| match outcome { + // Outcome::Success { value, task_id } => (task_id, value), + // _ => panic!("unexpected error"), + // }) + // .unzip(); + // values.sort(); + // assert_eq!( + // values, + // (0..10) + // .scan(0, |acc, i| { + // *acc += i; + // Some(*acc) + // }) + // .map(|i| i * i) + // .collect::>() + // ); + // task_ids.sort(); + // outcome_task_ids.sort(); + // assert_eq!(task_ids, outcome_task_ids); + // } + + // #[test] + // #[should_panic] + // fn test_try_scan_send_fail() { + // let hive: Hive<_, _, Global<_>, Local<_>> = Builder::new() + // .num_threads(4) + // .build_with(OnceCaller::of(|i: i32| Ok::<_, String>(i * i))); + // let (tx, _) = super::outcome_channel(); + // let _ = hive + // .try_scan_send(0..10, &tx, 0, |_, _| Err("fail")) + // .0 + // .into_iter() + // .map(Result::unwrap) + // .collect::>(); + // } + + // #[test] + // fn test_scan_store() { + // let mut hive: Hive<_, _, Global<_>, Local<_>> = Builder::new() + // .num_threads(4) + // .build_with(Caller::of(|i| i * i)); + // let (mut task_ids, state) = hive.scan_store(0..10, 0, |acc, i| { + // *acc += i; + // *acc + // }); + // assert_eq!(task_ids.len(), 10); + // assert_eq!(state, 45); + // hive.join(); + // for i in task_ids.iter() { + // assert!(hive.outcomes_deref().get(i).unwrap().is_success()); + // } + // let (mut outcome_task_ids, values): (Vec, Vec) = task_ids + // .clone() + // .into_iter() + // .map(|i| (i, hive.remove_success(i).unwrap())) + // .unzip(); + // assert_eq!( + // values, + // (0..10) + // .scan(0, |acc, i| { + // *acc += i; + // Some(*acc) + // }) + // .map(|i| i * i) + // .collect::>() + // ); + // task_ids.sort(); + // outcome_task_ids.sort(); + // assert_eq!(task_ids, outcome_task_ids); + // } + + // #[test] + // fn test_try_scan_store() { + // let mut hive: Hive<_, _, Global<_>, Local<_>> = Builder::new() + // .num_threads(4) + // .build_with(Caller::of(|i| i * i)); + // let (results, state) = hive.try_scan_store(0..10, 0, |acc, i| { + // *acc += i; + // Ok::(*acc) + // }); + // let mut task_ids: Vec<_> = results.into_iter().map(Result::unwrap).collect(); + // assert_eq!(task_ids.len(), 10); + // assert_eq!(state, 45); + // hive.join(); + // for i in task_ids.iter() { + // assert!(hive.outcomes_deref().get(i).unwrap().is_success()); + // } + // let (mut outcome_task_ids, values): (Vec, Vec) = task_ids + // .clone() + // .into_iter() + // .map(|i| (i, hive.remove_success(i).unwrap())) + // .unzip(); + // assert_eq!( + // values, + // (0..10) + // .scan(0, |acc, i| { + // *acc += i; + // Some(*acc) + // }) + // .map(|i| i * i) + // .collect::>() + // ); + // task_ids.sort(); + // outcome_task_ids.sort(); + // assert_eq!(task_ids, outcome_task_ids); + // } + + // #[test] + // #[should_panic] + // fn test_try_scan_store_fail() { + // let hive: Hive<_, _, Global<_>, Local<_>> = Builder::new() + // .num_threads(4) + // .build_with(OnceCaller::of(|i: i32| Ok::(i * i))); + // let _ = hive + // .try_scan_store(0..10, 0, |_, _| Err("fail")) + // .0 + // .into_iter() + // .map(Result::unwrap) + // .collect::>(); + // } + + // #[test] + // fn test_husk() { + // let hive1 = Builder::new() + // .num_threads(8) + // .build_with_default::, Global<_>, Local<_>>(); + // let task_ids = hive1.map_store((0..8u8).map(|i| Thunk::of(move || i))); + // hive1.join(); + // let mut husk1 = hive1.try_into_husk().unwrap(); + // for i in task_ids.iter() { + // assert!(husk1.outcomes_deref().get(i).unwrap().is_success()); + // assert!(matches!(husk1.get(*i), Some(Outcome::Success { .. }))); + // } + + // let builder = husk1.as_builder(); + // let hive2 = builder + // .num_threads(4) + // .build_with_default::, Global<_>, Local<_>>(); + // hive2.map_store((0..8u8).map(|i| { + // Thunk::of(move || { + // thread::sleep(Duration::from_millis((8 - i as u64) * 100)); + // i + // }) + // })); + // hive2.join(); + // let mut husk2 = hive2.try_into_husk().unwrap(); + + // let mut outputs1 = husk1 + // .remove_all() + // .into_iter() + // .map(Outcome::unwrap) + // .collect::>(); + // outputs1.sort(); + // let mut outputs2 = husk2 + // .remove_all() + // .into_iter() + // .map(Outcome::unwrap) + // .collect::>(); + // outputs2.sort(); + // assert_eq!(outputs1, outputs2); + + // let hive3 = husk1.into_hive::, Local<_>>(); + // hive3.map_store((0..8u8).map(|i| { + // Thunk::of(move || { + // thread::sleep(Duration::from_millis((8 - i as u64) * 100)); + // i + // }) + // })); + // hive3.join(); + // let husk3 = hive3.try_into_husk().unwrap(); + // let (_, outcomes3) = husk3.into_parts(); + // let mut outputs3 = outcomes3 + // .into_iter() + // .map(Outcome::unwrap) + // .collect::>(); + // outputs3.sort(); + // assert_eq!(outputs1, outputs3); + // } + + // #[test] + // fn test_clone() { + // let hive = Builder::new() + // .thread_name("clone example") + // .num_threads(2) + // .build_with_default::, Global<_>, Local<_>>(); + + // // This batch of tasks will occupy the pool for some time + // for _ in 0..6 { + // hive.apply_store(Thunk::of(|| { + // thread::sleep(SHORT_TASK); + // })); + // } + + // // The following tasks will be inserted into the pool in a random fashion + // let t0 = { + // let hive = hive.clone(); + // thread::spawn(move || { + // // wait for the first batch of tasks to finish + // hive.join(); + + // let (tx, rx) = mpsc::channel(); + // for i in 0..42 { + // let tx = tx.clone(); + // hive.apply_store(Thunk::of(move || { + // tx.send(i).expect("channel will be waiting"); + // })); + // } + // drop(tx); + // rx.iter().sum::() + // }) + // }; + // let t1 = { + // let pool = hive.clone(); + // thread::spawn(move || { + // // wait for the first batch of tasks to finish + // pool.join(); + + // let (tx, rx) = mpsc::channel(); + // for i in 1..12 { + // let tx = tx.clone(); + // pool.apply_store(Thunk::of(move || { + // tx.send(i).expect("channel will be waiting"); + // })); + // } + // drop(tx); + // rx.iter().product::() + // }) + // }; + + // assert_eq!( + // 861, + // t0.join() + // .expect("thread 0 will return after calculating additions",) + // ); + // assert_eq!( + // 39916800, + // t1.join() + // .expect("thread 1 will return after calculating multiplications",) + // ); + // } + + // type VoidThunkWorker = ThunkWorker<()>; + // type VoidThunkWorkerHive = Hive< + // VoidThunkWorker, + // DefaultQueen, + // ChannelGlobalQueue, + // DefaultLocalQueues>, + // >; + + // #[test] + // fn test_send() { + // fn assert_send() {} + // assert_send::(); + // } + + // #[test] + // fn test_cloned_eq() { + // let a = thunk_hive::<()>(2); + // assert_eq!(a, a.clone()); + // } + + // #[test] + // /// When a thread joins on a pool, it blocks until all tasks have completed. If a second thread + // /// adds tasks to the pool and then joins before all the tasks have completed, both threads + // /// will wait for all tasks to complete. However, as soon as all tasks have completed, all + // /// joining threads are notified, and the first one to wake will exit the join and increment + // /// the phase of the condvar. Subsequent notified threads will then see that the phase has been + // /// changed and will wake, even if new tasks have been added in the meantime. + // /// + // /// In this example, this means the waiting threads will exit the join in groups of four + // /// because the waiter pool has four processes. + // fn test_join_wavesurfer() { + // let n_waves = 4; + // let n_workers = 4; + // let (tx, rx) = mpsc::channel(); + // let builder = Builder::new() + // .num_threads(n_workers) + // .thread_name("join wavesurfer"); + // let waiter_hive = builder + // .clone() + // .build_with_default::, Global<_>, Local<_>>(); + // let clock_hive = builder.build_with_default::, Global<_>, Local<_>>(); + + // let barrier = Arc::new(Barrier::new(3)); + // let wave_counter = Arc::new(AtomicUsize::new(0)); + // let clock_thread = { + // let barrier = barrier.clone(); + // let wave_counter = wave_counter.clone(); + // thread::spawn(move || { + // barrier.wait(); + // for wave_num in 0..n_waves { + // let _ = wave_counter.swap(wave_num, Ordering::SeqCst); + // thread::sleep(ONE_SEC); + // } + // }) + // }; + + // { + // let barrier = barrier.clone(); + // clock_hive.apply_store(Thunk::of(move || { + // barrier.wait(); + // // this sleep is for stabilisation on weaker platforms + // thread::sleep(Duration::from_millis(100)); + // })); + // } + + // // prepare three waves of tasks (0..=11) + // for worker in 0..(3 * n_workers) { + // let tx = tx.clone(); + // let clock_hive = clock_hive.clone(); + // let wave_counter = wave_counter.clone(); + // waiter_hive.apply_store(Thunk::of(move || { + // let wave_before = wave_counter.load(Ordering::SeqCst); + // clock_hive.join(); + // // submit tasks for the next wave + // clock_hive.apply_store(Thunk::of(|| thread::sleep(ONE_SEC))); + // let wave_after = wave_counter.load(Ordering::SeqCst); + // tx.send((wave_before, wave_after, worker)).unwrap(); + // })); + // } + // barrier.wait(); + + // clock_hive.join(); + + // drop(tx); + // let mut hist = vec![0; n_waves]; + // let mut data = vec![]; + // for (before, after, worker) in rx.iter() { + // let mut dur = after - before; + // if dur >= n_waves - 1 { + // dur = n_waves - 1; + // } + // hist[dur] += 1; + // data.push((before, after, worker)); + // } + + // println!("Histogram of wave duration:"); + // for (i, n) in hist.iter().enumerate() { + // println!( + // "\t{}: {} {}", + // i, + // n, + // &*(0..*n).fold("".to_owned(), |s, _| s + "*") + // ); + // } + + // for (wave_before, wave_after, worker) in data.iter() { + // if *worker < n_workers { + // assert_eq!(wave_before, wave_after); + // } else { + // assert!(wave_before < wave_after); + // } + // } + // clock_thread.join().unwrap(); + // } + + // // cargo-llvm-cov doesn't yet support doctests in stable, so we need to duplicate them in + // // unit tests to get coverage + + // #[test] + // fn doctest_lib_2() { + // // create a hive to process `Thunk`s - no-argument closures with the same return type (`i32`) + // let hive = Builder::new() + // .num_threads(4) + // .thread_name("thunk_hive") + // .build_with_default::, Global<_>, Local<_>>(); + + // // return results to your own channel... + // let (tx, rx) = crate::hive::outcome_channel(); + // let task_ids = hive.swarm_send((0..10).map(|i: i32| Thunk::of(move || i * i)), &tx); + // let outputs: Vec<_> = rx.select_unordered_outputs(task_ids).collect(); + // assert_eq!(285, outputs.into_iter().sum()); + + // // return results as an iterator... + // let outputs2: Vec<_> = hive + // .swarm((0..10).map(|i: i32| Thunk::of(move || i * -i))) + // .into_outputs() + // .collect(); + // assert_eq!(-285, outputs2.into_iter().sum()); + // } + + // #[test] + // fn doctest_lib_3() { + // #[derive(Debug)] + // struct CatWorker { + // stdin: ChildStdin, + // stdout: BufReader, + // } + + // impl CatWorker { + // fn new(stdin: ChildStdin, stdout: ChildStdout) -> Self { + // Self { + // stdin, + // stdout: BufReader::new(stdout), + // } + // } + + // fn write_char(&mut self, c: u8) -> io::Result { + // self.stdin.write_all(&[c])?; + // self.stdin.write_all(b"\n")?; + // self.stdin.flush()?; + // let mut s = String::new(); + // self.stdout.read_line(&mut s)?; + // s.pop(); // exclude newline + // Ok(s) + // } + // } + + // impl Worker for CatWorker { + // type Input = u8; + // type Output = String; + // type Error = io::Error; + + // fn apply(&mut self, input: Self::Input, _: &Context) -> WorkerResult { + // self.write_char(input).map_err(|error| ApplyError::Fatal { + // input: Some(input), + // error, + // }) + // } + // } + + // #[derive(Default)] + // struct CatQueen { + // children: Vec, + // } + + // impl CatQueen { + // fn wait_for_all(&mut self) -> Vec> { + // self.children + // .drain(..) + // .map(|mut child| child.wait()) + // .collect() + // } + // } + + // impl QueenMut for CatQueen { + // type Kind = CatWorker; + + // fn create(&mut self) -> Self::Kind { + // let mut child = Command::new("cat") + // .stdin(Stdio::piped()) + // .stdout(Stdio::piped()) + // .stderr(Stdio::inherit()) + // .spawn() + // .unwrap(); + // let stdin = child.stdin.take().unwrap(); + // let stdout = child.stdout.take().unwrap(); + // self.children.push(child); + // CatWorker::new(stdin, stdout) + // } + // } + + // impl Drop for CatQueen { + // fn drop(&mut self) { + // self.wait_for_all() + // .into_iter() + // .for_each(|result| match result { + // Ok(status) if status.success() => (), + // Ok(status) => eprintln!("Child process failed: {}", status), + // Err(e) => eprintln!("Error waiting for child process: {}", e), + // }) + // } + // } + + // // build the Hive + // let hive = Builder::new() + // .num_threads(4) + // .build_default::, Global<_>, Local<_>>(); + + // // prepare inputs + // let inputs: Vec = (0..8).map(|i| 97 + i).collect(); + + // // execute tasks and collect outputs + // let output = hive + // .swarm(inputs) + // .into_outputs() + // .fold(String::new(), |mut a, b| { + // a.push_str(&b); + // a + // }) + // .into_bytes(); + + // // verify the output - note that `swarm` ensures the outputs are in the same order + // // as the inputs + // assert_eq!(output, b"abcdefgh"); + + // // shutdown the hive, use the Queen to wait on child processes, and report errors + // let mut queen = hive.try_into_husk().unwrap().into_parts().0.into_inner(); + // let (wait_ok, wait_err): (Vec<_>, Vec<_>) = + // queen.wait_for_all().into_iter().partition(Result::is_ok); + // if !wait_err.is_empty() { + // panic!( + // "Error(s) occurred while waiting for child processes: {:?}", + // wait_err + // ); + // } + // let exec_err_codes: Vec<_> = wait_ok + // .into_iter() + // .map(Result::unwrap) + // .filter(|status| !status.success()) + // .filter_map(|status| status.code()) + // .collect(); + // if !exec_err_codes.is_empty() { + // panic!( + // "Child process(es) failed with exit codes: {:?}", + // exec_err_codes + // ); + // } + // } + // } + + // #[cfg(all(test, feature = "affinity"))] + // mod affinity_tests { + // use crate::bee::stock::{Thunk, ThunkWorker}; + // use crate::hive::queue::{ChannelGlobalQueue, DefaultLocalQueues}; + // use crate::hive::Builder; + + // type Global = ChannelGlobalQueue; + // type Local = DefaultLocalQueues>; + + // #[test] + // fn test_affinity() { + // let hive = Builder::new() + // .thread_name("affinity example") + // .num_threads(2) + // .core_affinity(0..2) + // .build_with_default::, Global<_>, Local<_>>(); + + // hive.map_store((0..10).map(move |i| { + // Thunk::of(move || { + // if let Some(affininty) = core_affinity::get_core_ids() { + // eprintln!("task {} on thread with affinity {:?}", i, affininty); + // } + // }) + // })); + // } + + // #[test] + // fn test_use_all_cores() { + // let hive = Builder::new() + // .thread_name("affinity example") + // .with_thread_per_core() + // .with_default_core_affinity() + // .build_with_default::, Global<_>, Local<_>>(); + + // hive.map_store((0..num_cpus::get()).map(move |i| { + // Thunk::of(move || { + // if let Some(affininty) = core_affinity::get_core_ids() { + // eprintln!("task {} on thread with affinity {:?}", i, affininty); + // } + // }) + // })); + // } + // } + + // #[cfg(all(test, feature = "batching"))] + // mod batching_tests { + // use crate::barrier::IndexedBarrier; + // use crate::bee::stock::{Thunk, ThunkWorker}; + // use crate::bee::DefaultQueen; + // use crate::hive::queue::{ChannelGlobalQueue, DefaultLocalQueues}; + // use crate::hive::{Builder, Hive, OutcomeIteratorExt, OutcomeReceiver, OutcomeSender}; + // use std::collections::HashMap; + // use std::thread::{self, ThreadId}; + // use std::time::Duration; + + // type Global = ChannelGlobalQueue; + // type Local = DefaultLocalQueues>; + + // fn launch_tasks( + // hive: &Hive< + // ThunkWorker, + // DefaultQueen>, + // ChannelGlobalQueue>, + // DefaultLocalQueues, ChannelGlobalQueue>>, + // >, + // num_threads: usize, + // num_tasks_per_thread: usize, + // barrier: &IndexedBarrier, + // tx: &OutcomeSender>, + // ) -> Vec { + // let total_tasks = num_threads * num_tasks_per_thread; + // // send the first `num_threads` tasks widely spaced, so each worker thread only gets one + // let init_task_ids: Vec<_> = (0..num_threads) + // .map(|_| { + // let barrier = barrier.clone(); + // let task_id = hive.apply_send( + // Thunk::of(move || { + // barrier.wait(); + // thread::sleep(Duration::from_millis(100)); + // thread::current().id() + // }), + // tx, + // ); + // thread::sleep(Duration::from_millis(100)); + // task_id + // }) + // .collect(); + // // send the rest all at once + // let rest_task_ids = hive.map_send( + // (num_threads..total_tasks).map(|_| { + // Thunk::of(move || { + // thread::sleep(Duration::from_millis(1)); + // thread::current().id() + // }) + // }), + // tx, + // ); + // init_task_ids.into_iter().chain(rest_task_ids).collect() + // } + + // fn count_thread_ids( + // rx: OutcomeReceiver>, + // task_ids: Vec, + // ) -> HashMap { + // rx.select_unordered_outputs(task_ids) + // .fold(HashMap::new(), |mut counter, id| { + // *counter.entry(id).or_insert(0) += 1; + // counter + // }) + // } + + // fn run_test( + // hive: &Hive< + // ThunkWorker, + // DefaultQueen>, + // ChannelGlobalQueue>, + // DefaultLocalQueues, ChannelGlobalQueue>>, + // >, + // num_threads: usize, + // batch_size: usize, + // ) { + // let tasks_per_thread = batch_size + 2; + // let (tx, rx) = crate::hive::outcome_channel(); + // // each worker should take `batch_size` tasks for its queue + 1 to work on immediately, + // // meaning there should be `batch_size + 1` tasks associated with each thread ID + // let barrier = IndexedBarrier::new(num_threads); + // let task_ids = launch_tasks(hive, num_threads, tasks_per_thread, &barrier, &tx); + // // start the first tasks + // barrier.wait(); + // // wait for all tasks to complete + // hive.join(); + // let thread_counts = count_thread_ids(rx, task_ids); + // assert_eq!(thread_counts.len(), num_threads); + // assert!(thread_counts + // .values() + // .all(|&count| count == tasks_per_thread)); + // } + + // #[test] + // fn test_batching() { + // const NUM_THREADS: usize = 4; + // const BATCH_SIZE: usize = 24; + // let hive = Builder::new() + // .num_threads(NUM_THREADS) + // .batch_size(BATCH_SIZE) + // .build_with_default::, Global<_>, Local<_>>(); + // run_test(&hive, NUM_THREADS, BATCH_SIZE); + // } + + // #[test] + // fn test_set_batch_size() { + // const NUM_THREADS: usize = 4; + // const BATCH_SIZE_0: usize = 10; + // const BATCH_SIZE_1: usize = 20; + // const BATCH_SIZE_2: usize = 50; + // let hive = Builder::new() + // .num_threads(NUM_THREADS) + // .batch_size(BATCH_SIZE_0) + // .build_with_default::, Global<_>, Local<_>>(); + // run_test(&hive, NUM_THREADS, BATCH_SIZE_0); + // // increase batch size + // hive.set_worker_batch_size(BATCH_SIZE_2); + // run_test(&hive, NUM_THREADS, BATCH_SIZE_2); + // // decrease batch size + // hive.set_worker_batch_size(BATCH_SIZE_1); + // run_test(&hive, NUM_THREADS, BATCH_SIZE_1); + // } + + // #[test] + // fn test_shrink_batch_size() { + // const NUM_THREADS: usize = 4; + // const NUM_TASKS_PER_THREAD: usize = 125; + // const BATCH_SIZE_0: usize = 100; + // const BATCH_SIZE_1: usize = 10; + // let hive = Builder::new() + // .num_threads(NUM_THREADS) + // .batch_size(BATCH_SIZE_0) + // .build_with_default::, Global<_>, Local<_>>(); + // let (tx, rx) = crate::hive::outcome_channel(); + // let barrier = IndexedBarrier::new(NUM_THREADS); + // let task_ids = launch_tasks(&hive, NUM_THREADS, NUM_TASKS_PER_THREAD, &barrier, &tx); + // let total_tasks = NUM_THREADS * NUM_TASKS_PER_THREAD; + // assert_eq!(task_ids.len(), total_tasks); + // barrier.wait(); + // hive.set_worker_batch_size(BATCH_SIZE_1); + // // The number of tasks completed by each thread could be variable, so we want to ensure + // // that a) each processed at least `BATCH_SIZE_0` tasks, and b) there are a total of + // // `NUM_TASKS` outputs with no errors + // hive.join(); + // let thread_counts = count_thread_ids(rx, task_ids); + // assert!(thread_counts.values().all(|count| *count > BATCH_SIZE_0)); + // assert_eq!(thread_counts.values().sum::(), total_tasks); + // } + // } + + // #[cfg(all(test, feature = "retry"))] + // mod retry_tests { + // use crate::bee::stock::RetryCaller; + // use crate::bee::{ApplyError, Context}; + // use crate::hive::queue::{ChannelGlobalQueue, DefaultLocalQueues}; + // use crate::hive::{Builder, Hive, Outcome, OutcomeIteratorExt}; + // use std::time::{Duration, SystemTime}; + + // type Global = ChannelGlobalQueue; + // type Local = DefaultLocalQueues>; + + // fn echo_time(i: usize, ctx: &Context) -> Result> { + // let attempt = ctx.attempt(); + // if attempt == 3 { + // Ok("Success".into()) + // } else { + // // the delay between each message should be exponential + // eprintln!("Task {} attempt {}: {:?}", i, attempt, SystemTime::now()); + // Err(ApplyError::Retryable { + // input: i, + // error: "Retryable".into(), + // }) + // } + // } + + // #[test] + // fn test_retries() { + // let hive: Hive<_, _, Global<_>, Local<_>> = Builder::new() + // .with_thread_per_core() + // .max_retries(3) + // .retry_factor(Duration::from_secs(1)) + // .build_with(RetryCaller::of(echo_time)); + + // let v: Result, _> = hive.swarm(0..10).into_results().collect(); + // assert_eq!(v.unwrap().len(), 10); + // } + + // #[test] + // fn test_retries_fail() { + // fn sometimes_fail( + // i: usize, + // _: &Context, + // ) -> Result> { + // match i % 3 { + // 0 => Ok("Success".into()), + // 1 => Err(ApplyError::Retryable { + // input: i, + // error: "Retryable".into(), + // }), + // 2 => Err(ApplyError::Fatal { + // input: Some(i), + // error: "Fatal".into(), + // }), + // _ => unreachable!(), + // } + // } + + // let hive: Hive<_, _, Global<_>, Local<_>> = Builder::new() + // .with_thread_per_core() + // .max_retries(3) + // .build_with(RetryCaller::of(sometimes_fail)); + + // let (success, retry_failed, not_retried) = hive.swarm(0..10).fold( + // (0, 0, 0), + // |(success, retry_failed, not_retried), outcome| match outcome { + // Outcome::Success { .. } => (success + 1, retry_failed, not_retried), + // Outcome::MaxRetriesAttempted { .. } => (success, retry_failed + 1, not_retried), + // Outcome::Failure { .. } => (success, retry_failed, not_retried + 1), + // _ => unreachable!(), + // }, + // ); + + // assert_eq!(success, 4); + // assert_eq!(retry_failed, 3); + // assert_eq!(not_retried, 3); + // } + + // #[test] + // fn test_disable_retries() { + // let hive: Hive<_, _, Global<_>, Local<_>> = Builder::new() + // .with_thread_per_core() + // .with_no_retries() + // .build_with(RetryCaller::of(echo_time)); + // let v: Result, _> = hive.swarm(0..10).into_results().collect(); + // assert!(v.is_err()); + // } } diff --git a/src/hive/outcome/store.rs b/src/hive/outcome/store.rs index b38835d..eb19608 100644 --- a/src/hive/outcome/store.rs +++ b/src/hive/outcome/store.rs @@ -1,5 +1,6 @@ -use super::{DerefOutcomes, Outcome}; +use super::Outcome; use crate::bee::{TaskId, Worker}; +use sealed::DerefOutcomes; /// Traits with methods that should only be accessed internally by public traits. pub mod sealed { @@ -15,7 +16,7 @@ pub mod sealed { fn outcomes_deref(&self) -> impl Deref>>; /// Returns a mutable reference to a map of task task_id to `Outcome`. - fn outcomes_deref_mut(&mut self) -> impl DerefMut>>; + fn outcomes_deref_mut(&mut self) -> impl DerefMut>> + '_; } pub trait OwnedOutcomes: Sized { diff --git a/src/hive/shared.rs b/src/hive/shared.rs index da8007e..823bbbd 100644 --- a/src/hive/shared.rs +++ b/src/hive/shared.rs @@ -1,30 +1,30 @@ use super::{ - Config, GlobalQueue, Husk, LocalQueues, Outcome, OutcomeSender, Shared, SpawnError, Task, + Config, GlobalQueue, Husk, LocalQueues, Outcome, OutcomeSender, QueuePair, Shared, SpawnError, + Task, }; use crate::atomic::{Atomic, AtomicInt, AtomicUsize}; use crate::bee::{Queen, TaskId, Worker}; use crate::channel::SenderExt; -use parking_lot::Mutex; use std::collections::HashMap; use std::ops::DerefMut; use std::thread::{Builder, JoinHandle}; use std::{fmt, iter}; -impl Shared +impl Shared where W: Worker, Q: Queen, - G: GlobalQueue, - L: LocalQueues, + P: QueuePair, { /// Creates a new `Shared` instance with the given configuration, queen, and task receiver, /// and all other fields set to their default values. - pub fn new(config: Config, global_queue: G, queen: Q) -> Self { + pub fn new(config: Config, queen: Q) -> Self { + let (global_queue, local_queues) = P::new(); Shared { config, + queen, global_queue, - queen: Mutex::new(queen), - local_queues: Default::default(), + local_queues, spawn_results: Default::default(), num_tasks: Default::default(), next_task_id: Default::default(), @@ -145,7 +145,7 @@ where /// Returns a new `Worker` from the queen, or an error if a `Worker` could not be created. pub fn create_worker(&self) -> Q::Kind { - self.queen.lock().create() + self.queen.create() } /// Increments the number of queued tasks. Returns a new `Task` with the provided input and @@ -180,6 +180,9 @@ where input: W::Input, outcome_tx: Option<&OutcomeSender>, ) -> TaskId { + if self.config.num_threads.get_or_default() == 0 { + dbg!("WARNING: no worker threads are active for hive"); + } let task = self.prepare_task(input, outcome_tx); let task_id = task.id(); self.push_global(task); @@ -211,6 +214,10 @@ where T: IntoIterator, T::IntoIter: ExactSizeIterator, { + #[cfg(debug_assertions)] + if self.config.num_threads.get_or_default() == 0 { + dbg!("WARNING: no worker threads are active for hive"); + } let iter = inputs.into_iter(); let (min_size, _) = iter.size_hint(); self.num_tasks @@ -493,19 +500,18 @@ where self.drain_tasks_into_unprocessed(); Husk::new( self.config.into_unsync(), - self.queen.into_inner(), + self.queen, self.num_panics.into_inner(), self.outcomes.into_inner(), ) } } -impl fmt::Debug for Shared +impl fmt::Debug for Shared where W: Worker, Q: Queen, - G: GlobalQueue, - L: LocalQueues, + P: QueuePair, { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let (queued, active) = self.num_tasks(); @@ -522,14 +528,13 @@ where mod affinity { use crate::bee::{Queen, Worker}; use crate::hive::cores::{Core, Cores}; - use crate::hive::{GlobalQueue, LocalQueues, Shared}; + use crate::hive::{QueuePair, Shared}; - impl Shared + impl Shared where W: Worker, Q: Queen, - G: GlobalQueue, - L: LocalQueues, + P: QueuePair, { /// Adds cores to which worker threads may be pinned. pub fn add_core_affinity(&self, new_cores: Cores) { @@ -552,14 +557,13 @@ mod affinity { #[cfg(feature = "batching")] mod batching { use crate::bee::{Queen, Worker}; - use crate::hive::{GlobalQueue, LocalQueues, Shared}; + use crate::hive::{LocalQueues, QueuePair, Shared}; - impl Shared + impl Shared where W: Worker, Q: Queen, - G: GlobalQueue, - L: LocalQueues, + P: QueuePair, { /// Returns the local queue batch size. pub fn batch_size(&self) -> usize { @@ -596,15 +600,14 @@ mod batching { #[cfg(feature = "retry")] mod retry { use crate::bee::{Queen, Worker}; - use crate::hive::{GlobalQueue, LocalQueues, OutcomeSender, Shared, Task, TaskId}; + use crate::hive::{LocalQueues, OutcomeSender, QueuePair, Shared, Task, TaskId}; use std::time::Instant; - impl Shared + impl Shared where W: Worker, Q: Queen, - G: GlobalQueue, - L: LocalQueues, + P: QueuePair, { /// Returns `true` if the hive is configured to retry tasks and the `attempt` field of the /// given `ctx` is less than the maximum number of retries. @@ -639,15 +642,13 @@ mod retry { mod tests { use crate::bee::stock::ThunkWorker; use crate::bee::DefaultQueen; - use crate::hive::task::ChannelGlobalQueue; - use crate::hive::ChannelLocalQueues; + use crate::hive::queue::ChannelQueues; type VoidThunkWorker = ThunkWorker<()>; type VoidThunkWorkerShared = super::Shared< VoidThunkWorker, DefaultQueen, - ChannelGlobalQueue, - ChannelLocalQueues, + ChannelQueues, >; #[test] diff --git a/src/hive/task/delay.rs b/src/hive/task/delay.rs deleted file mode 100644 index 7faa6c4..0000000 --- a/src/hive/task/delay.rs +++ /dev/null @@ -1,164 +0,0 @@ -use std::cell::UnsafeCell; -use std::cmp::Ordering; -use std::collections::BinaryHeap; -use std::time::{Duration, Instant}; - -/// A queue where each item has an associated `Instant` at which it will be available. -/// -/// This is implemented internally as a `UnsafeCell`. -/// -/// SAFETY: This data structure is designed to enable the queue to be modified by a *single thread* -/// using interior mutability. `UnsafeCell` is used for performance - this is safe so long as the -/// queue is only accessed from a single thread at a time. This data structure is *not* thread-safe. -#[derive(Debug)] -pub struct DelayQueue(UnsafeCell>>); - -impl DelayQueue { - /// Returns the number of items currently in the queue. - pub fn len(&self) -> usize { - unsafe { self.0.get().as_ref().unwrap().len() } - } - - /// Pushes an item onto the queue. Returns the `Instant` at which the item will be available, - /// or an error with `item` if there was an error pushing the item. - pub fn push(&self, item: T, delay: Duration) -> Result { - unsafe { - match self.0.get().as_mut() { - Some(queue) => { - let delayed = Delayed::new(item, delay); - let until = delayed.until; - queue.push(delayed); - Ok(until) - } - None => Err(item), - } - } - } - - /// Returns the `Instant` at which the next item will be available. Returns `None` if the queue - /// is empty. - pub fn next_available(&self) -> Option { - unsafe { - self.0 - .get() - .as_ref() - .and_then(|queue| queue.peek().map(|head| head.until)) - } - } - - /// Returns the item at the head of the queue, if one exists and is available (i.e., its delay - /// has been exceeded), and removes it. - pub fn try_pop(&self) -> Option { - unsafe { - if self - .next_available() - .map(|until| until <= Instant::now()) - .unwrap_or(false) - { - self.0 - .get() - .as_mut() - .and_then(|queue| queue.pop()) - .map(|delayed| delayed.value) - } else { - None - } - } - } - - /// Drains all items from the queue and returns them as an iterator. - pub fn drain(&mut self) -> impl Iterator + '_ { - self.0.get_mut().drain().map(|delayed| delayed.value) - } -} - -unsafe impl Sync for DelayQueue {} - -impl Default for DelayQueue { - fn default() -> Self { - DelayQueue(UnsafeCell::new(BinaryHeap::new())) - } -} - -#[derive(Debug)] -struct Delayed { - value: T, - until: Instant, -} - -impl Delayed { - pub fn new(value: T, delay: Duration) -> Self { - Delayed { - value, - until: Instant::now() + delay, - } - } -} - -/// Implements ordering for `Delayed`, so it can be used to correctly order elements in the -/// `BinaryHeap` of the `DelayQueue`. -/// -/// Earlier entries have higher priority (should be popped first), so they are Greater that later -/// entries. -impl Ord for Delayed { - fn cmp(&self, other: &Delayed) -> Ordering { - other.until.cmp(&self.until) - } -} - -impl PartialOrd for Delayed { - fn partial_cmp(&self, other: &Delayed) -> Option { - Some(self.cmp(other)) - } -} - -impl PartialEq for Delayed { - fn eq(&self, other: &Delayed) -> bool { - self.cmp(other) == Ordering::Equal - } -} - -impl Eq for Delayed {} - -#[cfg(test)] -mod tests { - use super::DelayQueue; - use std::{thread, time::Duration}; - - #[test] - fn test_works() { - let queue = DelayQueue::default(); - - queue.push(1, Duration::from_secs(1)).unwrap(); - queue.push(2, Duration::from_secs(2)).unwrap(); - queue.push(3, Duration::from_secs(3)).unwrap(); - - assert_eq!(queue.len(), 3); - assert_eq!(queue.try_pop(), None); - - thread::sleep(Duration::from_secs(1)); - assert_eq!(queue.try_pop(), Some(1)); - assert_eq!(queue.len(), 2); - - thread::sleep(Duration::from_secs(1)); - assert_eq!(queue.try_pop(), Some(2)); - assert_eq!(queue.len(), 1); - - thread::sleep(Duration::from_secs(1)); - assert_eq!(queue.try_pop(), Some(3)); - assert_eq!(queue.len(), 0); - - assert_eq!(queue.try_pop(), None); - } - - #[test] - fn test_into_vec() { - let mut queue = DelayQueue::default(); - queue.push(1, Duration::from_secs(1)).unwrap(); - queue.push(2, Duration::from_secs(2)).unwrap(); - queue.push(3, Duration::from_secs(3)).unwrap(); - let mut v: Vec<_> = queue.drain().collect(); - v.sort(); - assert_eq!(v, vec![1, 2, 3]); - } -} diff --git a/src/hive/task/global.rs b/src/hive/task/global.rs deleted file mode 100644 index 9311ee6..0000000 --- a/src/hive/task/global.rs +++ /dev/null @@ -1,77 +0,0 @@ -pub use channel::GlobalQueueImpl as ChannelGlobalQueue; - -mod channel { - use crate::atomic::{Atomic, AtomicBool}; - use crate::bee::Worker; - use crate::hive::{GlobalPopError, GlobalQueue, Task, TaskReceiver, TaskSender}; - use crossbeam_channel::RecvTimeoutError; - use std::time::Duration; - - pub struct GlobalQueueImpl { - tx: TaskSender, - rx: TaskReceiver, - closed: AtomicBool, - } - - impl GlobalQueueImpl { - /// Returns a new `GlobalQueue` that uses the given channel sender for pushing new tasks - /// and the given channel receiver for popping tasks. - pub fn new(tx: TaskSender, rx: TaskReceiver) -> Self { - Self { - tx, - rx, - closed: AtomicBool::default(), - } - } - - #[cfg(feature = "batching")] - pub fn try_iter(&self) -> crossbeam_channel::TryIter> { - self.rx.try_iter() - } - - pub fn try_pop_timeout( - &self, - timeout: Duration, - ) -> Option, GlobalPopError>> { - match self.rx.recv_timeout(timeout) { - Ok(task) => Some(Ok(task)), - Err(RecvTimeoutError::Disconnected) => Some(Err(GlobalPopError::Closed)), - Err(RecvTimeoutError::Timeout) if self.closed.get() && self.rx.is_empty() => { - Some(Err(GlobalPopError::Closed)) - } - Err(RecvTimeoutError::Timeout) => None, - } - } - } - - impl GlobalQueue for GlobalQueueImpl { - fn try_push(&self, task: Task) -> Result<(), Task> { - if !self.closed.get() { - self.tx.try_send(task).map_err(|err| err.into_inner()) - } else { - Err(task) - } - } - - fn try_pop(&self) -> Option, GlobalPopError>> { - // time to wait in between polling the retry queue and then the task receiver - const RECV_TIMEOUT: Duration = Duration::from_secs(1); - self.try_pop_timeout(RECV_TIMEOUT) - } - - fn drain(&self) -> Vec> { - self.rx.try_iter().collect() - } - - fn close(&self) { - self.closed.set(true); - } - } - - impl Default for GlobalQueueImpl { - fn default() -> Self { - let (tx, rx) = crossbeam_channel::unbounded(); - Self::new(tx, rx) - } - } -} diff --git a/src/hive/task/local.rs b/src/hive/task/local.rs deleted file mode 100644 index 1a8e1e5..0000000 --- a/src/hive/task/local.rs +++ /dev/null @@ -1,386 +0,0 @@ -#[cfg(any(feature = "batching", feature = "retry"))] -pub use channel::LocalQueuesImpl as ChannelLocalQueues; -#[cfg(not(any(feature = "batching", feature = "retry")))] -pub use null::LocalQueuesImpl as ChannelLocalQueues; - -#[cfg(not(any(feature = "batching", feature = "retry")))] -mod null { - use crate::bee::{Queen, Worker}; - use crate::hive::{ChannelGlobalQueue, LocalQueues, Shared, Task}; - use std::marker::PhantomData; - - pub struct LocalQueuesImpl(PhantomData W>); - - impl LocalQueues> for LocalQueuesImpl { - fn init_for_threads>( - &self, - _: usize, - _: usize, - _: &Shared, Self>, - ) { - } - - #[inline(always)] - fn push>( - &self, - task: Task, - _: usize, - shared: &Shared, Self>, - ) { - shared.push_global(task); - } - - #[inline(always)] - fn try_pop>( - &self, - _: usize, - _: &Shared, Self>, - ) -> Option> { - None - } - - fn drain(&self) -> Vec> { - Vec::new() - } - } - - impl Default for LocalQueuesImpl { - fn default() -> Self { - Self(PhantomData) - } - } -} - -#[cfg(any(feature = "batching", feature = "retry"))] -mod channel { - use crate::bee::{Queen, Worker}; - use crate::hive::task::ChannelGlobalQueue; - use crate::hive::{LocalQueues, Shared, Task}; - use parking_lot::RwLock; - - pub struct LocalQueuesImpl { - /// thread-local queues of tasks used when the `batching` feature is enabled - #[cfg(feature = "batching")] - batch_queues: RwLock>>>, - /// thread-local queues used for tasks that are waiting to be retried after a failure - #[cfg(feature = "retry")] - retry_queues: RwLock>>>, - } - - #[cfg(feature = "retry")] - impl LocalQueuesImpl { - #[inline] - fn try_pop_retry(&self, thread_index: usize) -> Option> { - self.retry_queues - .read() - .get(thread_index) - .and_then(|queue| queue.try_pop()) - } - } - - #[cfg(feature = "batching")] - impl LocalQueuesImpl { - // time to wait in between polling the retry queue and then the task receiver - const POP_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(1); - - #[inline] - fn try_push_local(&self, task: Task, thread_index: usize) -> Result<(), Task> { - self.batch_queues.read()[thread_index].push(task) - } - - #[inline] - fn try_pop_local_or_refill>( - &self, - thread_index: usize, - shared: &Shared, Self>, - ) -> Option> { - let local_queue = &self.batch_queues.read()[thread_index]; - // pop from the local queue if it has any tasks - if !local_queue.is_empty() { - return local_queue.pop(); - } - // otherwise pull at least 1 and up to `batch_size + 1` tasks from the input channel - // wait for the next task from the receiver - let first = shared - .global_queue - .try_pop_timeout(Self::POP_TIMEOUT) - .and_then(Result::ok); - // if we fail after trying to get one, don't keep trying to fill the queue - if first.is_some() { - let batch_size = shared.batch_size(); - // batch size 0 means batching is disabled - if batch_size > 0 { - // otherwise try to take up to `batch_size` tasks from the input channel - // and add them to the local queue, but don't block if the input channel - // is empty - for result in shared - .global_queue - .try_iter() - .take(batch_size) - .map(|task| local_queue.push(task)) - { - if let Err(task) = result { - // for some reason we can't push the task to the local queue; - // this should never happen, but just in case we turn it into an - // unprocessed outcome and stop iterating - shared.abandon_task(task); - break; - } - } - } - } - first - } - } - - impl LocalQueues> for LocalQueuesImpl { - fn init_for_threads>( - &self, - start_index: usize, - end_index: usize, - #[allow(unused_variables)] shared: &Shared, Self>, - ) { - #[cfg(feature = "batching")] - self.init_batch_queues_for_threads(start_index, end_index, shared); - #[cfg(feature = "retry")] - self.init_retry_queues_for_threads(start_index, end_index); - } - - #[cfg(feature = "batching")] - fn resize>( - &self, - start_index: usize, - end_index: usize, - new_size: usize, - shared: &Shared, Self>, - ) { - self.resize_batch_queues(start_index, end_index, new_size, shared); - } - - /// Creates a task from `input` and pushes it to the local queue if there is space, - /// otherwise attempts to add it to the global queue. Returns the task ID if the push - /// succeeds, otherwise returns an error with the input. - fn push>( - &self, - task: Task, - #[allow(unused_variables)] thread_index: usize, - shared: &Shared, Self>, - ) { - #[cfg(feature = "batching")] - let task = match self.try_push_local(task, thread_index) { - Ok(_) => return, - Err(task) => task, - }; - shared.push_global(task); - } - - /// Returns the next task from the local queue if there are any, otherwise attempts to - /// fetch at least 1 and up to `batch_size + 1` tasks from the input channel and puts all - /// but the first one into the local queue. - fn try_pop>( - &self, - thread_index: usize, - #[allow(unused_variables)] shared: &Shared, Self>, - ) -> Option> { - #[cfg(feature = "retry")] - if let Some(task) = self.try_pop_retry(thread_index) { - return Some(task); - } - #[cfg(feature = "batching")] - if let Some(task) = self.try_pop_local_or_refill(thread_index, shared) { - return Some(task); - } - None - } - - fn drain(&self) -> Vec> { - let mut tasks = Vec::new(); - #[cfg(feature = "batching")] - { - self.drain_batch_queues_into(&mut tasks); - } - #[cfg(feature = "retry")] - { - self.drain_retry_queues_into(&mut tasks); - } - tasks - } - - #[cfg(feature = "retry")] - fn retry>( - &self, - task: Task, - thread_index: usize, - shared: &Shared, Self>, - ) -> Option { - self.try_push_retry(task, thread_index, shared) - } - } - - impl Default for LocalQueuesImpl { - fn default() -> Self { - Self { - #[cfg(feature = "batching")] - batch_queues: Default::default(), - #[cfg(feature = "retry")] - retry_queues: Default::default(), - } - } - } - - #[cfg(feature = "batching")] - mod batching { - use super::LocalQueuesImpl; - use crate::bee::{Queen, Worker}; - use crate::hive::{ChannelGlobalQueue, Shared, Task}; - use crossbeam_queue::ArrayQueue; - use std::collections::HashSet; - use std::time::Duration; - - impl LocalQueuesImpl { - pub(super) fn init_batch_queues_for_threads>( - &self, - start_index: usize, - end_index: usize, - shared: &Shared, Self>, - ) { - let mut batch_queues = self.batch_queues.write(); - assert_eq!(batch_queues.len(), start_index); - let queue_size = shared.batch_size().max(1); - (start_index..end_index) - .for_each(|_| batch_queues.push(ArrayQueue::new(queue_size))); - } - - pub(super) fn resize_batch_queues>( - &self, - start_index: usize, - end_index: usize, - batch_size: usize, - shared: &Shared, Self>, - ) { - // keep track of which queues need to be resized - // TODO: this method could cause a hang if one of the worker threads is stuck - we - // might want to keep track of each queue's size and if we don't see it shrink - // within a certain amount of time, we give up on that thread and leave it with a - // wrong-sized queue (which should never cause a panic) - let mut to_resize: HashSet = (start_index..end_index).collect(); - // iterate until we've resized them all - loop { - // scope the mutable access to local_queues - { - let mut batch_queues = self.batch_queues.write(); - to_resize.retain(|thread_index| { - let queue = if let Some(queue) = batch_queues.get_mut(*thread_index) { - queue - } else { - return false; - }; - if queue.len() > batch_size { - return true; - } - let new_queue = ArrayQueue::new(batch_size); - while let Some(task) = queue.pop() { - if let Err(task) = new_queue.push(task) { - // for some reason we can't push the task to the new queue - // this should never happen, but just in case we turn it into - // an unprocessed outcome - shared.abandon_task(task); - } - } - // this is safe because the worker threads can't get readable access to the - // queue while this thread holds the lock - let old_queue = std::mem::replace(queue, new_queue); - assert!(old_queue.is_empty()); - false - }); - } - if !to_resize.is_empty() { - // short sleep to give worker threads the chance to pull from their queues - std::thread::sleep(Duration::from_millis(10)); - } - } - } - - pub(super) fn drain_batch_queues_into(&self, tasks: &mut Vec>) { - let _ = self - .batch_queues - .write() - .iter_mut() - .fold(tasks, |tasks, queue| { - tasks.reserve(queue.len()); - while let Some(task) = queue.pop() { - tasks.push(task); - } - tasks - }); - } - } - } - - #[cfg(feature = "retry")] - mod retry { - use super::LocalQueuesImpl; - use crate::bee::{Queen, Worker}; - use crate::hive::task::delay::DelayQueue; - use crate::hive::{ChannelGlobalQueue, Shared, Task}; - use std::time::{Duration, Instant}; - - impl LocalQueuesImpl { - /// Initializes the retry queues worker threads in the specified range. - pub(super) fn init_retry_queues_for_threads( - &self, - start_index: usize, - end_index: usize, - ) { - let mut retry_queues = self.retry_queues.write(); - assert_eq!(retry_queues.len(), start_index); - (start_index..end_index).for_each(|_| retry_queues.push(DelayQueue::default())) - } - - /// Adds a task to the retry queue with a delay based on `attempt`. - pub(super) fn try_push_retry>( - &self, - task: Task, - thread_index: usize, - shared: &Shared, Self>, - ) -> Option { - // compute the delay - let delay = shared - .config - .retry_factor - .get() - .map(|retry_factor| { - 2u64.checked_pow(task.attempt - 1) - .and_then(|multiplier| { - retry_factor - .checked_mul(multiplier) - .or(Some(u64::MAX)) - .map(Duration::from_nanos) - }) - .unwrap() - }) - .unwrap_or_default(); - if let Some(queue) = self.retry_queues.read().get(thread_index) { - queue.push(task, delay) - } else { - Err(task) - } - // if unable to queue the task, abandon it - .map_err(|task| shared.abandon_task(task)) - .ok() - } - - pub(super) fn drain_retry_queues_into(&self, tasks: &mut Vec>) { - let _ = self - .retry_queues - .write() - .iter_mut() - .fold(tasks, |tasks, queue| { - tasks.reserve(queue.len()); - tasks.extend(queue.drain()); - tasks - }); - } - } - } -} diff --git a/src/hive/task/mod.rs b/src/hive/task/mod.rs deleted file mode 100644 index 0728e90..0000000 --- a/src/hive/task/mod.rs +++ /dev/null @@ -1,75 +0,0 @@ -#[cfg(feature = "retry")] -mod delay; -mod global; -mod local; - -pub use global::ChannelGlobalQueue; -pub use local::ChannelLocalQueues; - -use super::{Outcome, OutcomeSender, Task}; -use crate::bee::{TaskId, Worker}; - -impl Task { - /// Returns the ID of this task. - pub fn id(&self) -> TaskId { - self.id - } - - /// Consumes this `Task` and returns a `Outcome::Unprocessed` outcome with the input and ID, - /// and the outcome sender. - pub fn into_unprocessed(self) -> (Outcome, Option>) { - let outcome = Outcome::Unprocessed { - input: self.input, - task_id: self.id, - }; - (outcome, self.outcome_tx) - } -} - -#[cfg(not(feature = "retry"))] -impl Task { - /// Creates a new `Task`. - pub fn new(id: TaskId, input: W::Input, outcome_tx: Option>) -> Self { - Task { - id, - input, - outcome_tx, - } - } - - pub fn into_parts(self) -> (TaskId, W::Input, Option>) { - (self.id, self.input, self.outcome_tx) - } -} - -#[cfg(feature = "retry")] -impl Task { - /// Creates a new `Task`. - pub fn new(id: TaskId, input: W::Input, outcome_tx: Option>) -> Self { - Task { - id, - input, - outcome_tx, - attempt: 0, - } - } - - /// Creates a new `Task`. - pub fn with_attempt( - id: TaskId, - input: W::Input, - outcome_tx: Option>, - attempt: u32, - ) -> Self { - Task { - id, - input, - outcome_tx, - attempt, - } - } - - pub fn into_parts(self) -> (TaskId, W::Input, u32, Option>) { - (self.id, self.input, self.attempt, self.outcome_tx) - } -} diff --git a/src/util.rs b/src/util.rs index f326057..1160293 100644 --- a/src/util.rs +++ b/src/util.rs @@ -4,7 +4,9 @@ //! creating the [`Hive`](crate::hive::Hive), submitting tasks, collecting results, and shutting //! down the `Hive` properly. use crate::bee::stock::{Caller, OnceCaller}; -use crate::hive::{Builder, Outcome, OutcomeBatch}; +use crate::hive::{ + Builder, ChannelGlobalQueue, ChannelQueues, DefaultLocalQueues, Outcome, OutcomeBatch, +}; use std::fmt::Debug; /// Convenience function that creates a `Hive` with `num_threads` worker threads that execute the @@ -30,7 +32,7 @@ where { Builder::default() .num_threads(num_threads) - .build_with(Caller::of(f)) + .build_with::<_, ChannelQueues<_>>(Caller::of(f)) .map(inputs) .map(Outcome::unwrap) .collect() @@ -69,7 +71,7 @@ where { Builder::default() .num_threads(num_threads) - .build_with(OnceCaller::of(f)) + .build_with::<_, ChannelQueues<_>>(OnceCaller::of(f)) .map(inputs) .into() } @@ -115,7 +117,7 @@ pub use retry::try_map_retryable; mod retry { use crate::bee::stock::RetryCaller; use crate::bee::{ApplyError, Context}; - use crate::hive::{Builder, OutcomeBatch}; + use crate::hive::{Builder, ChannelQueues, OutcomeBatch}; use std::fmt::Debug; /// Convenience function that creates a `Hive` with `num_threads` worker threads that execute the @@ -159,7 +161,7 @@ mod retry { Builder::default() .num_threads(num_threads) .max_retries(max_retries) - .build_with(RetryCaller::of(f)) + .build_with::<_, ChannelQueues<_>>(RetryCaller::of(f)) .map(inputs) .into() } From 88e366fd110ccb232f0c02bd47b650cc4bdaea4d Mon Sep 17 00:00:00 2001 From: jdidion Date: Tue, 11 Feb 2025 10:14:42 -0800 Subject: [PATCH 08/67] WIP --- src/hive/core.rs | 1053 ++++++++++++++++++++++++++++++++++++++ src/hive/queue/delay.rs | 164 ++++++ src/hive/queue/global.rs | 77 +++ src/hive/queue/local.rs | 332 ++++++++++++ src/hive/queue/mod.rs | 44 ++ src/hive/queue/null.rs | 44 ++ src/hive/task.rs | 67 +++ 7 files changed, 1781 insertions(+) create mode 100644 src/hive/core.rs create mode 100644 src/hive/queue/delay.rs create mode 100644 src/hive/queue/global.rs create mode 100644 src/hive/queue/local.rs create mode 100644 src/hive/queue/mod.rs create mode 100644 src/hive/queue/null.rs create mode 100644 src/hive/task.rs diff --git a/src/hive/core.rs b/src/hive/core.rs new file mode 100644 index 0000000..2ae15c5 --- /dev/null +++ b/src/hive/core.rs @@ -0,0 +1,1053 @@ +use super::prelude::*; +use super::{ + Config, DerefOutcomes, GlobalQueue, LocalQueues, OutcomeSender, QueuePair, Shared, SpawnError, +}; +use crate::atomic::Atomic; +use crate::bee::{DefaultQueen, Queen, TaskContext, TaskId, Worker}; +use crossbeam_utils::Backoff; +use std::collections::HashMap; +use std::fmt; +use std::ops::{Deref, DerefMut}; +use std::sync::Arc; +use std::thread::{self, JoinHandle}; + +#[derive(thiserror::Error, Debug)] +#[error("The hive has been poisoned")] +pub struct Poisoned; + +impl< + W: Worker, + Q: Queen, + G: GlobalQueue, + L: LocalQueues, + P: QueuePair, + > Hive +{ + /// Creates a new `Hive`. This should only be called from `Builder`. + /// + /// The `Hive` will attempt to spawn the configured number of worker threads + /// (`config.num_threads`) but the actual number of threads available may be lower if there + /// are any errors during spawning. + pub(super) fn new(config: Config, queen: Q) -> Self { + let shared = Arc::new(Shared::new(config.into_sync(), queen)); + shared.init_threads(|thread_index| Self::try_spawn(thread_index, &shared)); + Self(Some(shared)) + } +} + +impl< + W: Worker, + Q: Queen, + G: GlobalQueue, + L: LocalQueues, + P: QueuePair, + > Hive +{ + /// Spawns a new worker thread with the specified index and with access to the `shared` data. + fn try_spawn( + thread_index: usize, + shared: &Arc>, + ) -> Result, SpawnError> { + let thread_builder = shared.thread_builder(); + let shared = Arc::clone(shared); + // spawn a thread that executes the worker loop + thread_builder.spawn(move || { + // perform one-time initialization of the worker thread + Self::init_thread(thread_index, &shared); + // create a Sentinel that will spawn a new thread on panic until it is cancelled + let sentinel = Sentinel::new(thread_index, Arc::clone(&shared)); + // create a new worker to process tasks + let mut worker = shared.create_worker(); + // execute the main loop: get the next task to process, which decrements the queued + // counter and increments the active counter + while let Some(task) = shared.get_next_task(thread_index) { + // execute the task and dispose of the outcome + Self::execute(task, thread_index, &mut worker, &shared); + // finish the task - decrements the active counter and notifies other threads + shared.finish_task(false); + } + // this is only reachable when the main loop exits due to the task receiver having + // disconnected; cancel the Sentinel so this thread won't be re-spawned on drop + sentinel.cancel(); + }) + } + + #[inline] + fn shared(&self) -> &Arc> { + self.0.as_ref().unwrap() + } + + /// Attempts to increase the number of worker threads by `num_threads`. Returns the number of + /// new worker threads that were successfully started (which may be fewer than `num_threads`), + /// or a `Poisoned` error if the hive has been poisoned. + pub fn grow(&self, num_threads: usize) -> Result { + if num_threads == 0 { + return Ok(0); + } + let shared = self.shared(); + // do not start any new threads if the hive is poisoned + if shared.is_poisoned() { + return Err(Poisoned); + } + let num_started = shared.grow_threads(num_threads, |thread_index| { + Self::try_spawn(thread_index, shared) + }); + Ok(num_started) + } + + /// Sets the number of worker threads to the number of available CPU cores. Returns the number + /// of new threads that were successfully started (which may be `0`), or a `Poisoned` error if + /// the hive has been poisoned. + pub fn use_all_cores(&self) -> Result { + let num_threads = num_cpus::get().saturating_sub(self.max_workers()); + self.grow(num_threads) + } + + /// Sends one `input` to the `Hive` for procesing and returns the result, blocking until the + /// result is available. Creates a channel to send the input and receive the outcome. Returns + /// an [`Outcome`] with the task output or an error. + pub fn apply(&self, input: W::Input) -> Outcome { + let (tx, rx) = outcome_channel(); + let task_id = self.shared().send_one_global(input, Some(&tx)); + rx.recv().unwrap_or_else(|_| Outcome::Missing { task_id }) + } + + /// Sends one `input` to the `Hive` for processing and returns its ID. The [`Outcome`] of + /// the task will be sent to `tx` upon completion. + pub fn apply_send(&self, input: W::Input, tx: &OutcomeSender) -> TaskId { + self.shared().send_one_global(input, Some(tx)) + } + + /// Sends one `input` to the `Hive` for processing and returns its ID immediately. The + /// [`Outcome`] of the task will be retained and available for later retrieval. + pub fn apply_store(&self, input: W::Input) -> TaskId { + self.shared().send_one_global(input, None) + } + + /// Sends a `batch` of inputs to the `Hive` for processing, and returns an iterator over the + /// [`Outcome`]s in the same order as the inputs. + /// + /// This method is more efficient than [`map`](Self::map) when the input is an + /// [`ExactSizeIterator`]. + pub fn swarm(&self, batch: T) -> impl Iterator> + where + T: IntoIterator, + T::IntoIter: ExactSizeIterator, + { + let (tx, rx) = outcome_channel(); + let task_ids = self.shared().send_batch_global(batch, Some(&tx)); + rx.select_ordered(task_ids) + } + + /// Sends a `batch` of inputs to the `Hive` for processing, and returns an unordered iterator + /// over the [`Outcome`]s. + /// + /// The `Outcome`s will be sent in the order they are completed; use [`swarm`](Self::swarm) to + /// instead receive the `Outcome`s in the order they were submitted. This method is more + /// efficient than [`map_unordered`](Self::map_unordered) when the input is an + /// [`ExactSizeIterator`]. + pub fn swarm_unordered(&self, batch: T) -> impl Iterator> + where + T: IntoIterator, + T::IntoIter: ExactSizeIterator, + { + let (tx, rx) = outcome_channel(); + let task_ids = self.shared().send_batch_global(batch, Some(&tx)); + rx.select_unordered(task_ids) + } + + /// Sends a `batch` of inputs to the `Hive` for processing, and returns a [`Vec`] of task IDs. + /// The [`Outcome`]s of the tasks will be sent to `tx` upon completion. + /// + /// This method is more efficient than [`map_send`](Self::map_send) when the input is an + /// [`ExactSizeIterator`]. + pub fn swarm_send(&self, batch: T, outcome_tx: &OutcomeSender) -> Vec + where + T: IntoIterator, + T::IntoIter: ExactSizeIterator, + { + self.shared().send_batch_global(batch, Some(outcome_tx)) + } + + /// Sends a `batch` of inputs to the `Hive` for processing, and returns a [`Vec`] of task IDs. + /// The [`Outcome`]s of the task are retained and available for later retrieval. + /// + /// This method is more efficient than `map_store` when the input is an [`ExactSizeIterator`]. + pub fn swarm_store(&self, batch: T) -> Vec + where + T: IntoIterator, + T::IntoIter: ExactSizeIterator, + { + self.shared().send_batch_global(batch, None) + } + + /// Iterates over `inputs` and sends each one to the `Hive` for processing and returns an + /// iterator over the [`Outcome`]s in the same order as the inputs. + /// + /// [`swarm`](Self::swarm) should be preferred when `inputs` is an [`ExactSizeIterator`]. + pub fn map( + &self, + inputs: impl IntoIterator, + ) -> impl Iterator> { + let (tx, rx) = outcome_channel(); + let task_ids: Vec<_> = inputs + .into_iter() + .map(|task| self.apply_send(task, &tx)) + .collect(); + rx.select_ordered(task_ids) + } + + /// Iterates over `inputs`, sends each one to the `Hive` for processing, and returns an + /// iterator over the [`Outcome`]s in order they become available. + /// + /// [`swarm_unordered`](Self::swarm_unordered) should be preferred when `inputs` is an + /// [`ExactSizeIterator`]. + pub fn map_unordered( + &self, + inputs: impl IntoIterator, + ) -> impl Iterator> { + let (tx, rx) = outcome_channel(); + // `map` is required (rather than `inspect`) because we need owned items + let task_ids: Vec<_> = inputs + .into_iter() + .map(|task| self.apply_send(task, &tx)) + .collect(); + rx.select_unordered(task_ids) + } + + /// Iterates over `inputs` and sends each one to the `Hive` for processing. Returns a [`Vec`] + /// of task IDs. The [`Outcome`]s of the tasks will be sent to `tx` upon completion. + /// + /// [`swarm_send`](Self::swarm_send) should be preferred when `inputs` is an + /// [`ExactSizeIterator`]. + pub fn map_send( + &self, + inputs: impl IntoIterator, + tx: &OutcomeSender, + ) -> Vec { + inputs + .into_iter() + .map(|input| self.apply_send(input, tx)) + .collect() + } + + /// Iterates over `inputs` and sends each one to the `Hive` for processing. Returns a [`Vec`] + /// of task IDs. The [`Outcome`]s of the task are retained and available for later retrieval. + /// + /// [`swarm_store`](Self::swarm_store) should be preferred when `inputs` is an + /// [`ExactSizeIterator`]. + pub fn map_store(&self, inputs: impl IntoIterator) -> Vec { + inputs + .into_iter() + .map(|input| self.apply_store(input)) + .collect() + } + + /// Iterates over `items` and calls `f` with a mutable reference to a state value (initialized + /// to `init`) and each item. `F` returns an input that is sent to the `Hive` for processing. + /// Returns an [`OutcomeBatch`] of the outputs and the final state value. + pub fn scan( + &self, + items: impl IntoIterator, + init: St, + f: F, + ) -> (OutcomeBatch, St) + where + F: FnMut(&mut St, T) -> W::Input, + { + let (tx, rx) = outcome_channel(); + let (task_ids, fold_value) = self.scan_send(items, &tx, init, f); + let outcomes = rx.select_unordered(task_ids).into(); + (outcomes, fold_value) + } + + /// Iterates over `items` and calls `f` with a mutable reference to a state value (initialized + /// to `init`) and each item. `F` returns an input that is sent to the `Hive` for processing, + /// or an error. Returns an [`OutcomeBatch`] of the outputs, a [`Vec`] of errors, and the final + /// state value. + pub fn try_scan( + &self, + items: impl IntoIterator, + init: St, + mut f: F, + ) -> (OutcomeBatch, Vec, St) + where + F: FnMut(&mut St, T) -> Result, + { + let (tx, rx) = outcome_channel(); + let (task_ids, errors, fold_value) = items.into_iter().fold( + (Vec::new(), Vec::new(), init), + |(mut task_ids, mut errors, mut acc), inp| { + match f(&mut acc, inp) { + Ok(input) => task_ids.push(self.apply_send(input, &tx)), + Err(err) => errors.push(err), + } + (task_ids, errors, acc) + }, + ); + let outcomes = rx.select_unordered(task_ids).into(); + (outcomes, errors, fold_value) + } + + /// Iterates over `items` and calls `f` with a mutable reference to a state value (initialized + /// to `init`) and each item. `f` returns an input that is sent to the `Hive` for processing. + /// The outputs are sent to `tx` in the order they become available. Returns a [`Vec`] of the + /// task IDs and the final state value. + pub fn scan_send( + &self, + items: impl IntoIterator, + tx: &OutcomeSender, + init: St, + mut f: F, + ) -> (Vec, St) + where + F: FnMut(&mut St, T) -> W::Input, + { + items + .into_iter() + .fold((Vec::new(), init), |(mut task_ids, mut acc), item| { + let input = f(&mut acc, item); + task_ids.push(self.apply_send(input, tx)); + (task_ids, acc) + }) + } + + /// Iterates over `items` and calls `f` with a mutable reference to a state value (initialized + /// to `init`) and each item. `f` returns an input that is sent to the `Hive` for processing, + /// or an error. The outputs are sent to `tx` in the order they become available. This + /// function returns the final state value and a [`Vec`] of results, where each result is + /// either a task ID or an error. + pub fn try_scan_send( + &self, + items: impl IntoIterator, + tx: &OutcomeSender, + init: St, + mut f: F, + ) -> (Vec>, St) + where + F: FnMut(&mut St, T) -> Result, + { + items + .into_iter() + .fold((Vec::new(), init), |(mut results, mut acc), inp| { + results.push(f(&mut acc, inp).map(|input| self.apply_send(input, tx))); + (results, acc) + }) + } + + /// Iterates over `items` and calls `f` with a mutable reference to a state value (initialized + /// to `init`) and each item. `f` returns an input that is sent to the `Hive` for processing. + /// This function returns the final state value and a [`Vec`] of task IDs. The [`Outcome`]s of + /// the tasks are retained and available for later retrieval. + pub fn scan_store( + &self, + items: impl IntoIterator, + init: St, + mut f: F, + ) -> (Vec, St) + where + F: FnMut(&mut St, T) -> W::Input, + { + items + .into_iter() + .fold((Vec::new(), init), |(mut task_ids, mut acc), item| { + let input = f(&mut acc, item); + task_ids.push(self.apply_store(input)); + (task_ids, acc) + }) + } + + /// Iterates over `items` and calls `f` with a mutable reference to a state value (initialized + /// to `init`) and each item. `f` returns an input that is sent to the `Hive` for processing, + /// or an error. This function returns the final value of the state value and a [`Vec`] of + /// results, where each result is either a task ID or an error. The [`Outcome`]s of the + /// tasks are retained and available for later retrieval. + pub fn try_scan_store( + &self, + items: impl IntoIterator, + init: St, + mut f: F, + ) -> (Vec>, St) + where + F: FnMut(&mut St, T) -> Result, + { + items + .into_iter() + .fold((Vec::new(), init), |(mut results, mut acc), item| { + results.push(f(&mut acc, item).map(|input| self.apply_store(input))); + (results, acc) + }) + } + + /// Blocks the calling thread until all tasks finish. + pub fn join(&self) { + (self.shared()).wait_on_done(); + } + + /// Returns the [`MutexGuard`](parking_lot::MutexGuard) for the [`Queen`]. + /// + /// Note that the `Queen` will remain locked until the returned guard is dropped, and that + /// locking the `Queen` prevents new worker threads from being started. + pub fn queen(&self) -> &Q { + &self.shared().queen + } + + /// Returns the number of worker threads that have been requested, i.e., the maximum number of + /// tasks that could be processed concurrently. This may be greater than + /// [`active_workers`](Self::active_workers) if any of the worker threads failed to start. + pub fn max_workers(&self) -> usize { + (self.shared()).config.num_threads.get_or_default() + } + + /// Returns the number of worker threads that have been successfully started. This may be + /// fewer than [`max_workers`](Self::max_workers) if any of the worker threads failed to start. + pub fn alive_workers(&self) -> usize { + (self.shared()) + .spawn_results + .lock() + .iter() + .filter(|result| result.is_ok()) + .count() + } + + /// Returns `true` if there are any "dead" worker threads that failed to spawn. + pub fn has_dead_workers(&self) -> bool { + (self.shared()) + .spawn_results + .lock() + .iter() + .any(|result| result.is_err()) + } + + /// Attempts to respawn any dead worker threads. Returns the number of worker threads that were + /// successfully respawned. + pub fn revive_workers(&self) -> usize { + let shared = self.shared(); + shared.respawn_dead_threads(|thread_index| Self::try_spawn(thread_index, shared)) + } + + /// Returns the number of tasks currently (queued for processing, being processed). + pub fn num_tasks(&self) -> (u64, u64) { + (self.shared()).num_tasks() + } + + /// Returns the number of times one of this `Hive`'s worker threads has panicked. + pub fn num_panics(&self) -> usize { + (self.shared()).num_panics.get() + } + + /// Returns `true` if this `Hive` has been poisoned - i.e., its internal state has been + /// corrupted such that it is no longer able to process tasks. + /// + /// Note that, when a `Hive` is poisoned, it is still possible to call methods that extract + /// its stored [`Outcome`]s (e.g., [`take_stored`](Self::take_stored)) or consume it (e.g., + /// [`try_into_husk`](Self::try_into_husk)). + pub fn is_poisoned(&self) -> bool { + (self.shared()).is_poisoned() + } + + /// Returns `true` if the suspended flag is set. + pub fn is_suspended(&self) -> bool { + (self.shared()).is_suspended() + } + + /// Sets the suspended flag, which notifies worker threads that they a) MAY terminate their + /// current task early (returning an [`Outcome::Unprocessed`]), and b) MUST not accept new + /// tasks, and instead block until the suspended flag is cleared. + /// + /// Call [`resume`](Self::resume) to unset the suspended flag and continue processing tasks. + /// + /// Note: this does *not* prevent new tasks from being queued, and there is a window of time + /// (~1 second) after the suspended flag is set within which a worker thread may still accept a + /// new task. + /// + /// # Examples + /// + /// ``` + /// use beekeeper::bee::stock::{Thunk, ThunkWorker}; + /// use beekeeper::hive::Builder; + /// use std::thread; + /// use std::time::Duration; + /// + /// # fn main() { + /// let hive = Builder::new() + /// .num_threads(4) + /// .build_with_default::>(); + /// hive.map((0..10).map(|_| Thunk::of(|| thread::sleep(Duration::from_secs(3))))); + /// thread::sleep(Duration::from_secs(1)); // Allow first set of tasks to be started. + /// // There should be 4 active tasks and 6 queued tasks. + /// hive.suspend(); + /// assert_eq!(hive.num_tasks(), (6, 4)); + /// // Wait for active tasks to complete. + /// hive.join(); + /// assert_eq!(hive.num_tasks(), (6, 0)); + /// hive.resume(); + /// // Wait for remaining tasks to complete. + /// hive.join(); + /// assert_eq!(hive.num_tasks(), (0, 0)); + /// # } + /// ``` + pub fn suspend(&self) { + (self.shared()).set_suspended(true); + } + + /// Unsets the suspended flag, allowing worker threads to continue processing queued tasks. + pub fn resume(&self) { + (self.shared()).set_suspended(false); + } + + /// Removes all `Unprocessed` outcomes from this `Hive` and returns them as an iterator over + /// the input values. + fn take_unprocessed_inputs(&self) -> impl ExactSizeIterator { + (self.shared()) + .take_unprocessed() + .into_iter() + .map(|outcome| match outcome { + Outcome::Unprocessed { input, task_id: _ } => input, + _ => unreachable!(), + }) + } + + /// If this `Hive` is suspended, resumes this `Hive` and re-submits any unprocessed tasks for + /// processing, with their results to be sent to `tx`. Returns a [`Vec`] of task IDs that + /// were resumed. + pub fn resume_send(&self, outcome_tx: &OutcomeSender) -> Vec { + (self.shared()) + .set_suspended(false) + .then(|| self.swarm_send(self.take_unprocessed_inputs(), outcome_tx)) + .unwrap_or_default() + } + + /// If this `Hive` is suspended, resumes this `Hive` and re-submit any unprocessed tasks for + /// processing, with their results to be stored in the queue. Returns a [`Vec`] of task IDs + /// that were resumed. + pub fn resume_store(&self) -> Vec { + (self.shared()) + .set_suspended(false) + .then(|| self.swarm_store(self.take_unprocessed_inputs())) + .unwrap_or_default() + } + + /// Returns all stored outcomes as a [`HashMap`] of task IDs to `Outcome`s. + pub fn take_stored(&self) -> HashMap> { + (self.shared()).take_outcomes() + } + + /// Consumes this `Hive` and attempts to return a [`Husk`] containing the remnants of this + /// `Hive`, including any stored task outcomes, and all the data necessary to create a new + /// `Hive`. + /// + /// If this `Hive` has been cloned, and those clones have not been dropped, this method + /// returns `None` since it cannot take exclusive ownership of the internal shared data. + /// + /// This method first joins on the `Hive` to wait for all tasks to finish. + pub fn try_into_husk(mut self) -> Option> { + if (self.shared()).num_referrers() > 1 { + return None; + } + // take the inner value and replace it with `None` + let mut shared = self.0.take().unwrap(); + // close the global queue to prevent new tasks from being submitted + shared.global_queue.close(); + // wait for all tasks to finish + shared.wait_on_done(); + // wait for worker threads to drop, then take ownership of the shared data and convert it + // into a Husk + let mut backoff = None::; + loop { + // TODO: may want to have some timeout or other kind of limit to prevent this from + // looping forever if a worker thread somehow gets stuck, or if the `num_referrers` + // counter is corrupted + shared = match Arc::try_unwrap(shared) { + Ok(shared) => { + return Some(shared.into_husk()); + } + Err(shared) => { + backoff.get_or_insert_with(Backoff::new).spin(); + shared + } + }; + } + } +} + +use crate::hive::queue::{ChannelGlobalQueue, ChannelQueues, DefaultLocalQueues}; + +impl Default + for Hive< + W, + DefaultQueen, + ChannelGlobalQueue, + DefaultLocalQueues>, + ChannelQueues, + > +{ + fn default() -> Self { + Builder::default().build_with_default::>() + } +} + +impl Clone for Hive +where + W: Worker, + Q: Queen, + G: GlobalQueue, + L: LocalQueues, + P: QueuePair, +{ + /// Creates a shallow copy of this `Hive` containing references to its same internal state, + /// i.e., all clones of a `Hive` submit tasks to the same shared worker thread pool. + fn clone(&self) -> Self { + let shared = self.0.as_ref().unwrap(); + shared.referrer_is_cloning(); + Self(Some(shared.clone())) + } +} + +impl fmt::Debug for Hive +where + W: Worker, + Q: Queen, + G: GlobalQueue, + L: LocalQueues, + P: QueuePair, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if let Some(shared) = self.0.as_ref() { + f.debug_struct("Hive").field("shared", &shared).finish() + } else { + f.write_str("Hive {}") + } + } +} + +impl PartialEq for Hive +where + W: Worker, + Q: Queen, + G: GlobalQueue, + L: LocalQueues, + + P: QueuePair, +{ + fn eq(&self, other: &Hive) -> bool { + let self_shared = self.shared(); + let other_shared = &other.shared(); + Arc::ptr_eq(self_shared, other_shared) + } +} + +impl Eq for Hive +where + W: Worker, + Q: Queen, + G: GlobalQueue, + L: LocalQueues, + + P: QueuePair, +{ +} + +impl DerefOutcomes for Hive +where + W: Worker, + Q: Queen, + G: GlobalQueue, + L: LocalQueues, + P: QueuePair, +{ + #[inline] + fn outcomes_deref(&self) -> impl Deref>> { + (self.shared()).outcomes() + } + + #[inline] + fn outcomes_deref_mut(&mut self) -> impl DerefMut>> { + (self.shared()).outcomes() + } +} + +impl Drop for Hive +where + W: Worker, + Q: Queen, + G: GlobalQueue, + L: LocalQueues, + P: QueuePair, +{ + fn drop(&mut self) { + // if this Hive has already been turned into a Husk, it's inner value will be `None` + if let Some(shared) = self.0.as_ref() { + // reduce the referrer count + let _ = shared.referrer_is_dropping(); + // if this Hive is the only one with a pointer to the shared data, poison it + // to prevent any worker threads that still have access to the shared data from + // re-spawning. + if shared.num_referrers() == 0 { + shared.poison(); + } + } + } +} + +/// Sentinel for a worker thread. Until the sentinel is cancelled, it will respawn the worker +/// thread if it panics. +struct Sentinel +where + W: Worker, + Q: Queen, + P: QueuePair, +{ + thread_index: usize, + shared: Arc>, + active: bool, +} + +impl Sentinel +where + W: Worker, + Q: Queen, + P: QueuePair, +{ + fn new(thread_index: usize, shared: Arc>) -> Self { + Self { + thread_index, + shared, + active: true, + } + } + + /// Cancel and destroy this sentinel. + fn cancel(mut self) { + self.active = false; + } +} + +impl Drop for Sentinel +where + W: Worker, + Q: Queen, + P: QueuePair, +{ + fn drop(&mut self) { + if self.active { + // if the sentinel is active, that means the thread panicked during task execution, so + // we have to finish the task here before respawning + self.shared.finish_task(thread::panicking()); + // only respawn if the sentinel is active and the hive has not been poisoned + if !self.shared.is_poisoned() { + // can't do anything with the previous result + let _ = self + .shared + .respawn_thread(self.thread_index, |thread_index| { + Hive::try_spawn(thread_index, &self.shared) + }); + } + } + } +} + +#[cfg(not(feature = "affinity"))] +mod no_affinity { + use crate::bee::{Queen, Worker}; + use crate::hive::{GlobalQueue, Hive, LocalQueues, Shared}; + + impl, G: GlobalQueue, L: LocalQueues> Hive { + #[inline] + pub(super) fn init_thread(_: usize, _: &Shared) {} + } +} + +#[cfg(feature = "affinity")] +mod affinity { + use crate::bee::{Queen, Worker}; + use crate::hive::cores::Cores; + use crate::hive::{GlobalQueue, Hive, LocalQueues, Poisoned, QueuePair, Shared}; + + impl Hive + where + W: Worker, + Q: Queen, + G: GlobalQueue, + L: LocalQueues, + P: QueuePair, + { + /// Tries to pin the worker thread to a specific CPU core. + #[inline] + pub(super) fn init_thread(thread_index: usize, shared: &Shared) { + if let Some(core) = shared.get_core_affinity(thread_index) { + core.try_pin_current(); + } + } + + /// Attempts to increase the number of worker threads by `num_threads`. + /// + /// The provided `affinity` specifies additional CPU core indices to which the worker + /// threads may be pinned - these are added to the existing pool of core indices (if any). + /// + /// Returns the number of new worker threads that were successfully started (which may be + /// fewer than `num_threads`) or a `Poisoned` error if the hive has been poisoned. + pub fn grow_with_affinity>( + &self, + num_threads: usize, + affinity: C, + ) -> Result { + (self.shared()).add_core_affinity(affinity.into()); + self.grow(num_threads) + } + + /// Sets the number of worker threads to the number of available CPU cores. An attempt is + /// made to pin each worker thread to a different CPU core. + /// + /// Returns the number of new threads spun up (if any) or a `Poisoned` error if the hive + /// has been poisoned. + pub fn use_all_cores_with_affinity(&self) -> Result { + (self.shared()).add_core_affinity(Cores::all()); + self.use_all_cores() + } + } +} + +#[cfg(feature = "batching")] +mod batching { + use crate::bee::{Queen, Worker}; + use crate::hive::{GlobalQueue, Hive, LocalQueues, QueuePair}; + + impl Hive + where + W: Worker, + Q: Queen, + G: GlobalQueue, + L: LocalQueues, + P: QueuePair, + { + /// Returns the batch size for worker threads. + pub fn worker_batch_size(&self) -> usize { + (self.shared()).batch_size() + } + + /// Sets the batch size for worker threads. This will block the current thread until all + /// worker thread queues can be resized. + pub fn set_worker_batch_size(&self, batch_size: usize) { + (self.shared()).set_batch_size(batch_size); + } + } +} + +struct HiveTaskContext<'a, W, Q, P> +where + W: Worker, + Q: Queen, + P: QueuePair, +{ + thread_index: usize, + shared: &'a Arc>, + outcome_tx: Option<&'a OutcomeSender>, +} + +impl TaskContext for HiveTaskContext<'_, W, Q, P> +where + W: Worker, + Q: Queen, + P: QueuePair, +{ + fn should_cancel_tasks(&self) -> bool { + self.shared.is_suspended() + } + + fn submit_task(&self, input: W::Input) -> TaskId { + self.shared + .send_one_local(input, self.outcome_tx, self.thread_index) + } +} + +impl fmt::Debug for HiveTaskContext<'_, W, Q, P> +where + W: Worker, + Q: Queen, + P: QueuePair, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("HiveTaskContext").finish() + } +} + +#[cfg(not(feature = "retry"))] +mod no_retry { + use super::HiveTaskContext; + use crate::bee::{Context, Queen, Worker}; + use crate::hive::{GlobalQueue, Hive, LocalQueues, Outcome, Shared, Task}; + use std::sync::Arc; + + impl Hive + where + W: Worker, + Q: Queen, + G: GlobalQueue, + L: LocalQueues, + { + pub(super) fn execute( + task: Task, + thread_index: usize, + worker: &mut W, + shared: &Arc>, + ) { + let (task_id, input, outcome_tx) = task.into_parts(); + let task_ctx = HiveTaskContext { + thread_index, + shared, + outcome_tx: outcome_tx.as_ref(), + }; + let ctx = Context::new(task_id, Some(&task_ctx)); + let result = worker.apply(input, &ctx); + let subtask_ids = ctx.into_subtask_ids(); + let outcome = Outcome::from_worker_result(result, task_id, subtask_ids); + shared.send_or_store_outcome(outcome, outcome_tx); + } + } +} + +#[cfg(feature = "retry")] +mod retry { + use super::HiveTaskContext; + use crate::bee::{ApplyError, Context, Queen, Worker}; + use crate::hive::{GlobalQueue, Hive, LocalQueues, Outcome, QueuePair, Shared, Task}; + use std::sync::Arc; + + impl Hive + where + W: Worker, + Q: Queen, + G: GlobalQueue, + L: LocalQueues, + P: QueuePair, + { + pub(super) fn execute( + task: Task, + thread_index: usize, + worker: &mut W, + shared: &Arc>, + ) { + let (task_id, input, attempt, outcome_tx) = task.into_parts(); + let task_ctx = HiveTaskContext { + thread_index, + shared, + outcome_tx: outcome_tx.as_ref(), + }; + let ctx = Context::new(task_id, attempt, Some(&task_ctx)); + // execute the task until it succeeds or we reach maximum retries - this should + // be the only place where a panic can occur + let result = worker.apply(input, &ctx); + let subtask_ids = ctx.into_subtask_ids(); + match result { + Err(ApplyError::Retryable { input, .. }) + if subtask_ids.is_none() && shared.can_retry(attempt) => + { + shared.send_retry(task_id, input, outcome_tx, attempt + 1, thread_index); + } + result => { + let outcome = Outcome::from_worker_result(result, task_id, subtask_ids); + shared.send_or_store_outcome(outcome, outcome_tx); + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::Poisoned; + use crate::bee::stock::{Caller, Thunk, ThunkWorker}; + use crate::hive::queue::ChannelQueues; + use crate::hive::{outcome_channel, Builder, Hive, Outcome, OutcomeIteratorExt}; + use std::collections::HashMap; + use std::thread; + use std::time::Duration; + + #[test] + fn test_suspend() { + let hive = Builder::new() + .num_threads(4) + .build_with_default::, ChannelQueues<_>>(); + let outcome_iter = + hive.map((0..10).map(|_| Thunk::of(|| thread::sleep(Duration::from_secs(3))))); + // Allow first set of tasks to be started. + thread::sleep(Duration::from_secs(1)); + // There should be 4 active tasks and 6 queued tasks. + hive.suspend(); + assert_eq!(hive.num_tasks(), (6, 4)); + // Wait for active tasks to complete. + hive.join(); + assert_eq!(hive.num_tasks(), (6, 0)); + hive.resume(); + // Wait for remaining tasks to complete. + hive.join(); + assert_eq!(hive.num_tasks(), (0, 0)); + let outputs: Vec<_> = outcome_iter.into_outputs().collect(); + assert_eq!(outputs.len(), 10); + } + + #[test] + fn test_spawn_after_poison() { + let hive = Builder::new() + .num_threads(4) + .build_with_default::, ChannelQueues<_>>(); + assert_eq!(hive.max_workers(), 4); + assert_eq!(hive.alive_workers(), 4); + // poison hive using private method + hive.0.as_ref().unwrap().poison(); + // attempt to spawn a new task + assert!(matches!(hive.grow(1), Err(Poisoned))); + // make sure the worker count wasn't increased + assert_eq!(hive.max_workers(), 4); + assert_eq!(hive.alive_workers(), 4); + } + + #[test] + fn test_apply_after_poison() { + let hive = Builder::new() + .num_threads(4) + .build_with::<_, ChannelQueues<_>>(Caller::of(|i: usize| i * 2)); + // poison hive using private method + hive.0.as_ref().unwrap().poison(); + // submit a task, check that it comes back unprocessed + let (tx, rx) = outcome_channel(); + let sent_input = 1; + let sent_task_id = hive.apply_send(sent_input, &tx); + let outcome = rx.recv().unwrap(); + match outcome { + Outcome::Unprocessed { input, task_id } => { + assert_eq!(input, sent_input); + assert_eq!(task_id, sent_task_id); + } + _ => panic!("Expected unprocessed outcome"), + } + } + + #[test] + fn test_swarm_after_poison() { + let hive = Builder::new() + .num_threads(4) + .build_with::<_, ChannelQueues<_>>(Caller::of(|i: usize| i * 2)); + // poison hive using private method + hive.0.as_ref().unwrap().poison(); + // submit a task, check that it comes back unprocessed + let (tx, rx) = outcome_channel(); + let inputs = 0..10; + let task_ids: HashMap = hive + .swarm_send(inputs.clone(), &tx) + .into_iter() + .zip(inputs) + .collect(); + for outcome in rx.into_iter().take(10) { + match outcome { + Outcome::Unprocessed { input, task_id } => { + let expected_input = task_ids.get(&task_id); + assert!(expected_input.is_some()); + assert_eq!(input, *expected_input.unwrap()); + } + _ => panic!("Expected unprocessed outcome"), + } + } + } +} diff --git a/src/hive/queue/delay.rs b/src/hive/queue/delay.rs new file mode 100644 index 0000000..7faa6c4 --- /dev/null +++ b/src/hive/queue/delay.rs @@ -0,0 +1,164 @@ +use std::cell::UnsafeCell; +use std::cmp::Ordering; +use std::collections::BinaryHeap; +use std::time::{Duration, Instant}; + +/// A queue where each item has an associated `Instant` at which it will be available. +/// +/// This is implemented internally as a `UnsafeCell`. +/// +/// SAFETY: This data structure is designed to enable the queue to be modified by a *single thread* +/// using interior mutability. `UnsafeCell` is used for performance - this is safe so long as the +/// queue is only accessed from a single thread at a time. This data structure is *not* thread-safe. +#[derive(Debug)] +pub struct DelayQueue(UnsafeCell>>); + +impl DelayQueue { + /// Returns the number of items currently in the queue. + pub fn len(&self) -> usize { + unsafe { self.0.get().as_ref().unwrap().len() } + } + + /// Pushes an item onto the queue. Returns the `Instant` at which the item will be available, + /// or an error with `item` if there was an error pushing the item. + pub fn push(&self, item: T, delay: Duration) -> Result { + unsafe { + match self.0.get().as_mut() { + Some(queue) => { + let delayed = Delayed::new(item, delay); + let until = delayed.until; + queue.push(delayed); + Ok(until) + } + None => Err(item), + } + } + } + + /// Returns the `Instant` at which the next item will be available. Returns `None` if the queue + /// is empty. + pub fn next_available(&self) -> Option { + unsafe { + self.0 + .get() + .as_ref() + .and_then(|queue| queue.peek().map(|head| head.until)) + } + } + + /// Returns the item at the head of the queue, if one exists and is available (i.e., its delay + /// has been exceeded), and removes it. + pub fn try_pop(&self) -> Option { + unsafe { + if self + .next_available() + .map(|until| until <= Instant::now()) + .unwrap_or(false) + { + self.0 + .get() + .as_mut() + .and_then(|queue| queue.pop()) + .map(|delayed| delayed.value) + } else { + None + } + } + } + + /// Drains all items from the queue and returns them as an iterator. + pub fn drain(&mut self) -> impl Iterator + '_ { + self.0.get_mut().drain().map(|delayed| delayed.value) + } +} + +unsafe impl Sync for DelayQueue {} + +impl Default for DelayQueue { + fn default() -> Self { + DelayQueue(UnsafeCell::new(BinaryHeap::new())) + } +} + +#[derive(Debug)] +struct Delayed { + value: T, + until: Instant, +} + +impl Delayed { + pub fn new(value: T, delay: Duration) -> Self { + Delayed { + value, + until: Instant::now() + delay, + } + } +} + +/// Implements ordering for `Delayed`, so it can be used to correctly order elements in the +/// `BinaryHeap` of the `DelayQueue`. +/// +/// Earlier entries have higher priority (should be popped first), so they are Greater that later +/// entries. +impl Ord for Delayed { + fn cmp(&self, other: &Delayed) -> Ordering { + other.until.cmp(&self.until) + } +} + +impl PartialOrd for Delayed { + fn partial_cmp(&self, other: &Delayed) -> Option { + Some(self.cmp(other)) + } +} + +impl PartialEq for Delayed { + fn eq(&self, other: &Delayed) -> bool { + self.cmp(other) == Ordering::Equal + } +} + +impl Eq for Delayed {} + +#[cfg(test)] +mod tests { + use super::DelayQueue; + use std::{thread, time::Duration}; + + #[test] + fn test_works() { + let queue = DelayQueue::default(); + + queue.push(1, Duration::from_secs(1)).unwrap(); + queue.push(2, Duration::from_secs(2)).unwrap(); + queue.push(3, Duration::from_secs(3)).unwrap(); + + assert_eq!(queue.len(), 3); + assert_eq!(queue.try_pop(), None); + + thread::sleep(Duration::from_secs(1)); + assert_eq!(queue.try_pop(), Some(1)); + assert_eq!(queue.len(), 2); + + thread::sleep(Duration::from_secs(1)); + assert_eq!(queue.try_pop(), Some(2)); + assert_eq!(queue.len(), 1); + + thread::sleep(Duration::from_secs(1)); + assert_eq!(queue.try_pop(), Some(3)); + assert_eq!(queue.len(), 0); + + assert_eq!(queue.try_pop(), None); + } + + #[test] + fn test_into_vec() { + let mut queue = DelayQueue::default(); + queue.push(1, Duration::from_secs(1)).unwrap(); + queue.push(2, Duration::from_secs(2)).unwrap(); + queue.push(3, Duration::from_secs(3)).unwrap(); + let mut v: Vec<_> = queue.drain().collect(); + v.sort(); + assert_eq!(v, vec![1, 2, 3]); + } +} diff --git a/src/hive/queue/global.rs b/src/hive/queue/global.rs new file mode 100644 index 0000000..87c737d --- /dev/null +++ b/src/hive/queue/global.rs @@ -0,0 +1,77 @@ +use crate::atomic::{Atomic, AtomicBool}; +use crate::bee::Worker; +use crate::hive::{GlobalPopError, GlobalQueue, Task}; +use crossbeam_channel::RecvTimeoutError; +use std::time::Duration; + +/// Type alias for the input task channel sender +type TaskSender = crossbeam_channel::Sender>; +/// Type alias for the input task channel receiver +type TaskReceiver = crossbeam_channel::Receiver>; + +pub struct ChannelGlobalQueue { + tx: TaskSender, + rx: TaskReceiver, + closed: AtomicBool, +} + +impl ChannelGlobalQueue { + /// Returns a new `GlobalQueue` that uses the given channel sender for pushing new tasks + /// and the given channel receiver for popping tasks. + pub(super) fn new(tx: TaskSender, rx: TaskReceiver) -> Self { + Self { + tx, + rx, + closed: AtomicBool::default(), + } + } + + pub(super) fn try_pop_timeout( + &self, + timeout: Duration, + ) -> Option, GlobalPopError>> { + match self.rx.recv_timeout(timeout) { + Ok(task) => Some(Ok(task)), + Err(RecvTimeoutError::Disconnected) => Some(Err(GlobalPopError::Closed)), + Err(RecvTimeoutError::Timeout) if self.closed.get() && self.rx.is_empty() => { + Some(Err(GlobalPopError::Closed)) + } + Err(RecvTimeoutError::Timeout) => None, + } + } +} + +impl GlobalQueue for ChannelGlobalQueue { + fn try_push(&self, task: Task) -> Result<(), Task> { + if !self.closed.get() { + self.tx.try_send(task).map_err(|err| err.into_inner()) + } else { + Err(task) + } + } + + fn try_pop(&self) -> Option, GlobalPopError>> { + // time to wait in between polling the retry queue and then the task receiver + const RECV_TIMEOUT: Duration = Duration::from_secs(1); + self.try_pop_timeout(RECV_TIMEOUT) + } + + fn try_iter(&self) -> impl Iterator> + '_ { + self.rx.try_iter() + } + + fn drain(&self) -> Vec> { + self.rx.try_iter().collect() + } + + fn close(&self) { + self.closed.set(true); + } +} + +impl Default for ChannelGlobalQueue { + fn default() -> Self { + let (tx, rx) = crossbeam_channel::unbounded(); + Self::new(tx, rx) + } +} diff --git a/src/hive/queue/local.rs b/src/hive/queue/local.rs new file mode 100644 index 0000000..1993f01 --- /dev/null +++ b/src/hive/queue/local.rs @@ -0,0 +1,332 @@ +use std::marker::PhantomData; + +use crate::bee::{Queen, Worker}; +use crate::hive::{GlobalQueue, LocalQueues, QueuePair, Shared, Task}; +use parking_lot::RwLock; + +pub struct LocalQueuesImpl> { + /// thread-local queues of tasks used when the `batching` feature is enabled + #[cfg(feature = "batching")] + batch_queues: RwLock>>>, + /// thread-local queues used for tasks that are waiting to be retried after a failure + #[cfg(feature = "retry")] + retry_queues: RwLock>>>, + /// marker for the global queue type + _global: PhantomData, +} + +#[cfg(feature = "retry")] +impl> LocalQueuesImpl { + #[inline] + fn try_pop_retry(&self, thread_index: usize) -> Option> { + self.retry_queues + .read() + .get(thread_index) + .and_then(|queue| queue.try_pop()) + } +} + +#[cfg(feature = "batching")] +impl> LocalQueuesImpl { + #[inline] + fn try_push_local(&self, task: Task, thread_index: usize) -> Result<(), Task> { + self.batch_queues.read()[thread_index].push(task) + } + + #[inline] + fn try_pop_local_or_refill, P: QueuePair>( + &self, + thread_index: usize, + shared: &Shared, + ) -> Option> { + let local_queue = &self.batch_queues.read()[thread_index]; + // pop from the local queue if it has any tasks + if !local_queue.is_empty() { + return local_queue.pop(); + } + // otherwise pull at least 1 and up to `batch_size + 1` tasks from the input channel + // wait for the next task from the receiver + let first = shared.global_queue.try_pop().and_then(Result::ok); + // if we fail after trying to get one, don't keep trying to fill the queue + if first.is_some() { + let batch_size = shared.batch_size(); + // batch size 0 means batching is disabled + if batch_size > 0 { + // otherwise try to take up to `batch_size` tasks from the input channel + // and add them to the local queue, but don't block if the input channel + // is empty + for result in shared + .global_queue + .try_iter() + .take(batch_size) + .map(|task| local_queue.push(task)) + { + if let Err(task) = result { + // for some reason we can't push the task to the local queue; + // this should never happen, but just in case we turn it into an + // unprocessed outcome and stop iterating + shared.abandon_task(task); + break; + } + } + } + } + first + } +} + +impl> LocalQueues for LocalQueuesImpl { + fn init_for_threads, P: QueuePair>( + &self, + start_index: usize, + end_index: usize, + #[allow(unused_variables)] shared: &Shared, + ) { + #[cfg(feature = "batching")] + self.init_batch_queues_for_threads(start_index, end_index, shared); + #[cfg(feature = "retry")] + self.init_retry_queues_for_threads(start_index, end_index); + } + + #[cfg(feature = "batching")] + fn resize, P: QueuePair>( + &self, + start_index: usize, + end_index: usize, + new_size: usize, + shared: &Shared, + ) { + self.resize_batch_queues(start_index, end_index, new_size, shared); + } + + /// Creates a task from `input` and pushes it to the local queue if there is space, + /// otherwise attempts to add it to the global queue. Returns the task ID if the push + /// succeeds, otherwise returns an error with the input. + fn push, P: QueuePair>( + &self, + task: Task, + #[allow(unused_variables)] thread_index: usize, + shared: &Shared, + ) { + #[cfg(feature = "batching")] + let task = match self.try_push_local(task, thread_index) { + Ok(_) => return, + Err(task) => task, + }; + shared.push_global(task); + } + + /// Returns the next task from the local queue if there are any, otherwise attempts to + /// fetch at least 1 and up to `batch_size + 1` tasks from the input channel and puts all + /// but the first one into the local queue. + fn try_pop, P: QueuePair>( + &self, + thread_index: usize, + #[allow(unused_variables)] shared: &Shared, + ) -> Option> { + #[cfg(feature = "retry")] + if let Some(task) = self.try_pop_retry(thread_index) { + return Some(task); + } + #[cfg(feature = "batching")] + if let Some(task) = self.try_pop_local_or_refill(thread_index, shared) { + return Some(task); + } + None + } + + fn drain(&self) -> Vec> { + let mut tasks = Vec::new(); + #[cfg(feature = "batching")] + { + self.drain_batch_queues_into(&mut tasks); + } + #[cfg(feature = "retry")] + { + self.drain_retry_queues_into(&mut tasks); + } + tasks + } + + #[cfg(feature = "retry")] + fn retry, P: QueuePair>( + &self, + task: Task, + thread_index: usize, + shared: &Shared, + ) -> Option { + self.try_push_retry(task, thread_index, shared) + } +} + +impl> Default for LocalQueuesImpl { + fn default() -> Self { + Self { + #[cfg(feature = "batching")] + batch_queues: Default::default(), + #[cfg(feature = "retry")] + retry_queues: Default::default(), + _global: PhantomData, + } + } +} + +#[cfg(feature = "batching")] +mod batching { + use super::LocalQueuesImpl; + use crate::bee::{Queen, Worker}; + use crate::hive::{GlobalQueue, QueuePair, Shared, Task}; + use crossbeam_queue::ArrayQueue; + use std::collections::HashSet; + use std::time::Duration; + + impl> LocalQueuesImpl { + pub(super) fn init_batch_queues_for_threads< + Q: Queen, + P: QueuePair, + >( + &self, + start_index: usize, + end_index: usize, + shared: &Shared, + ) { + let mut batch_queues = self.batch_queues.write(); + assert_eq!(batch_queues.len(), start_index); + let queue_size = shared.batch_size().max(1); + (start_index..end_index).for_each(|_| batch_queues.push(ArrayQueue::new(queue_size))); + } + + pub(super) fn resize_batch_queues< + Q: Queen, + P: QueuePair, + >( + &self, + start_index: usize, + end_index: usize, + batch_size: usize, + shared: &Shared, + ) { + // keep track of which queues need to be resized + // TODO: this method could cause a hang if one of the worker threads is stuck - we + // might want to keep track of each queue's size and if we don't see it shrink + // within a certain amount of time, we give up on that thread and leave it with a + // wrong-sized queue (which should never cause a panic) + let mut to_resize: HashSet = (start_index..end_index).collect(); + // iterate until we've resized them all + loop { + // scope the mutable access to local_queues + { + let mut batch_queues = self.batch_queues.write(); + to_resize.retain(|thread_index| { + let queue = if let Some(queue) = batch_queues.get_mut(*thread_index) { + queue + } else { + return false; + }; + if queue.len() > batch_size { + return true; + } + let new_queue = ArrayQueue::new(batch_size); + while let Some(task) = queue.pop() { + if let Err(task) = new_queue.push(task) { + // for some reason we can't push the task to the new queue + // this should never happen, but just in case we turn it into + // an unprocessed outcome + shared.abandon_task(task); + } + } + // this is safe because the worker threads can't get readable access to the + // queue while this thread holds the lock + let old_queue = std::mem::replace(queue, new_queue); + assert!(old_queue.is_empty()); + false + }); + } + if !to_resize.is_empty() { + // short sleep to give worker threads the chance to pull from their queues + std::thread::sleep(Duration::from_millis(10)); + } + } + } + + pub(super) fn drain_batch_queues_into(&self, tasks: &mut Vec>) { + let _ = self + .batch_queues + .write() + .iter_mut() + .fold(tasks, |tasks, queue| { + tasks.reserve(queue.len()); + while let Some(task) = queue.pop() { + tasks.push(task); + } + tasks + }); + } + } +} + +#[cfg(feature = "retry")] +mod retry { + use super::LocalQueuesImpl; + use crate::bee::{Queen, Worker}; + use crate::hive::queue::delay::DelayQueue; + use crate::hive::{GlobalQueue, QueuePair, Shared, Task}; + use std::time::{Duration, Instant}; + + impl> LocalQueuesImpl { + /// Initializes the retry queues worker threads in the specified range. + pub(super) fn init_retry_queues_for_threads(&self, start_index: usize, end_index: usize) { + let mut retry_queues = self.retry_queues.write(); + assert_eq!(retry_queues.len(), start_index); + (start_index..end_index).for_each(|_| retry_queues.push(DelayQueue::default())) + } + + /// Adds a task to the retry queue with a delay based on `attempt`. + pub(super) fn try_push_retry< + Q: Queen, + P: QueuePair, + >( + &self, + task: Task, + thread_index: usize, + shared: &Shared, + ) -> Option { + // compute the delay + let delay = shared + .config + .retry_factor + .get() + .map(|retry_factor| { + 2u64.checked_pow(task.attempt - 1) + .and_then(|multiplier| { + retry_factor + .checked_mul(multiplier) + .or(Some(u64::MAX)) + .map(Duration::from_nanos) + }) + .unwrap() + }) + .unwrap_or_default(); + if let Some(queue) = self.retry_queues.read().get(thread_index) { + queue.push(task, delay) + } else { + Err(task) + } + // if unable to queue the task, abandon it + .map_err(|task| shared.abandon_task(task)) + .ok() + } + + pub(super) fn drain_retry_queues_into(&self, tasks: &mut Vec>) { + let _ = self + .retry_queues + .write() + .iter_mut() + .fold(tasks, |tasks, queue| { + tasks.reserve(queue.len()); + tasks.extend(queue.drain()); + tasks + }); + } + } +} diff --git a/src/hive/queue/mod.rs b/src/hive/queue/mod.rs new file mode 100644 index 0000000..68586b2 --- /dev/null +++ b/src/hive/queue/mod.rs @@ -0,0 +1,44 @@ +#[cfg(feature = "retry")] +mod delay; +mod global; +#[cfg(any(feature = "batching", feature = "retry"))] +mod local; +#[cfg(not(any(feature = "batching", feature = "retry")))] +mod null; + +pub use global::ChannelGlobalQueue; +#[cfg(any(feature = "batching", feature = "retry"))] +pub use local::LocalQueuesImpl as DefaultLocalQueues; +#[cfg(not(any(feature = "batching", feature = "retry")))] +pub use null::LocalQueuesImpl as DefaultLocalQueues; + +use super::{GlobalQueue, LocalQueues, QueuePair}; +use crate::bee::Worker; +use std::marker::PhantomData; + +pub(crate) type ChannelQueues = + DefaultQueuePair, DefaultLocalQueues>>; + +pub(crate) struct DefaultQueuePair< + W: Worker, + G: GlobalQueue + Default, + L: LocalQueues + Default, +> { + _worker: PhantomData, + _global: PhantomData, + _local: PhantomData, +} + +impl QueuePair for DefaultQueuePair +where + W: Worker, + G: GlobalQueue + Default, + L: LocalQueues + Default, +{ + type Global = G; + type Local = L; + + fn new() -> (Self::Global, Self::Local) { + (Self::Global::default(), Self::Local::default()) + } +} diff --git a/src/hive/queue/null.rs b/src/hive/queue/null.rs new file mode 100644 index 0000000..e3be134 --- /dev/null +++ b/src/hive/queue/null.rs @@ -0,0 +1,44 @@ +use crate::bee::{Queen, Worker}; +use crate::hive::{ChannelGlobalQueue, LocalQueues, Shared, Task}; +use std::marker::PhantomData; + +pub struct LocalQueuesImpl(PhantomData W>); + +impl LocalQueues> for LocalQueuesImpl { + fn init_for_threads>( + &self, + _: usize, + _: usize, + _: &Shared, Self>, + ) { + } + + #[inline(always)] + fn push>( + &self, + task: Task, + _: usize, + shared: &Shared, Self>, + ) { + shared.push_global(task); + } + + #[inline(always)] + fn try_pop>( + &self, + _: usize, + _: &Shared, Self>, + ) -> Option> { + None + } + + fn drain(&self) -> Vec> { + Vec::new() + } +} + +impl Default for LocalQueuesImpl { + fn default() -> Self { + Self(PhantomData) + } +} diff --git a/src/hive/task.rs b/src/hive/task.rs new file mode 100644 index 0000000..b2dd42d --- /dev/null +++ b/src/hive/task.rs @@ -0,0 +1,67 @@ +use super::{Outcome, OutcomeSender, Task}; +use crate::bee::{TaskId, Worker}; + +impl Task { + /// Returns the ID of this task. + pub fn id(&self) -> TaskId { + self.id + } + + /// Consumes this `Task` and returns a `Outcome::Unprocessed` outcome with the input and ID, + /// and the outcome sender. + pub fn into_unprocessed(self) -> (Outcome, Option>) { + let outcome = Outcome::Unprocessed { + input: self.input, + task_id: self.id, + }; + (outcome, self.outcome_tx) + } +} + +#[cfg(not(feature = "retry"))] +impl Task { + /// Creates a new `Task`. + pub fn new(id: TaskId, input: W::Input, outcome_tx: Option>) -> Self { + Task { + id, + input, + outcome_tx, + } + } + + pub fn into_parts(self) -> (TaskId, W::Input, Option>) { + (self.id, self.input, self.outcome_tx) + } +} + +#[cfg(feature = "retry")] +impl Task { + /// Creates a new `Task`. + pub fn new(id: TaskId, input: W::Input, outcome_tx: Option>) -> Self { + Task { + id, + input, + outcome_tx, + attempt: 0, + } + } + + /// Creates a new `Task`. + pub fn with_attempt( + id: TaskId, + input: W::Input, + outcome_tx: Option>, + attempt: u32, + ) -> Self { + Task { + id, + input, + outcome_tx, + attempt, + } + } + + pub fn into_parts(self) -> (TaskId, W::Input, u32, Option>) { + (self.id, self.input, self.attempt, self.outcome_tx) + } +} From 234c2ea0cca50b7306be7368d76c2511947f5961 Mon Sep 17 00:00:00 2001 From: jdidion Date: Sun, 16 Feb 2025 15:19:47 -0800 Subject: [PATCH 09/67] reorganize hive module --- CHANGELOG.md | 17 +- Cargo.toml | 5 +- benches/perf.rs | 31 +- src/bee/context.rs | 22 + src/bee/mod.rs | 14 +- src/bee/queen.rs | 15 +- src/hive/builder.rs | 623 ---- src/hive/builder/bee.rs | 127 + src/hive/builder/channel.rs | 71 + src/hive/builder/full.rs | 48 + src/hive/builder/mod.rs | 50 + src/hive/builder/open.rs | 277 ++ src/hive/{core.rs => hive.rs} | 315 +- src/hive/husk.rs | 94 +- src/hive/inner/builder.rs | 331 ++ src/hive/{ => inner}/config.rs | 54 +- src/hive/{ => inner}/counter.rs | 0 src/hive/{ => inner}/gate.rs | 0 src/hive/inner/mod.rs | 107 + .../local.rs => inner/queue/channel.rs} | 286 +- src/hive/{ => inner}/queue/delay.rs | 0 src/hive/inner/queue/mod.rs | 96 + src/hive/inner/queue/workstealing.rs | 36 + src/hive/{ => inner}/shared.rs | 135 +- src/hive/{ => inner}/task.rs | 3 +- src/hive/mod.rs | 3263 ++++++++--------- src/hive/outcome/mod.rs | 84 +- src/hive/outcome/outcome.rs | 452 --- src/hive/outcome/queue.rs | 3 +- src/hive/outcome/store.rs | 55 +- src/hive/queue/global.rs | 77 - src/hive/queue/mod.rs | 44 - src/hive/queue/null.rs | 44 - src/hive/scoped/hive.rs | 152 - src/hive/scoped/mod.rs | 0 src/hive/workstealing.rs | 17 - src/util.rs | 21 +- 37 files changed, 3326 insertions(+), 3643 deletions(-) delete mode 100644 src/hive/builder.rs create mode 100644 src/hive/builder/bee.rs create mode 100644 src/hive/builder/channel.rs create mode 100644 src/hive/builder/full.rs create mode 100644 src/hive/builder/mod.rs create mode 100644 src/hive/builder/open.rs rename src/hive/{core.rs => hive.rs} (81%) create mode 100644 src/hive/inner/builder.rs rename src/hive/{ => inner}/config.rs (84%) rename src/hive/{ => inner}/counter.rs (100%) rename src/hive/{ => inner}/gate.rs (100%) create mode 100644 src/hive/inner/mod.rs rename src/hive/{queue/local.rs => inner/queue/channel.rs} (67%) rename src/hive/{ => inner}/queue/delay.rs (100%) create mode 100644 src/hive/inner/queue/mod.rs create mode 100644 src/hive/inner/queue/workstealing.rs rename src/hive/{ => inner}/shared.rs (88%) rename src/hive/{ => inner}/task.rs (96%) delete mode 100644 src/hive/outcome/outcome.rs delete mode 100644 src/hive/queue/global.rs delete mode 100644 src/hive/queue/mod.rs delete mode 100644 src/hive/queue/null.rs delete mode 100644 src/hive/scoped/hive.rs delete mode 100644 src/hive/scoped/mod.rs delete mode 100644 src/hive/workstealing.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index 4fc475b..9162526 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,17 +5,26 @@ The general theme of this release is performance improvement by eliminating thread contention due to unnecessary locking of shared state. This required making some breaking changes to the API. * **Breaking** + * `beekeeper::hive::Hive` type signature has changed + * Removed the `W: Worker` parameter as it is redundant (can be obtained from `Q::Kind`) + * Added `T: TaskQueues`to specify the `TaskQueues` implementation + * The `Builder` interface has been re-written to enable maximum flexibility. + * `Builder` is now a trait that must be in scope. + * `ChannelBuilder` implements the previous builder functionality. + * `OpenBuilder` has no type parameters and can be specialized to create a `Hive` with any combination of `Queen` and `TaskQueues`. + * `BeeBuilder` and `FullBuilder` are intermediate types that generally should not be instantiated directly. * `beekeeper::bee::Queen::create` now takes `&self` rather than `&mut self`. There is a new type, `beekeeper::bee::QueenMut`, with a `create(&mut self)` method, and needs to be wrapped in a `beekeeper::bee::QueenCell` to implement the `Queen` trait. This enables the `Hive` to create new workers without locking in the case of a `Queen` that does not need mutable state. * `beekeeper::bee::Context` now takes a generic parameter that must be input type of the `Worker`. * Features - * Added the `batching` feature, which enables worker threads to queue up batches of tasks locally, which can alleviate contention between threads in the pool, especially when there are many short-lived tasks. + * Added the `TaskQueues` trait, which enables `Hive` to be specialized for different implementations of global (i.e., sending tasks from the `Hive` to worker threads) and local (i.e., worker thread-specific) queues. + * `ChannelTaskQueues` implements the existing behavior, using a channel for sending tasks. + * `WorkstealingTaskQueues` has been added to implement the workstealing pattern, based on `crossbeam::dequeue`. + * Added the `batching` feature, which enables worker threads to queue up batches of tasks locally, which can alleviate contention between threads in the pool, especially when there are many short-lived tasks. This feature is only used by `ChannelTaskQueues`. * Added the `Context::submit` method, which enables tasks to submit new tasks to the `Hive`. * Other - * `beekeeper::hive::Hive` now has additional generic parameters for the global and local queue types. These default to `beekeeper::hive::ChannelGlobalQueue` and `beekeeper::hive::DefaultLocalQueues`, which provide the same behavior as before. - * `beekeeper::hive::ChannelHive` is the existing `Hive` implementation. * Switched to using thread-local retry queues for the implementation of the `retry` feature, to reduce thread-contention. * Switched to storing `Outcome`s in the hive using a data structure that does not require locking when inserting, which should reduce thread contention when using `*_store` operations. - * Switched to using `crossbeam_channel` for the `Hive`'s task input channel. + * Switched to using `crossbeam_channel` for the task input channel in `ChannelTaskQueues`. ## 0.2.1 diff --git a/Cargo.toml b/Cargo.toml index 45bcdb8..95c2478 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ license = "MIT OR Apache-2.0" [dependencies] crossbeam-channel = "0.5.13" crossbeam-deque = "0.8.6" +crossbeam-queue = "0.3.12" crossbeam-utils = "0.8.20" num = "0.4.3" num_cpus = "1.16.0" @@ -19,8 +20,6 @@ paste = "1.0.15" thiserror = "1.0.63" # required with the `affinity` feature core_affinity = { version = "0.8.1", optional = true } -# required with the `batching` feature -crossbeam-queue = { version = "0.3.12", optional = true } # required with alternate outcome channel implementations that can be enabled with features flume = { version = "0.11.1", optional = true } loole = { version = "0.4.0", optional = true } @@ -39,7 +38,7 @@ harness = false [features] default = ["affinity", "batching", "retry"] affinity = ["dep:core_affinity"] -batching = ["dep:crossbeam-queue"] +batching = [] retry = [] crossbeam = [] flume = ["dep:flume"] diff --git a/benches/perf.rs b/benches/perf.rs index 6314462..8bce664 100644 --- a/benches/perf.rs +++ b/benches/perf.rs @@ -1,5 +1,5 @@ use beekeeper::bee::stock::EchoWorker; -use beekeeper::hive::{outcome_channel, Builder}; +use beekeeper::hive::{outcome_channel, Builder, ChannelBuilder}; use divan::{bench, black_box_drop, AllocProfiler, Bencher}; use itertools::iproduct; @@ -9,20 +9,21 @@ static ALLOC: AllocProfiler = AllocProfiler::system(); const THREADS: &[usize] = &[1, 4, 8, 16]; const TASKS: &[usize] = &[1, 100, 10_000, 1_000_000]; -// #[bench(args = iproduct!(THREADS, TASKS))] -// fn bench_apply_short_task(bencher: Bencher, (num_threads, num_tasks): (&usize, &usize)) { -// let hive = Builder::new() -// .num_threads(*num_threads) -// .build_with_default::>(); -// bencher.bench_local(|| { -// let (tx, rx) = outcome_channel(); -// for i in 0..*num_tasks { -// hive.apply_send(i, &tx); -// } -// hive.join(); -// rx.into_iter().take(*num_tasks).for_each(black_box_drop); -// }) -// } +#[bench(args = iproduct!(THREADS, TASKS))] +fn bench_apply_short_task(bencher: Bencher, (num_threads, num_tasks): (&usize, &usize)) { + let hive = ChannelBuilder::empty() + .num_threads(*num_threads) + .with_worker_default::>() + .build(); + bencher.bench_local(|| { + let (tx, rx) = outcome_channel(); + for i in 0..*num_tasks { + hive.apply_send(i, &tx); + } + hive.join(); + rx.into_iter().take(*num_tasks).for_each(black_box_drop); + }) +} fn main() { divan::main(); diff --git a/src/bee/context.rs b/src/bee/context.rs index 09abb48..e804bc0 100644 --- a/src/bee/context.rs +++ b/src/bee/context.rs @@ -124,3 +124,25 @@ impl<'a, I> Context<'a, I> { self.attempt } } + +#[cfg(test)] +pub mod mock { + use super::{TaskContext, TaskId}; + use std::cell::RefCell; + + #[derive(Debug, Default)] + pub struct MockTaskContext(RefCell); + + impl TaskContext for MockTaskContext { + fn should_cancel_tasks(&self) -> bool { + false + } + + fn submit_task(&self, _: I) -> super::TaskId { + let mut task_id = self.0.borrow_mut(); + let cur_id = *task_id; + *task_id += 1; + cur_id + } + } +} diff --git a/src/bee/mod.rs b/src/bee/mod.rs index 75998b1..7b1027d 100644 --- a/src/bee/mod.rs +++ b/src/bee/mod.rs @@ -115,14 +115,16 @@ mod queen; pub mod stock; mod worker; -pub use context::{Context, TaskContext, TaskId}; -pub use error::{ApplyError, ApplyRefError}; -pub use queen::{CloneQueen, DefaultQueen, Queen, QueenCell, QueenMut}; -pub use worker::{RefWorker, RefWorkerResult, Worker, WorkerError, WorkerResult}; +#[cfg(test)] +pub use self::context::mock::MockTaskContext; +pub use self::context::{Context, TaskContext, TaskId}; +pub use self::error::{ApplyError, ApplyRefError}; +pub use self::queen::{CloneQueen, DefaultQueen, Queen, QueenCell, QueenMut}; +pub use self::worker::{RefWorker, RefWorkerResult, Worker, WorkerError, WorkerResult}; pub mod prelude { pub use super::{ - ApplyError, ApplyRefError, Context, Queen, QueenMut, RefWorker, RefWorkerResult, Worker, - WorkerError, WorkerResult, + ApplyError, ApplyRefError, Context, Queen, QueenCell, QueenMut, RefWorker, RefWorkerResult, + Worker, WorkerError, WorkerResult, }; } diff --git a/src/bee/queen.rs b/src/bee/queen.rs index 64ba7b1..7642467 100644 --- a/src/bee/queen.rs +++ b/src/bee/queen.rs @@ -2,6 +2,7 @@ use super::Worker; use parking_lot::RwLock; use std::marker::PhantomData; +use std::ops::Deref; /// A trait for factories that create `Worker`s. pub trait Queen: Send + Sync + 'static { @@ -21,7 +22,9 @@ pub trait QueenMut: Send + Sync + 'static { fn create(&mut self) -> Self::Kind; } -/// A wrapper for a `MutQueen` that implements `Queen` using an `RwLock` internally. +/// A wrapper for a `MutQueen` that implements `Queen`. +/// +/// Interior mutability is enabled using an `RwLock`. pub struct QueenCell(RwLock); impl QueenCell { @@ -29,6 +32,10 @@ impl QueenCell { Self(RwLock::new(mut_queen)) } + pub fn get(&self) -> impl Deref + '_ { + self.0.read() + } + pub fn into_inner(self) -> Q { self.0.into_inner() } @@ -48,6 +55,12 @@ impl Default for QueenCell { } } +impl From for QueenCell { + fn from(queen: Q) -> Self { + Self::new(queen) + } +} + /// A `Queen` that can create a `Worker` type that implements `Default`. /// /// Note that, for the implementation to be generic, `W` also needs to be `Send` and `Sync`. If you diff --git a/src/hive/builder.rs b/src/hive/builder.rs deleted file mode 100644 index 6fdc4e5..0000000 --- a/src/hive/builder.rs +++ /dev/null @@ -1,623 +0,0 @@ -use super::{Config, Hive, QueuePair}; -use crate::bee::{CloneQueen, DefaultQueen, Queen, Worker}; - -/// A `Builder` for a [`Hive`](crate::hive::Hive). -/// -/// Calling [`Builder::new()`] creates an unconfigured `Builder`, while calling -/// [`Builder::default()`] creates a `Builder` with fields pre-set to the global default values. -/// Global defaults can be changed using the -/// [`beekeeper::hive::set_*_default`](crate::hive#functions) functions. -/// -/// The configuration options available: -/// * [`Builder::num_threads`]: number of worker threads that will be spawned by the built `Hive`. -/// * [`Builder::with_default_num_threads`] will set `num_threads` to the global default value. -/// * [`Builder::with_thread_per_core`] will set `num_threads` to the number of available CPU -/// cores. -/// * [`Builder::thread_name`]: thread name for each of the threads spawned by the built `Hive`. By -/// default, threads are unnamed. -/// * [`Builder::thread_stack_size`]: stack size (in bytes) for each of the threads spawned by the -/// built `Hive`. See the -/// [`std::thread`](https://doc.rust-lang.org/stable/std/thread/index.html#stack-size) -/// documentation for details on the default stack size. -/// -/// The following configuration options are available when the `retry` feature is enabled: -/// * [`Builder::max_retries`]: maximum number of times a `Worker` will retry an -/// [`ApplyError::Retryable`](crate::bee::ApplyError#Retryable) before giving up. -/// * [`Builder::retry_factor`]: [`Duration`](std::time::Duration) factor for exponential backoff -/// when retrying an `ApplyError::Retryable` error. -/// * [`Builder::with_default_retries`] sets the retry options to the global defaults, while -/// [`Builder::with_no_retries`] disabled retrying. -/// -/// The following configuration options are available when the `affinity` feature is enabled: -/// * [`Builder::core_affinity`]: List of CPU core indices to which the threads should be pinned. -/// * [`Builder::with_default_core_affinity`] will set the list to all CPU core indices, though -/// only the first `num_threads` indices will be used. -/// -/// To create the [`Hive`], call one of the `build*` methods: -/// * [`Builder::build`] requires a [`Queen`] instance. -/// * [`Builder::build_default`] requires a [`Queen`] type that implements [`Default`]. -/// * [`Builder::build_with`] requires a [`Worker`] instance that implements [`Clone`]. -/// * [`Builder::build_with_default`] requires a [`Worker`] type that implements [`Default`]. -/// -/// # Examples -/// -/// Build a [`Hive`] that uses a maximum of eight threads simultaneously and each thread has -/// a 8 MB stack size: -/// -/// ``` -/// type MyWorker = beekeeper::bee::stock::ThunkWorker<()>; -/// -/// let hive = beekeeper::hive::Builder::new() -/// .num_threads(8) -/// .thread_stack_size(8_000_000) -/// .build_with_default::(); -/// ``` -#[derive(Clone)] -pub struct Builder(Config); - -impl Builder { - /// Returns a new `Builder` with no options configured. - pub fn new() -> Self { - Self(Config::empty()) - } - - /// Sets the maximum number of worker threads that will be alive at any given moment in the - /// built [`Hive`]. If not specified, the built `Hive` will not be initialized with worker - /// threads until [`Hive::grow`] is called. - /// - /// # Examples - /// - /// No more than eight threads will be alive simultaneously for this hive: - /// - /// ``` - /// use beekeeper::bee::stock::{Thunk, ThunkWorker}; - /// use beekeeper::hive::{Builder, Hive}; - /// - /// # fn main() { - /// let hive = Builder::new() - /// .num_threads(8) - /// .build_with_default::>(); - /// - /// for _ in 0..100 { - /// hive.apply_store(Thunk::of(|| { - /// println!("Hello from a worker thread!") - /// })); - /// } - /// # } - /// ``` - pub fn num_threads(mut self, num: usize) -> Self { - let _ = self.0.num_threads.set(Some(num)); - self - } - - /// Sets the number of worker threads to the global default value. - pub fn with_default_num_threads(mut self) -> Self { - let _ = self - .0 - .num_threads - .set(super::config::DEFAULTS.lock().num_threads.get()); - self - } - - /// Specifies that the built [`Hive`] will use all available CPU cores for worker threads. - /// - /// # Examples - /// - /// All available threads will be alive simultaneously for this hive: - /// - /// ``` - /// use beekeeper::bee::stock::{Thunk, ThunkWorker}; - /// use beekeeper::hive::{Builder, Hive}; - /// - /// # fn main() { - /// let hive = Builder::new() - /// .with_thread_per_core() - /// .build_with_default::>(); - /// - /// for _ in 0..100 { - /// hive.apply_store(Thunk::of(|| { - /// println!("Hello from a worker thread!") - /// })); - /// } - /// # } - /// ``` - pub fn with_thread_per_core(mut self) -> Self { - let _ = self.0.num_threads.set(Some(num_cpus::get())); - self - } - - /// Sets the thread name for each of the threads spawned by the built [`Hive`]. If not - /// specified, threads spawned by the thread pool will be unnamed. - /// - /// # Examples - /// - /// Each thread spawned by this hive will have the name `"foo"`: - /// - /// ``` - /// use beekeeper::bee::stock::{Thunk, ThunkWorker}; - /// use beekeeper::hive::{Builder, Hive}; - /// use std::thread; - /// - /// # fn main() { - /// let hive = Builder::default() - /// .thread_name("foo") - /// .build_with_default::>(); - /// - /// for _ in 0..100 { - /// hive.apply_store(Thunk::of(|| { - /// assert_eq!(thread::current().name(), Some("foo")); - /// })); - /// } - /// # hive.join(); - /// # } - /// ``` - pub fn thread_name>(mut self, name: T) -> Self { - let _ = self.0.thread_name.set(Some(name.into())); - self - } - - /// Sets the stack size (in bytes) for each of the threads spawned by the built [`Hive`]. - /// If not specified, threads spawned by the hive will have a stack size [as specified in - /// the `std::thread` documentation][thread]. - /// - /// [thread]: https://doc.rust-lang.org/nightly/std/thread/index.html#stack-size - /// - /// # Examples - /// - /// Each thread spawned by this hive will have a 4 MB stack: - /// - /// ``` - /// use beekeeper::bee::stock::{Thunk, ThunkWorker}; - /// use beekeeper::hive::{Builder, Hive}; - /// - /// # fn main() { - /// let hive = Builder::default() - /// .thread_stack_size(4_000_000) - /// .build_with_default::>(); - /// - /// for _ in 0..100 { - /// hive.apply_store(Thunk::of(|| { - /// println!("This thread has a 4 MB stack size!"); - /// })); - /// } - /// # hive.join(); - /// # } - /// ``` - pub fn thread_stack_size(mut self, size: usize) -> Self { - let _ = self.0.thread_stack_size.set(Some(size)); - self - } - - /// Consumes this `Builder` and returns a new [`Hive`] using the given [`Queen`] to create - /// [`Worker`]s. - /// - /// Returns an error if there was an error spawning the worker threads. - /// - /// # Examples - /// - /// ``` - /// # use beekeeper::hive::{Builder, Hive}; - /// # use beekeeper::bee::{Context, Queen, Worker, WorkerResult}; - /// - /// #[derive(Debug)] - /// struct CounterWorker { - /// index: usize, - /// input_count: usize, - /// input_sum: usize, - /// } - /// - /// impl CounterWorker { - /// fn new(index: usize) -> Self { - /// Self { - /// index, - /// input_count: 0, - /// input_sum: 0, - /// } - /// } - /// } - /// - /// impl Worker for CounterWorker { - /// type Input = usize; - /// type Output = String; - /// type Error = (); - /// - /// fn apply(&mut self, input: Self::Input, _: &Context) -> WorkerResult { - /// self.input_count += 1; - /// self.input_sum += input; - /// let s = format!( - /// "CounterWorker {}: Input {}, Count {}, Sum {}", - /// self.index, input, self.input_count, self.input_sum - /// ); - /// Ok(s) - /// } - /// } - /// - /// #[derive(Debug, Default)] - /// struct CounterQueen { - /// num_workers: usize - /// } - /// - /// impl Queen for CounterQueen { - /// type Kind = CounterWorker; - /// - /// fn create(&mut self) -> Self::Kind { - /// self.num_workers += 1; - /// CounterWorker::new(self.num_workers) - /// } - /// } - /// - /// # fn main() { - /// let hive = Builder::new() - /// .num_threads(8) - /// .thread_stack_size(4_000_000) - /// .build(CounterQueen::default()); - /// - /// for i in 0..100 { - /// hive.apply_store(i); - /// } - /// let husk = hive.try_into_husk().unwrap(); - /// assert_eq!(husk.queen().num_workers, 8); - /// # } - /// ``` - pub fn build>( - self, - queen: Q, - ) -> Hive { - Hive::new(self.0, queen) - } - - /// Consumes this `Builder` and returns a new [`Hive`] using a [`Queen`] created with - /// [`Q::default()`](std::default::Default) to create [`Worker`]s. - /// - /// Returns an error if there was an error spawning the worker threads. - pub fn build_default>( - self, - ) -> Hive { - Hive::new(self.0, Q::default()) - } - - /// Consumes this `Builder` and returns a new [`Hive`] with [`Worker`]s created by cloning - /// `worker`. - /// - /// Returns an error if there was an error spawning the worker threads. - /// - /// # Examples - /// - /// ``` - /// # use beekeeper::hive::{Builder, OutcomeIteratorExt}; - /// # use beekeeper::bee::{Context, Worker, WorkerResult}; - /// - /// #[derive(Debug, Clone)] - /// struct MathWorker(isize); - /// - /// impl MathWorker { - /// fn new(left_operand: isize) -> Self { - /// assert!(left_operand != 0); - /// Self(left_operand) - /// } - /// } - /// - /// impl Worker for MathWorker { - /// type Input = (isize, u8); - /// type Output = isize; - /// type Error = (); - /// - /// fn apply(&mut self, input: Self::Input, _: &Context) -> WorkerResult { - /// let (operand, operator) = input; - /// let value = match operator % 4 { - /// 0 => operand + self.0, - /// 1 => operand - self.0, - /// 2 => operand * self.0, - /// 3 => operand / self.0, - /// _ => unreachable!(), - /// }; - /// Ok(value) - /// } - /// } - /// - /// # fn main() { - /// let hive = Builder::new() - /// .num_threads(8) - /// .thread_stack_size(4_000_000) - /// .build_with(MathWorker(5isize)); - /// - /// let sum: isize = hive - /// .map((0..100).zip((0..4).cycle())) - /// .into_outputs() - /// .sum(); - /// assert_eq!(sum, 8920); - /// # } - /// ``` - pub fn build_with>( - self, - worker: W, - ) -> Hive, P::Global, P::Local, P> - where - W: Worker + Send + Sync + Clone, - { - Hive::new(self.0, CloneQueen::new(worker)) - } - - /// Consumes this `Builder` and returns a new [`Hive`] with [`Worker`]s created using - /// [`W::default()`](std::default::Default). - /// - /// Returns a [`SpawnError`](crate::hive::SpawnError) if there was an error spawning the - /// worker threads. - /// - /// # Examples - /// - /// ``` - /// # use beekeeper::hive::{Builder, OutcomeIteratorExt}; - /// # use beekeeper::bee::{Context, Worker, WorkerResult}; - /// # use std::num::NonZeroIsize; - /// - /// #[derive(Debug, Default)] - /// struct MathWorker(isize); // value is always `0` - /// - /// impl Worker for MathWorker { - /// type Input = (NonZeroIsize, u8); - /// type Output = isize; - /// type Error = (); - /// - /// fn apply(&mut self, input: Self::Input, _: &Context) -> WorkerResult { - /// let (operand, operator) = input; - /// let result = match operator % 4 { - /// 0 => self.0 + operand.get(), - /// 1 => self.0 - operand.get(), - /// 2 => self.0 * operand.get(), - /// 3 => self.0 / operand.get(), - /// _ => unreachable!(), - /// }; - /// Ok(result) - /// } - /// } - /// - /// # fn main() { - /// let hive = Builder::new() - /// .num_threads(8) - /// .thread_stack_size(4_000_000) - /// .build_with_default::(); - /// - /// let sum: isize = hive - /// .map((1..=100).map(|i| NonZeroIsize::new(i).unwrap()).zip((0..4).cycle())) - /// .into_outputs() - /// .sum(); - /// assert_eq!(sum, -25); - /// # } - /// ``` - pub fn build_with_default>( - self, - ) -> Hive, P::Global, P::Local, P> - where - W: Worker + Send + Sync + Default, - { - Hive::new(self.0, DefaultQueen::default()) - } -} - -impl Default for Builder { - /// Creates a new `Builder` with default configuration options: - /// * `num_threads = config::DEFAULT_NUM_THREADS` - /// - /// The following default configuration options are used when the `retry` feature is enabled: - /// * `max_retries = config::retry::DEFAULT_MAX_RETRIES` - /// * `retry_factor = config::retry::DEFAULT_RETRY_FACTOR_SECS` - fn default() -> Self { - Builder(Config::with_defaults()) - } -} - -impl From for Builder { - fn from(value: Config) -> Self { - Self(value) - } -} - -#[cfg(feature = "affinity")] -mod affinity { - use super::Builder; - use crate::hive::cores::Cores; - - impl Builder { - /// Sets set list of CPU core indices to which threads in the `Hive` should be pinned. - /// - /// Core indices are integers in the range `0..N`, where `N` is the number of available CPU - /// cores as reported by [`num_cpus::get()`]. The mapping between core indices and core IDs - /// is platform-specific. All CPU cores on a given system should be equivalent, and thus it - /// does not matter which cores are pinned so long as a core is not pinned to multiple - /// threads. - /// - /// Excess core indices (i.e., if `affinity.len() > num_threads`) are ignored. If - /// `affinity.len() < num_threads` then the excess threads will not be pinned. - /// - /// # Examples - /// - /// Each thread spawned by this hive will be pinned to a core: - /// - /// ``` - /// use beekeeper::bee::stock::{Thunk, ThunkWorker}; - /// use beekeeper::hive::{Builder, Hive}; - /// - /// # fn main() { - /// let hive = Builder::new() - /// .num_threads(4) - /// .core_affinity(0..4) - /// .build_with_default::>(); - /// - /// for _ in 0..100 { - /// hive.apply_store(Thunk::of(|| { - /// println!("This thread is pinned!"); - /// })); - /// } - /// # hive.join(); - /// # } - /// ``` - pub fn core_affinity>(mut self, affinity: C) -> Self { - let _ = self.0.affinity.set(Some(affinity.into())); - self - } - - /// Specifies that worker threads should be pinned to all available CPU cores. If - /// `num_threads` is greater than the available number of CPU cores, then some threads - /// might not be pinned. - pub fn with_default_core_affinity(mut self) -> Self { - let _ = self.0.affinity.set(Some(Cores::all())); - self - } - } - - #[cfg(test)] - mod tests { - use crate::hive::cores::Cores; - use crate::hive::Builder; - - #[test] - fn test_with_affinity() { - let mut builder = Builder::new(); - builder = builder.with_default_core_affinity(); - assert_eq!(builder.0.affinity.get(), Some(Cores::all())); - } - } -} - -#[cfg(feature = "batching")] -mod batching { - use super::Builder; - - impl Builder { - /// Sets the worker thread batch size. If `batch_size` is `0`, batching is disabled, but - /// note that the performance may be worse than with the `batching` feature disabled. - pub fn batch_size(mut self, batch_size: usize) -> Self { - if batch_size == 0 { - self.0.batch_size.set(None); - } else { - self.0.batch_size.set(Some(batch_size)); - } - self - } - - /// Sets the worker thread batch size to the global default value. - pub fn with_default_batch_size(mut self) -> Self { - let _ = self - .0 - .batch_size - .set(crate::hive::config::DEFAULTS.lock().batch_size.get()); - self - } - } -} - -#[cfg(feature = "retry")] -mod retry { - use super::Builder; - use std::time::Duration; - - impl Builder { - /// Sets the maximum number of times to retry a - /// [`ApplyError::Retryable`](crate::bee::ApplyError::Retryable) error. A worker - /// thread will retry a task until it either returns - /// [`ApplyError::Fatal`](crate::bee::ApplyError::Fatal) or the maximum number of retries is - /// reached. Each time a task is retried, the worker thread will first sleep for - /// `retry_factor * (2 ** (attempt - 1))` before attempting the task again. If not - /// specified, tasks are retried a default number of times. If set to `0`, tasks will be - /// retried immediately without delay. - /// - /// # Examples - /// - /// ``` - /// use beekeeper::bee::{ApplyError, Context}; - /// use beekeeper::bee::stock::RetryCaller; - /// use beekeeper::hive::{Builder, Hive}; - /// use std::time; - /// - /// fn sometimes_fail( - /// i: usize, - /// _: &Context - /// ) -> Result> { - /// match i % 3 { - /// 0 => Ok("Success".into()), - /// 1 => Err(ApplyError::Retryable { input: i, error: "Retryable".into() }), - /// 2 => Err(ApplyError::Fatal { input: Some(i), error: "Fatal".into() }), - /// _ => unreachable!(), - /// } - /// } - /// - /// # fn main() { - /// let hive = Builder::default() - /// .max_retries(3) - /// .build_with(RetryCaller::of(sometimes_fail)); - /// - /// for i in 0..10 { - /// hive.apply_store(i); - /// } - /// # hive.join(); - /// # } - /// ``` - pub fn max_retries(mut self, limit: u32) -> Self { - let _ = if limit == 0 { - self.0.max_retries.set(None) - } else { - self.0.max_retries.set(Some(limit)) - }; - self - } - - /// Sets the exponential back-off factor for retrying tasks. Each time a task is retried, - /// the thread will first sleep for `retry_factor * (2 ** (attempt - 1))`. If not - /// specififed, a default retry factor is used. Set to - /// [`Duration::ZERO`](std::time::Duration::ZERO) to disableexponential backoff. - /// - /// # Examples - /// - /// ``` - /// use beekeeper::bee::{ApplyError, Context}; - /// use beekeeper::bee::stock::RetryCaller; - /// use beekeeper::hive::{Builder, Hive}; - /// use std::time; - /// - /// fn echo_time(i: usize, ctx: &Context) -> Result> { - /// let attempt = ctx.attempt(); - /// if attempt == 3 { - /// Ok("Success".into()) - /// } else { - /// // the delay between each message should be exponential - /// println!("Task {} attempt {}: {:?}", i, attempt, time::SystemTime::now()); - /// Err(ApplyError::Retryable { input: i, error: "Retryable".into() }) - /// } - /// } - /// - /// # fn main() { - /// let hive = Builder::default() - /// .max_retries(3) - /// .retry_factor(time::Duration::from_secs(1)) - /// .build_with(RetryCaller::of(echo_time)); - /// - /// for i in 0..10 { - /// hive.apply_store(i); - /// } - /// # hive.join(); - /// # } - /// ``` - pub fn retry_factor(mut self, duration: Duration) -> Self { - let _ = if duration == Duration::ZERO { - self.0.retry_factor.set(None) - } else { - self.0.set_retry_factor_from(duration) - }; - self - } - - /// Sets retry parameters to their default values. - pub fn with_default_retries(mut self) -> Self { - let defaults = crate::hive::config::DEFAULTS.lock(); - let _ = self.0.max_retries.set(defaults.max_retries.get()); - let _ = self.0.retry_factor.set(defaults.retry_factor.get()); - self - } - - /// Disables retrying tasks. - pub fn with_no_retries(self) -> Self { - self.max_retries(0).retry_factor(Duration::ZERO) - } - } -} diff --git a/src/hive/builder/bee.rs b/src/hive/builder/bee.rs new file mode 100644 index 0000000..466acf4 --- /dev/null +++ b/src/hive/builder/bee.rs @@ -0,0 +1,127 @@ +use super::{BuilderConfig, FullBuilder, Token}; +use crate::bee::{CloneQueen, DefaultQueen, Queen, QueenCell, QueenMut, Worker}; +use crate::hive::{ChannelTaskQueues, Config, TaskQueues}; + +/// A Builder for creating `Hive` instances for specific [`Worker`] and [`TaskQueues`] types. +#[derive(Clone, Default)] +pub struct BeeBuilder { + config: Config, + queen: Q, +} + +impl BeeBuilder { + /// Creates a new `BeeBuilder` with the given queen and no options configured. + pub fn empty>(queen: Q) -> Self { + Self { + config: Config::empty(), + queen: queen.into(), + } + } + + /// Creates a new `BeeBuilder` with the given `queen` and options configured with global + /// preset values. + pub fn preset>(queen: Q) -> Self { + Self { + config: Config::default(), + queen: queen.into(), + } + } + + /// Creates a new `BeeBuilder` from an existing `config` and a `queen`. + pub(super) fn from(config: Config, queen: Q) -> Self { + Self { + config, + queen: queen.into(), + } + } + + /// Creates a new `FullBuilder` with the current configuration and queen and specified + /// `TaskQueues` type. + pub fn with_queues>(self) -> FullBuilder { + FullBuilder::from(self.config, self.queen) + } + + /// Creates a new `FullBuilder` with the current configuration and queen and channel-based + /// task queues. + pub fn with_channel_queues(self) -> FullBuilder> { + FullBuilder::from(self.config, self.queen) + } +} + +impl BeeBuilder { + /// Creates a new `BeeBuilder` with a queen created with + /// [`Q::default()`](std::default::Default) and no options configured. + pub fn empty_with_queen_default() -> Self { + Self { + config: Config::empty(), + queen: Q::default(), + } + } + + /// Creates a new `BeeBuilder` with a queen created with + /// [`Q::default()`](std::default::Default) and options configured with global defaults. + pub fn preset_with_queen_default() -> Self { + Self { + config: Config::default(), + queen: Q::default(), + } + } +} + +impl BeeBuilder> { + /// Creates a new `BeeBuilder` with a queen created with + /// [`Q::default()`](std::default::Default) and no options configured. + pub fn empty_with_queen_mut_default() -> Self { + Self { + config: Config::empty(), + queen: QueenCell::new(Q::default()), + } + } + + /// Creates a new `BeeBuilder` with a queen created with + /// [`Q::default()`](std::default::Default) and options configured with global defaults. + pub fn preset_with_queen_mut_default() -> Self { + Self { + config: Config::default(), + queen: QueenCell::new(Q::default()), + } + } +} + +impl BeeBuilder> { + pub fn empty_with_worker(worker: W) -> Self { + Self { + config: Config::empty(), + queen: CloneQueen::new(worker), + } + } + + pub fn default_with_worker(worker: W) -> Self { + Self { + config: Config::default(), + queen: CloneQueen::new(worker), + } + } +} + +impl BeeBuilder> { + pub fn empty_with_worker_default() -> Self { + Self { + config: Config::empty(), + queen: DefaultQueen::default(), + } + } + + pub fn preset_with_worker_default() -> Self { + Self { + config: Config::default(), + queen: DefaultQueen::default(), + } + } +} + +impl BuilderConfig for BeeBuilder { + fn config(&mut self, _: Token) -> &mut Config { + &mut self.config + } +} diff --git a/src/hive/builder/channel.rs b/src/hive/builder/channel.rs new file mode 100644 index 0000000..1525d5f --- /dev/null +++ b/src/hive/builder/channel.rs @@ -0,0 +1,71 @@ +use super::{BuilderConfig, FullBuilder, Token}; +use crate::bee::{CloneQueen, DefaultQueen, Queen, QueenCell, QueenMut, Worker}; +use crate::hive::{ChannelTaskQueues, Config}; + +#[derive(Clone, Default)] +pub struct ChannelBuilder(Config); + +impl ChannelBuilder { + /// Creates a new `ChannelBuilder` with the given queen and no options configured. + pub fn empty() -> Self { + Self(Config::empty()) + } + + /// Consumes this `Builder` and returns a new [`FullBuilder`] using the given [`Queen`] to + /// create [`Worker`]s. + pub fn with_queen(self, queen: I) -> FullBuilder> + where + Q: Queen, + I: Into, + { + FullBuilder::from(self.0, queen.into()) + } + + /// Consumes this `Builder` and returns a new [`FullBuilder`] using a [`Queen`] created with + /// [`Q::default()`](std::default::Default) to create [`Worker`]s. + pub fn with_queen_default(self) -> FullBuilder> + where + Q: Queen + Default, + { + FullBuilder::from(self.0, Q::default()) + } + + /// Consumes this `Builder` and returns a new [`FullBuilder`] using a [`QueenMut`] created with + /// [`Q::default()`](std::default::Default) to create [`Worker`]s. + pub fn with_queen_mut_default(self) -> FullBuilder, ChannelTaskQueues> + where + Q: QueenMut + Default, + { + FullBuilder::from(self.0, QueenCell::new(Q::default())) + } + + /// Consumes this `Builder` and returns a new [`FullBuilder`] with [`Worker`]s created by + /// cloning `worker`. + pub fn with_worker(self, worker: W) -> FullBuilder, ChannelTaskQueues> + where + W: Worker + Send + Sync + Clone, + { + FullBuilder::from(self.0, CloneQueen::new(worker)) + } + + /// Consumes this `Builder` and returns a new [`FullBuilder`] with [`Worker`]s created using + /// [`W::default()`](std::default::Default). + pub fn with_worker_default(self) -> FullBuilder, ChannelTaskQueues> + where + W: Worker + Send + Sync + Default, + { + FullBuilder::from(self.0, DefaultQueen::default()) + } +} + +impl BuilderConfig for ChannelBuilder { + fn config(&mut self, _: Token) -> &mut Config { + &mut self.0 + } +} + +impl From for ChannelBuilder { + fn from(value: Config) -> Self { + Self(value) + } +} diff --git a/src/hive/builder/full.rs b/src/hive/builder/full.rs new file mode 100644 index 0000000..e2a88e7 --- /dev/null +++ b/src/hive/builder/full.rs @@ -0,0 +1,48 @@ +use super::{BuilderConfig, Token}; +use crate::bee::Queen; +use crate::hive::{Config, Hive, TaskQueues}; +use std::marker::PhantomData; + +/// A Builder for creating `Hive` instances for specific [`Queen`] and [`TaskQueues`] types. +#[derive(Clone, Default)] +pub struct FullBuilder> { + config: Config, + queen: Q, + _queues: PhantomData, +} + +impl> FullBuilder { + pub fn empty>(queen: Q) -> Self { + Self { + config: Config::empty(), + queen: queen.into(), + _queues: PhantomData, + } + } + + pub fn preset>(queen: I) -> Self { + Self { + config: Config::default(), + queen: queen.into(), + _queues: PhantomData, + } + } + + pub(super) fn from(config: Config, queen: Q) -> Self { + Self { + config, + queen, + _queues: PhantomData, + } + } + + pub fn build(self) -> Hive { + Hive::new(self.config, self.queen) + } +} + +impl> BuilderConfig for FullBuilder { + fn config(&mut self, _: Token) -> &mut Config { + &mut self.config + } +} diff --git a/src/hive/builder/mod.rs b/src/hive/builder/mod.rs new file mode 100644 index 0000000..7c809cd --- /dev/null +++ b/src/hive/builder/mod.rs @@ -0,0 +1,50 @@ +//! There are a few different builder types. All builders implement the `BuilderConfig` trait, +//! which provides methods to set configuration parameters. +//! +//! * Open: has no type parameters; can only set config parameters. Has methods to create +//! typed builders. +//! * Bee-typed: has type parameters for the `Worker` and `Queen` types. +//! * Queue-typed: builder instances that are specific to the `TaskQueues` type. +//! * Fully-typed: builder that has type parameters for the `Worker`, `Queen`, and `TaskQueues` +//! types. This is the only builder with a `build` method to create a `Hive`. +//! +//! Generic - Queue +//! | / +//! Bee / +//! | / +//! Full +mod bee; +mod channel; +mod full; +mod open; + +pub use bee::BeeBuilder; +pub use channel::ChannelBuilder; +pub use full::FullBuilder; +pub use open::OpenBuilder; + +use crate::hive::inner::{BuilderConfig, Token}; + +// #[cfg(all(test, feature = "affinity"))] +// mod affinity_tests { +// use super::{OpenBuilder, Token}; +// use crate::hive::cores::Cores; + +// #[test] +// fn test_with_affinity() { +// let mut builder = OpenBuilder::empty(); +// builder = builder.with_default_core_affinity(); +// assert_eq!(builder.config(Token).affinity.get(), Some(Cores::all())); +// } +// } + +// #[cfg(all(test, feature = "batching"))] +// mod batching_tests { +// use super::OpenBuilder; +// } + +// #[cfg(all(test, feature = "retry"))] +// mod retry_tests { +// use super::OpenBuilder; +// use std::time::Duration; +// } diff --git a/src/hive/builder/open.rs b/src/hive/builder/open.rs new file mode 100644 index 0000000..89d9ab8 --- /dev/null +++ b/src/hive/builder/open.rs @@ -0,0 +1,277 @@ +use super::{BeeBuilder, BuilderConfig, ChannelBuilder, Token}; +use crate::bee::{CloneQueen, DefaultQueen, Queen, QueenCell, QueenMut, Worker}; +use crate::hive::Config; + +/// A builder for a [`Hive`](crate::hive::Hive). +/// +/// Calling [`Builder::new()`] creates an unconfigured `Builder`, while calling +/// [`Builder::default()`] creates a `Builder` with fields preset to the global default values. +/// Global defaults can be changed using the +/// [`beekeeper::hive::set_*_default`](crate::hive#functions) functions. +/// +/// The configuration options available: +/// * [`Builder::num_threads`]: number of worker threads that will be spawned by the built `Hive`. +/// * [`Builder::with_default_num_threads`] will set `num_threads` to the global default value. +/// * [`Builder::with_thread_per_core`] will set `num_threads` to the number of available CPU +/// cores. +/// * [`Builder::thread_name`]: thread name for each of the threads spawned by the built `Hive`. By +/// default, threads are unnamed. +/// * [`Builder::thread_stack_size`]: stack size (in bytes) for each of the threads spawned by the +/// built `Hive`. See the +/// [`std::thread`](https://doc.rust-lang.org/stable/std/thread/index.html#stack-size) +/// documentation for details on the default stack size. +/// +/// The following configuration options are available when the `retry` feature is enabled: +/// * [`Builder::max_retries`]: maximum number of times a `Worker` will retry an +/// [`ApplyError::Retryable`](crate::bee::ApplyError#Retryable) before giving up. +/// * [`Builder::retry_factor`]: [`Duration`](std::time::Duration) factor for exponential backoff +/// when retrying an `ApplyError::Retryable` error. +/// * [`Builder::with_default_retries`] sets the retry options to the global defaults, while +/// [`Builder::with_no_retries`] disabled retrying. +/// +/// The following configuration options are available when the `affinity` feature is enabled: +/// * [`Builder::core_affinity`]: List of CPU core indices to which the threads should be pinned. +/// * [`Builder::with_default_core_affinity`] will set the list to all CPU core indices, though +/// only the first `num_threads` indices will be used. +/// +/// To create the [`Hive`], call one of the `build*` methods: +/// * [`Builder::build`] requires a [`Queen`] instance. +/// * [`Builder::build_default`] requires a [`Queen`] type that implements [`Default`]. +/// * [`Builder::build_with`] requires a [`Worker`] instance that implements [`Clone`]. +/// * [`Builder::build_with_default`] requires a [`Worker`] type that implements [`Default`]. +/// +/// # Examples +/// +/// Build a [`Hive`] that uses a maximum of eight threads simultaneously and each thread has +/// a 8 MB stack size: +/// +/// ``` +/// type MyWorker = beekeeper::bee::stock::ThunkWorker<()>; +/// +/// let hive = beekeeper::hive::Builder::empty() +/// .num_threads(8) +/// .thread_stack_size(8_000_000) +/// .with_worker_default::() +/// .with_channel_queues() +/// .build(); +/// ``` +#[derive(Clone, Default)] +pub struct OpenBuilder(Config); + +impl OpenBuilder { + /// Returns a new `Builder` with no options configured. + pub fn empty() -> Self { + Self(Config::empty()) + } + + /// Consumes this `Builder` and returns a new [`BeeBuilder`] using the given [`Queen`] to + /// create [`Worker`]s. + /// + /// # Examples + /// + /// ``` + /// # use beekeeper::hive::{Builder, Hive}; + /// # use beekeeper::bee::{Context, Queen, Worker, WorkerResult}; + /// + /// #[derive(Debug)] + /// struct CounterWorker { + /// index: usize, + /// input_count: usize, + /// input_sum: usize, + /// } + /// + /// impl CounterWorker { + /// fn new(index: usize) -> Self { + /// Self { + /// index, + /// input_count: 0, + /// input_sum: 0, + /// } + /// } + /// } + /// + /// impl Worker for CounterWorker { + /// type Input = usize; + /// type Output = String; + /// type Error = (); + /// + /// fn apply(&mut self, input: Self::Input, _: &Context) -> WorkerResult { + /// self.input_count += 1; + /// self.input_sum += input; + /// let s = format!( + /// "CounterWorker {}: Input {}, Count {}, Sum {}", + /// self.index, input, self.input_count, self.input_sum + /// ); + /// Ok(s) + /// } + /// } + /// + /// #[derive(Debug, Default)] + /// struct CounterQueen { + /// num_workers: usize + /// } + /// + /// impl Queen for CounterQueen { + /// type Kind = CounterWorker; + /// + /// fn create(&mut self) -> Self::Kind { + /// self.num_workers += 1; + /// CounterWorker::new(self.num_workers) + /// } + /// } + /// + /// # fn main() { + /// let hive = Builder::new() + /// .num_threads(8) + /// .thread_stack_size(4_000_000) + /// .build(CounterQueen::default()); + /// + /// for i in 0..100 { + /// hive.apply_store(i); + /// } + /// let husk = hive.try_into_husk().unwrap(); + /// assert_eq!(husk.queen().num_workers, 8); + /// # } + /// ``` + pub fn with_queen>(self, queen: I) -> BeeBuilder { + BeeBuilder::from(self.0, queen.into()) + } + + /// Consumes this `Builder` and returns a new [`BeeBuilder`] using a [`Queen`] created with + /// [`Q::default()`](std::default::Default) to create [`Worker`]s. + pub fn with_queen_default(self) -> BeeBuilder { + BeeBuilder::from(self.0, Q::default()) + } + + /// Consumes this `Builder` and returns a new [`BeeBuilder`] using a [`QueenMut`] created with + /// [`Q::default()`](std::default::Default) to create [`Worker`]s. + pub fn with_queen_mut_default(self) -> BeeBuilder> { + BeeBuilder::from(self.0, QueenCell::new(Q::default())) + } + + /// Consumes this `Builder` and returns a new [`BeeBuilder`] with [`Worker`]s created by + /// cloning `worker`. + /// + /// # Examples + /// + /// ``` + /// # use beekeeper::hive::{Builder, OutcomeIteratorExt}; + /// # use beekeeper::bee::{Context, Worker, WorkerResult}; + /// + /// #[derive(Debug, Clone)] + /// struct MathWorker(isize); + /// + /// impl MathWorker { + /// fn new(left_operand: isize) -> Self { + /// assert!(left_operand != 0); + /// Self(left_operand) + /// } + /// } + /// + /// impl Worker for MathWorker { + /// type Input = (isize, u8); + /// type Output = isize; + /// type Error = (); + /// + /// fn apply(&mut self, input: Self::Input, _: &Context) -> WorkerResult { + /// let (operand, operator) = input; + /// let value = match operator % 4 { + /// 0 => operand + self.config(Token), + /// 1 => operand - self.config(Token), + /// 2 => operand * self.config(Token), + /// 3 => operand / self.config(Token), + /// _ => unreachable!(), + /// }; + /// Ok(value) + /// } + /// } + /// + /// # fn main() { + /// let hive = Builder::new() + /// .num_threads(8) + /// .thread_stack_size(4_000_000) + /// .build_with(MathWorker(5isize)); + /// + /// let sum: isize = hive + /// .map((0..100).zip((0..4).cycle())) + /// .into_outputs() + /// .sum(); + /// assert_eq!(sum, 8920); + /// # } + /// ``` + pub fn with_worker(self, worker: W) -> BeeBuilder> + where + W: Worker + Send + Sync + Clone, + { + BeeBuilder::from(self.0, CloneQueen::new(worker)) + } + + /// Consumes this `Builder` and returns a new [`BeeBuilder`] with [`Worker`]s created using + /// [`W::default()`](std::default::Default). + /// + /// # Examples + /// + /// ``` + /// # use beekeeper::hive::{Builder, OutcomeIteratorExt}; + /// # use beekeeper::bee::{Context, Worker, WorkerResult}; + /// # use std::num::NonZeroIsize; + /// + /// #[derive(Debug, Default)] + /// struct MathWorker(isize); // value is always `0` + /// + /// impl Worker for MathWorker { + /// type Input = (NonZeroIsize, u8); + /// type Output = isize; + /// type Error = (); + /// + /// fn apply(&mut self, input: Self::Input, _: &Context) -> WorkerResult { + /// let (operand, operator) = input; + /// let result = match operator % 4 { + /// 0 => self.config(Token) + operand.get(), + /// 1 => self.config(Token) - operand.get(), + /// 2 => self.config(Token) * operand.get(), + /// 3 => self.config(Token) / operand.get(), + /// _ => unreachable!(), + /// }; + /// Ok(result) + /// } + /// } + /// + /// # fn main() { + /// let hive = Builder::new() + /// .num_threads(8) + /// .thread_stack_size(4_000_000) + /// .build_with_default::(); + /// + /// let sum: isize = hive + /// .map((1..=100).map(|i| NonZeroIsize::new(i).unwrap()).zip((0..4).cycle())) + /// .into_outputs() + /// .sum(); + /// assert_eq!(sum, -25); + /// # } + /// ``` + pub fn with_worker_default(self) -> BeeBuilder> + where + W: Worker + Send + Sync + Default, + { + BeeBuilder::from(self.0, DefaultQueen::default()) + } + + /// Consumes this `Builder` and returns a new [`ChannelBuilder`] using the current + /// configuration. + pub fn with_channel_queues(self) -> ChannelBuilder { + ChannelBuilder::from(self.0) + } +} + +impl BuilderConfig for OpenBuilder { + fn config(&mut self, _: Token) -> &mut Config { + &mut self.0 + } +} + +impl From for OpenBuilder { + fn from(value: Config) -> Self { + Self(value) + } +} diff --git a/src/hive/core.rs b/src/hive/hive.rs similarity index 81% rename from src/hive/core.rs rename to src/hive/hive.rs index 2ae15c5..5787250 100644 --- a/src/hive/core.rs +++ b/src/hive/hive.rs @@ -1,8 +1,7 @@ -use super::prelude::*; use super::{ - Config, DerefOutcomes, GlobalQueue, LocalQueues, OutcomeSender, QueuePair, Shared, SpawnError, + ChannelBuilder, ChannelTaskQueues, Config, DerefOutcomes, Husk, Outcome, OutcomeBatch, + OutcomeIteratorExt, OutcomeSender, Shared, SpawnError, TaskQueues, }; -use crate::atomic::Atomic; use crate::bee::{DefaultQueen, Queen, TaskContext, TaskId, Worker}; use crossbeam_utils::Backoff; use std::collections::HashMap; @@ -15,14 +14,12 @@ use std::thread::{self, JoinHandle}; #[error("The hive has been poisoned")] pub struct Poisoned; -impl< - W: Worker, - Q: Queen, - G: GlobalQueue, - L: LocalQueues, - P: QueuePair, - > Hive -{ +/// A pool of worker threads that each execute the same function. +/// +/// See the [module documentation](crate::hive) for details. +pub struct Hive>(Option>>); + +impl> Hive { /// Creates a new `Hive`. This should only be called from `Builder`. /// /// The `Hive` will attempt to spawn the configured number of worker threads @@ -35,18 +32,11 @@ impl< } } -impl< - W: Worker, - Q: Queen, - G: GlobalQueue, - L: LocalQueues, - P: QueuePair, - > Hive -{ +impl, T: TaskQueues> Hive { /// Spawns a new worker thread with the specified index and with access to the `shared` data. fn try_spawn( thread_index: usize, - shared: &Arc>, + shared: &Arc>, ) -> Result, SpawnError> { let thread_builder = shared.thread_builder(); let shared = Arc::clone(shared); @@ -73,7 +63,7 @@ impl< } #[inline] - fn shared(&self) -> &Arc> { + fn shared(&self) -> &Arc> { self.0.as_ref().unwrap() } @@ -107,7 +97,7 @@ impl< /// result is available. Creates a channel to send the input and receive the outcome. Returns /// an [`Outcome`] with the task output or an error. pub fn apply(&self, input: W::Input) -> Outcome { - let (tx, rx) = outcome_channel(); + let (tx, rx) = super::outcome_channel(); let task_id = self.shared().send_one_global(input, Some(&tx)); rx.recv().unwrap_or_else(|_| Outcome::Missing { task_id }) } @@ -129,12 +119,12 @@ impl< /// /// This method is more efficient than [`map`](Self::map) when the input is an /// [`ExactSizeIterator`]. - pub fn swarm(&self, batch: T) -> impl Iterator> + pub fn swarm(&self, batch: I) -> impl Iterator> where - T: IntoIterator, - T::IntoIter: ExactSizeIterator, + I: IntoIterator, + I::IntoIter: ExactSizeIterator, { - let (tx, rx) = outcome_channel(); + let (tx, rx) = super::outcome_channel(); let task_ids = self.shared().send_batch_global(batch, Some(&tx)); rx.select_ordered(task_ids) } @@ -146,12 +136,12 @@ impl< /// instead receive the `Outcome`s in the order they were submitted. This method is more /// efficient than [`map_unordered`](Self::map_unordered) when the input is an /// [`ExactSizeIterator`]. - pub fn swarm_unordered(&self, batch: T) -> impl Iterator> + pub fn swarm_unordered(&self, batch: I) -> impl Iterator> where - T: IntoIterator, - T::IntoIter: ExactSizeIterator, + I: IntoIterator, + I::IntoIter: ExactSizeIterator, { - let (tx, rx) = outcome_channel(); + let (tx, rx) = super::outcome_channel(); let task_ids = self.shared().send_batch_global(batch, Some(&tx)); rx.select_unordered(task_ids) } @@ -161,10 +151,10 @@ impl< /// /// This method is more efficient than [`map_send`](Self::map_send) when the input is an /// [`ExactSizeIterator`]. - pub fn swarm_send(&self, batch: T, outcome_tx: &OutcomeSender) -> Vec + pub fn swarm_send(&self, batch: I, outcome_tx: &OutcomeSender) -> Vec where - T: IntoIterator, - T::IntoIter: ExactSizeIterator, + I: IntoIterator, + I::IntoIter: ExactSizeIterator, { self.shared().send_batch_global(batch, Some(outcome_tx)) } @@ -173,10 +163,10 @@ impl< /// The [`Outcome`]s of the task are retained and available for later retrieval. /// /// This method is more efficient than `map_store` when the input is an [`ExactSizeIterator`]. - pub fn swarm_store(&self, batch: T) -> Vec + pub fn swarm_store(&self, batch: I) -> Vec where - T: IntoIterator, - T::IntoIter: ExactSizeIterator, + I: IntoIterator, + I::IntoIter: ExactSizeIterator, { self.shared().send_batch_global(batch, None) } @@ -189,7 +179,7 @@ impl< &self, inputs: impl IntoIterator, ) -> impl Iterator> { - let (tx, rx) = outcome_channel(); + let (tx, rx) = super::outcome_channel(); let task_ids: Vec<_> = inputs .into_iter() .map(|task| self.apply_send(task, &tx)) @@ -206,7 +196,7 @@ impl< &self, inputs: impl IntoIterator, ) -> impl Iterator> { - let (tx, rx) = outcome_channel(); + let (tx, rx) = super::outcome_channel(); // `map` is required (rather than `inspect`) because we need owned items let task_ids: Vec<_> = inputs .into_iter() @@ -246,16 +236,16 @@ impl< /// Iterates over `items` and calls `f` with a mutable reference to a state value (initialized /// to `init`) and each item. `F` returns an input that is sent to the `Hive` for processing. /// Returns an [`OutcomeBatch`] of the outputs and the final state value. - pub fn scan( + pub fn scan( &self, - items: impl IntoIterator, + items: impl IntoIterator, init: St, f: F, ) -> (OutcomeBatch, St) where - F: FnMut(&mut St, T) -> W::Input, + F: FnMut(&mut St, I) -> W::Input, { - let (tx, rx) = outcome_channel(); + let (tx, rx) = super::outcome_channel(); let (task_ids, fold_value) = self.scan_send(items, &tx, init, f); let outcomes = rx.select_unordered(task_ids).into(); (outcomes, fold_value) @@ -265,16 +255,16 @@ impl< /// to `init`) and each item. `F` returns an input that is sent to the `Hive` for processing, /// or an error. Returns an [`OutcomeBatch`] of the outputs, a [`Vec`] of errors, and the final /// state value. - pub fn try_scan( + pub fn try_scan( &self, - items: impl IntoIterator, + items: impl IntoIterator, init: St, mut f: F, ) -> (OutcomeBatch, Vec, St) where - F: FnMut(&mut St, T) -> Result, + F: FnMut(&mut St, I) -> Result, { - let (tx, rx) = outcome_channel(); + let (tx, rx) = super::outcome_channel(); let (task_ids, errors, fold_value) = items.into_iter().fold( (Vec::new(), Vec::new(), init), |(mut task_ids, mut errors, mut acc), inp| { @@ -293,15 +283,15 @@ impl< /// to `init`) and each item. `f` returns an input that is sent to the `Hive` for processing. /// The outputs are sent to `tx` in the order they become available. Returns a [`Vec`] of the /// task IDs and the final state value. - pub fn scan_send( + pub fn scan_send( &self, - items: impl IntoIterator, + items: impl IntoIterator, tx: &OutcomeSender, init: St, mut f: F, ) -> (Vec, St) where - F: FnMut(&mut St, T) -> W::Input, + F: FnMut(&mut St, I) -> W::Input, { items .into_iter() @@ -317,15 +307,15 @@ impl< /// or an error. The outputs are sent to `tx` in the order they become available. This /// function returns the final state value and a [`Vec`] of results, where each result is /// either a task ID or an error. - pub fn try_scan_send( + pub fn try_scan_send( &self, - items: impl IntoIterator, + items: impl IntoIterator, tx: &OutcomeSender, init: St, mut f: F, ) -> (Vec>, St) where - F: FnMut(&mut St, T) -> Result, + F: FnMut(&mut St, I) -> Result, { items .into_iter() @@ -339,14 +329,14 @@ impl< /// to `init`) and each item. `f` returns an input that is sent to the `Hive` for processing. /// This function returns the final state value and a [`Vec`] of task IDs. The [`Outcome`]s of /// the tasks are retained and available for later retrieval. - pub fn scan_store( + pub fn scan_store( &self, - items: impl IntoIterator, + items: impl IntoIterator, init: St, mut f: F, ) -> (Vec, St) where - F: FnMut(&mut St, T) -> W::Input, + F: FnMut(&mut St, I) -> W::Input, { items .into_iter() @@ -362,14 +352,14 @@ impl< /// or an error. This function returns the final value of the state value and a [`Vec`] of /// results, where each result is either a task ID or an error. The [`Outcome`]s of the /// tasks are retained and available for later retrieval. - pub fn try_scan_store( + pub fn try_scan_store( &self, - items: impl IntoIterator, + items: impl IntoIterator, init: St, mut f: F, ) -> (Vec>, St) where - F: FnMut(&mut St, T) -> Result, + F: FnMut(&mut St, I) -> Result, { items .into_iter() @@ -381,30 +371,26 @@ impl< /// Blocks the calling thread until all tasks finish. pub fn join(&self) { - (self.shared()).wait_on_done(); + self.shared().wait_on_done(); } - /// Returns the [`MutexGuard`](parking_lot::MutexGuard) for the [`Queen`]. - /// - /// Note that the `Queen` will remain locked until the returned guard is dropped, and that - /// locking the `Queen` prevents new worker threads from being started. + /// Returns a read-only reference to the [`Queen`]. pub fn queen(&self) -> &Q { - &self.shared().queen + &self.shared().queen() } /// Returns the number of worker threads that have been requested, i.e., the maximum number of /// tasks that could be processed concurrently. This may be greater than /// [`active_workers`](Self::active_workers) if any of the worker threads failed to start. pub fn max_workers(&self) -> usize { - (self.shared()).config.num_threads.get_or_default() + self.shared().num_threads() } /// Returns the number of worker threads that have been successfully started. This may be /// fewer than [`max_workers`](Self::max_workers) if any of the worker threads failed to start. pub fn alive_workers(&self) -> usize { - (self.shared()) - .spawn_results - .lock() + self.shared() + .spawn_results() .iter() .filter(|result| result.is_ok()) .count() @@ -412,9 +398,8 @@ impl< /// Returns `true` if there are any "dead" worker threads that failed to spawn. pub fn has_dead_workers(&self) -> bool { - (self.shared()) - .spawn_results - .lock() + self.shared() + .spawn_results() .iter() .any(|result| result.is_err()) } @@ -428,12 +413,12 @@ impl< /// Returns the number of tasks currently (queued for processing, being processed). pub fn num_tasks(&self) -> (u64, u64) { - (self.shared()).num_tasks() + self.shared().num_tasks() } /// Returns the number of times one of this `Hive`'s worker threads has panicked. pub fn num_panics(&self) -> usize { - (self.shared()).num_panics.get() + self.shared().num_panics() } /// Returns `true` if this `Hive` has been poisoned - i.e., its internal state has been @@ -443,12 +428,12 @@ impl< /// its stored [`Outcome`]s (e.g., [`take_stored`](Self::take_stored)) or consume it (e.g., /// [`try_into_husk`](Self::try_into_husk)). pub fn is_poisoned(&self) -> bool { - (self.shared()).is_poisoned() + self.shared().is_poisoned() } /// Returns `true` if the suspended flag is set. pub fn is_suspended(&self) -> bool { - (self.shared()).is_suspended() + self.shared().is_suspended() } /// Sets the suspended flag, which notifies worker threads that they a) MAY terminate their @@ -488,18 +473,18 @@ impl< /// # } /// ``` pub fn suspend(&self) { - (self.shared()).set_suspended(true); + self.shared().set_suspended(true); } /// Unsets the suspended flag, allowing worker threads to continue processing queued tasks. pub fn resume(&self) { - (self.shared()).set_suspended(false); + self.shared().set_suspended(false); } /// Removes all `Unprocessed` outcomes from this `Hive` and returns them as an iterator over /// the input values. fn take_unprocessed_inputs(&self) -> impl ExactSizeIterator { - (self.shared()) + self.shared() .take_unprocessed() .into_iter() .map(|outcome| match outcome { @@ -512,7 +497,7 @@ impl< /// processing, with their results to be sent to `tx`. Returns a [`Vec`] of task IDs that /// were resumed. pub fn resume_send(&self, outcome_tx: &OutcomeSender) -> Vec { - (self.shared()) + self.shared() .set_suspended(false) .then(|| self.swarm_send(self.take_unprocessed_inputs(), outcome_tx)) .unwrap_or_default() @@ -522,7 +507,7 @@ impl< /// processing, with their results to be stored in the queue. Returns a [`Vec`] of task IDs /// that were resumed. pub fn resume_store(&self) -> Vec { - (self.shared()) + self.shared() .set_suspended(false) .then(|| self.swarm_store(self.take_unprocessed_inputs())) .unwrap_or_default() @@ -530,7 +515,7 @@ impl< /// Returns all stored outcomes as a [`HashMap`] of task IDs to `Outcome`s. pub fn take_stored(&self) -> HashMap> { - (self.shared()).take_outcomes() + self.shared().take_outcomes() } /// Consumes this `Hive` and attempts to return a [`Husk`] containing the remnants of this @@ -541,14 +526,14 @@ impl< /// returns `None` since it cannot take exclusive ownership of the internal shared data. /// /// This method first joins on the `Hive` to wait for all tasks to finish. - pub fn try_into_husk(mut self) -> Option> { - if (self.shared()).num_referrers() > 1 { + pub fn try_into_husk(mut self) -> Option> { + if self.shared().num_referrers() > 1 { return None; } // take the inner value and replace it with `None` let mut shared = self.0.take().unwrap(); // close the global queue to prevent new tasks from being submitted - shared.global_queue.close(); + shared.close(); // wait for all tasks to finish shared.wait_on_done(); // wait for worker threads to drop, then take ownership of the shared data and convert it @@ -571,29 +556,17 @@ impl< } } -use crate::hive::queue::{ChannelGlobalQueue, ChannelQueues, DefaultLocalQueues}; - -impl Default - for Hive< - W, - DefaultQueen, - ChannelGlobalQueue, - DefaultLocalQueues>, - ChannelQueues, - > -{ +impl Default for Hive, ChannelTaskQueues> { fn default() -> Self { - Builder::default().build_with_default::>() + ChannelBuilder::default().with_worker_default().build() } } -impl Clone for Hive +impl Clone for Hive where W: Worker, Q: Queen, - G: GlobalQueue, - L: LocalQueues, - P: QueuePair, + T: TaskQueues, { /// Creates a shallow copy of this `Hive` containing references to its same internal state, /// i.e., all clones of a `Hive` submit tasks to the same shared worker thread pool. @@ -604,13 +577,11 @@ where } } -impl fmt::Debug for Hive +impl fmt::Debug for Hive where W: Worker, Q: Queen, - G: GlobalQueue, - L: LocalQueues, - P: QueuePair, + T: TaskQueues, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { if let Some(shared) = self.0.as_ref() { @@ -621,59 +592,49 @@ where } } -impl PartialEq for Hive +impl PartialEq for Hive where W: Worker, Q: Queen, - G: GlobalQueue, - L: LocalQueues, - - P: QueuePair, + T: TaskQueues, { - fn eq(&self, other: &Hive) -> bool { + fn eq(&self, other: &Hive) -> bool { let self_shared = self.shared(); let other_shared = &other.shared(); Arc::ptr_eq(self_shared, other_shared) } } -impl Eq for Hive +impl Eq for Hive where W: Worker, Q: Queen, - G: GlobalQueue, - L: LocalQueues, - - P: QueuePair, + T: TaskQueues, { } -impl DerefOutcomes for Hive +impl DerefOutcomes for Hive where W: Worker, Q: Queen, - G: GlobalQueue, - L: LocalQueues, - P: QueuePair, + T: TaskQueues, { #[inline] fn outcomes_deref(&self) -> impl Deref>> { - (self.shared()).outcomes() + self.shared().outcomes() } #[inline] fn outcomes_deref_mut(&mut self) -> impl DerefMut>> { - (self.shared()).outcomes() + self.shared().outcomes() } } -impl Drop for Hive +impl Drop for Hive where W: Worker, Q: Queen, - G: GlobalQueue, - L: LocalQueues, - P: QueuePair, + T: TaskQueues, { fn drop(&mut self) { // if this Hive has already been turned into a Husk, it's inner value will be `None` @@ -692,24 +653,24 @@ where /// Sentinel for a worker thread. Until the sentinel is cancelled, it will respawn the worker /// thread if it panics. -struct Sentinel +struct Sentinel where W: Worker, Q: Queen, - P: QueuePair, + T: TaskQueues, { thread_index: usize, - shared: Arc>, + shared: Arc>, active: bool, } -impl Sentinel +impl Sentinel where W: Worker, Q: Queen, - P: QueuePair, + T: TaskQueues, { - fn new(thread_index: usize, shared: Arc>) -> Self { + fn new(thread_index: usize, shared: Arc>) -> Self { Self { thread_index, shared, @@ -723,11 +684,11 @@ where } } -impl Drop for Sentinel +impl Drop for Sentinel where W: Worker, Q: Queen, - P: QueuePair, + T: TaskQueues, { fn drop(&mut self) { if self.active { @@ -750,11 +711,11 @@ where #[cfg(not(feature = "affinity"))] mod no_affinity { use crate::bee::{Queen, Worker}; - use crate::hive::{GlobalQueue, Hive, LocalQueues, Shared}; + use crate::hive::{Hive, Shared, TaskQueues}; - impl, G: GlobalQueue, L: LocalQueues> Hive { + impl, T: TaskQueues> Hive { #[inline] - pub(super) fn init_thread(_: usize, _: &Shared) {} + pub(super) fn init_thread(_: usize, _: &Shared) {} } } @@ -762,19 +723,17 @@ mod no_affinity { mod affinity { use crate::bee::{Queen, Worker}; use crate::hive::cores::Cores; - use crate::hive::{GlobalQueue, Hive, LocalQueues, Poisoned, QueuePair, Shared}; + use crate::hive::{Hive, Poisoned, Shared, TaskQueues}; - impl Hive + impl Hive where W: Worker, Q: Queen, - G: GlobalQueue, - L: LocalQueues, - P: QueuePair, + T: TaskQueues, { /// Tries to pin the worker thread to a specific CPU core. #[inline] - pub(super) fn init_thread(thread_index: usize, shared: &Shared) { + pub(super) fn init_thread(thread_index: usize, shared: &Shared) { if let Some(core) = shared.get_core_affinity(thread_index) { core.try_pin_current(); } @@ -792,7 +751,7 @@ mod affinity { num_threads: usize, affinity: C, ) -> Result { - (self.shared()).add_core_affinity(affinity.into()); + self.shared().add_core_affinity(affinity.into()); self.grow(num_threads) } @@ -802,7 +761,7 @@ mod affinity { /// Returns the number of new threads spun up (if any) or a `Poisoned` error if the hive /// has been poisoned. pub fn use_all_cores_with_affinity(&self) -> Result { - (self.shared()).add_core_affinity(Cores::all()); + self.shared().add_core_affinity(Cores::all()); self.use_all_cores() } } @@ -811,45 +770,43 @@ mod affinity { #[cfg(feature = "batching")] mod batching { use crate::bee::{Queen, Worker}; - use crate::hive::{GlobalQueue, Hive, LocalQueues, QueuePair}; + use crate::hive::{Hive, TaskQueues}; - impl Hive + impl Hive where W: Worker, Q: Queen, - G: GlobalQueue, - L: LocalQueues, - P: QueuePair, + T: TaskQueues, { /// Returns the batch size for worker threads. pub fn worker_batch_size(&self) -> usize { - (self.shared()).batch_size() + self.shared().batch_size() } /// Sets the batch size for worker threads. This will block the current thread until all /// worker thread queues can be resized. pub fn set_worker_batch_size(&self, batch_size: usize) { - (self.shared()).set_batch_size(batch_size); + self.shared().set_batch_size(batch_size); } } } -struct HiveTaskContext<'a, W, Q, P> +struct HiveTaskContext<'a, W, Q, T> where W: Worker, Q: Queen, - P: QueuePair, + T: TaskQueues, { thread_index: usize, - shared: &'a Arc>, + shared: &'a Arc>, outcome_tx: Option<&'a OutcomeSender>, } -impl TaskContext for HiveTaskContext<'_, W, Q, P> +impl TaskContext for HiveTaskContext<'_, W, Q, T> where W: Worker, Q: Queen, - P: QueuePair, + T: TaskQueues, { fn should_cancel_tasks(&self) -> bool { self.shared.is_suspended() @@ -861,11 +818,11 @@ where } } -impl fmt::Debug for HiveTaskContext<'_, W, Q, P> +impl fmt::Debug for HiveTaskContext<'_, W, Q, T> where W: Worker, Q: Queen, - P: QueuePair, + T: TaskQueues, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("HiveTaskContext").finish() @@ -876,21 +833,20 @@ where mod no_retry { use super::HiveTaskContext; use crate::bee::{Context, Queen, Worker}; - use crate::hive::{GlobalQueue, Hive, LocalQueues, Outcome, Shared, Task}; + use crate::hive::{Hive, Outcome, Shared, Task, TaskQueues}; use std::sync::Arc; - impl Hive + impl Hive where W: Worker, Q: Queen, - G: GlobalQueue, - L: LocalQueues, + T: TaskQueues, { pub(super) fn execute( task: Task, thread_index: usize, worker: &mut W, - shared: &Arc>, + shared: &Arc>, ) { let (task_id, input, outcome_tx) = task.into_parts(); let task_ctx = HiveTaskContext { @@ -911,22 +867,20 @@ mod no_retry { mod retry { use super::HiveTaskContext; use crate::bee::{ApplyError, Context, Queen, Worker}; - use crate::hive::{GlobalQueue, Hive, LocalQueues, Outcome, QueuePair, Shared, Task}; + use crate::hive::{Hive, Outcome, Shared, Task, TaskQueues}; use std::sync::Arc; - impl Hive + impl Hive where W: Worker, Q: Queen, - G: GlobalQueue, - L: LocalQueues, - P: QueuePair, + T: TaskQueues, { pub(super) fn execute( task: Task, thread_index: usize, worker: &mut W, - shared: &Arc>, + shared: &Arc>, ) { let (task_id, input, attempt, outcome_tx) = task.into_parts(); let task_ctx = HiveTaskContext { @@ -958,17 +912,17 @@ mod retry { mod tests { use super::Poisoned; use crate::bee::stock::{Caller, Thunk, ThunkWorker}; - use crate::hive::queue::ChannelQueues; - use crate::hive::{outcome_channel, Builder, Hive, Outcome, OutcomeIteratorExt}; + use crate::hive::{outcome_channel, Builder, ChannelBuilder, Outcome, OutcomeIteratorExt}; use std::collections::HashMap; use std::thread; use std::time::Duration; #[test] fn test_suspend() { - let hive = Builder::new() + let hive = ChannelBuilder::empty() .num_threads(4) - .build_with_default::, ChannelQueues<_>>(); + .with_worker_default::>() + .build(); let outcome_iter = hive.map((0..10).map(|_| Thunk::of(|| thread::sleep(Duration::from_secs(3))))); // Allow first set of tasks to be started. @@ -989,9 +943,10 @@ mod tests { #[test] fn test_spawn_after_poison() { - let hive = Builder::new() + let hive = ChannelBuilder::empty() .num_threads(4) - .build_with_default::, ChannelQueues<_>>(); + .with_worker_default::>() + .build(); assert_eq!(hive.max_workers(), 4); assert_eq!(hive.alive_workers(), 4); // poison hive using private method @@ -1005,9 +960,10 @@ mod tests { #[test] fn test_apply_after_poison() { - let hive = Builder::new() + let hive = ChannelBuilder::empty() .num_threads(4) - .build_with::<_, ChannelQueues<_>>(Caller::of(|i: usize| i * 2)); + .with_worker(Caller::of(|i: usize| i * 2)) + .build(); // poison hive using private method hive.0.as_ref().unwrap().poison(); // submit a task, check that it comes back unprocessed @@ -1026,9 +982,10 @@ mod tests { #[test] fn test_swarm_after_poison() { - let hive = Builder::new() + let hive = ChannelBuilder::empty() .num_threads(4) - .build_with::<_, ChannelQueues<_>>(Caller::of(|i: usize| i * 2)); + .with_worker(Caller::of(|i: usize| i * 2)) + .build(); // poison hive using private method hive.0.as_ref().unwrap().poison(); // submit a task, check that it comes back unprocessed diff --git a/src/hive/husk.rs b/src/hive/husk.rs index 9795416..e6ce12c 100644 --- a/src/hive/husk.rs +++ b/src/hive/husk.rs @@ -1,6 +1,6 @@ use super::{ - Builder, Config, DerefOutcomes, Hive, Outcome, OutcomeBatch, OutcomeSender, OutcomeStore, - OwnedOutcomes, QueuePair, + Config, DerefOutcomes, Hive, OpenBuilder, Outcome, OutcomeBatch, OutcomeSender, OutcomeStore, + OwnedOutcomes, TaskQueues, }; use crate::bee::{Queen, TaskId, Worker}; use std::collections::HashMap; @@ -10,20 +10,20 @@ use std::ops::{Deref, DerefMut}; /// /// Provides access to the `Queen` and to stored `Outcome`s. Can be used to create a new `Hive` /// based on the previous `Hive`'s configuration. -pub struct Husk> { +pub struct Husk { config: Config, queen: Q, num_panics: usize, - outcomes: HashMap>, + outcomes: HashMap>, } -impl> Husk { +impl Husk { /// Creates a new `Husk`. Should only be called from `Shared::try_into_husk`. pub(super) fn new( config: Config, queen: Q, num_panics: usize, - outcomes: HashMap>, + outcomes: HashMap>, ) -> Self { Self { config, @@ -44,20 +44,23 @@ impl> Husk { } /// Consumes this `Husk` and returns the `Queen` and `Outcome`s. - pub fn into_parts(self) -> (Q, OutcomeBatch) { + pub fn into_parts(self) -> (Q, OutcomeBatch) { (self.queen, OutcomeBatch::new(self.outcomes)) } /// Returns a new `Builder` that will create a `Hive` with the same configuration as the one /// that produced this `Husk`. - pub fn as_builder(&self) -> Builder { - self.config.clone().into() + pub fn as_builder(&self) -> OpenBuilder { + OpenBuilder::from(self.config.clone()) } /// Consumes this `Husk` and returns a new `Hive` with the same configuration and `Queen` as /// the one that produced this `Husk`. - pub fn into_hive>(self) -> Hive { - self.as_builder().build::(self.queen) + pub fn into_hive>(self) -> Hive { + self.as_builder() + .with_queen(self.queen) + .with_queues::() + .build() } /// Consumes this `Husk` and creates a new `Hive` with the same configuration as the one that @@ -65,16 +68,20 @@ impl> Husk { /// be sent to `tx`. Returns the new `Hive` and the IDs of the tasks that were queued. /// /// This method returns a `SpawnError` if there is an error creating the new `Hive`. - pub fn into_hive_swarm_send_unprocessed>( + pub fn into_hive_swarm_send_unprocessed>( mut self, - tx: &OutcomeSender, - ) -> (Hive, Vec) { + tx: &OutcomeSender, + ) -> (Hive, Vec) { let unprocessed: Vec<_> = self .remove_all_unprocessed() .into_iter() .map(|(_, input)| input) .collect(); - let hive = self.as_builder().build::(self.queen); + let hive = self + .as_builder() + .with_queen(self.queen) + .with_queues::() + .build(); let task_ids = hive.swarm_send(unprocessed, tx); (hive, task_ids) } @@ -85,40 +92,44 @@ impl> Husk { /// of the tasks that were queued. /// /// This method returns a `SpawnError` if there is an error creating the new `Hive`. - pub fn into_hive_swarm_store_unprocessed>( + pub fn into_hive_swarm_store_unprocessed>( mut self, - ) -> (Hive, Vec) { + ) -> (Hive, Vec) { let unprocessed: Vec<_> = self .remove_all_unprocessed() .into_iter() .map(|(_, input)| input) .collect(); - let hive = self.as_builder().build::(self.queen); + let hive = self + .as_builder() + .with_queen(self.queen) + .with_queues::() + .build(); let task_ids = hive.swarm_store(unprocessed); (hive, task_ids) } } -impl> DerefOutcomes for Husk { +impl> DerefOutcomes for Husk { #[inline] - fn outcomes_deref(&self) -> impl Deref>> { + fn outcomes_deref(&self) -> impl Deref>> { &self.outcomes } #[inline] - fn outcomes_deref_mut(&mut self) -> impl DerefMut>> { + fn outcomes_deref_mut(&mut self) -> impl DerefMut>> { &mut self.outcomes } } -impl> OwnedOutcomes for Husk { +impl> OwnedOutcomes for Husk { #[inline] - fn outcomes(self) -> HashMap> { + fn outcomes(self) -> HashMap> { self.outcomes } #[inline] - fn outcomes_ref(&self) -> &HashMap> { + fn outcomes_ref(&self) -> &HashMap> { &self.outcomes } } @@ -126,15 +137,18 @@ impl> OwnedOutcomes for Husk { #[cfg(test)] mod tests { use crate::bee::stock::{PunkWorker, Thunk, ThunkWorker}; - use crate::hive::queue::ChannelQueues; - use crate::hive::{outcome_channel, Builder, Outcome, OutcomeIteratorExt, OutcomeStore}; + use crate::hive::ChannelTaskQueues; + use crate::hive::{ + outcome_channel, Builder, ChannelBuilder, Outcome, OutcomeIteratorExt, OutcomeStore, + }; #[test] fn test_unprocessed() { // don't spin up any worker threads so that no tasks will be processed - let hive = Builder::new() + let hive = ChannelBuilder::empty() .num_threads(0) - .build_with_default::, ChannelQueues<_>>(); + .with_worker_default::>() + .build(); let mut task_ids = hive.map_store((0..10).map(|i| Thunk::of(move || i))); // cancel and smash the hive before the tasks can be processed hive.suspend(); @@ -157,14 +171,15 @@ mod tests { #[test] fn test_reprocess_unprocessed() { // don't spin up any worker threads so that no tasks will be processed - let hive1 = Builder::new() + let hive1 = ChannelBuilder::empty() .num_threads(0) - .build_with_default::, ChannelQueues<_>>(); + .with_worker_default::>() + .build(); let _ = hive1.map_store((0..10).map(|i| Thunk::of(move || i))); // cancel and smash the hive before the tasks can be processed hive1.suspend(); let husk1 = hive1.try_into_husk().unwrap(); - let (hive2, _) = husk1.into_hive_swarm_store_unprocessed::>(); + let (hive2, _) = husk1.into_hive_swarm_store_unprocessed::>(); // now spin up worker threads to process the tasks hive2.grow(8).expect("error spawning threads"); hive2.join(); @@ -177,15 +192,16 @@ mod tests { #[test] fn test_reprocess_unprocessed_to() { // don't spin up any worker threads so that no tasks will be processed - let hive1 = Builder::new() + let hive1 = ChannelBuilder::empty() .num_threads(0) - .build_with_default::, ChannelQueues<_>>(); + .with_worker_default::>() + .build(); let _ = hive1.map_store((0..10).map(|i| Thunk::of(move || i))); // cancel and smash the hive before the tasks can be processed hive1.suspend(); let husk1 = hive1.try_into_husk().unwrap(); let (tx, rx) = outcome_channel(); - let (hive2, task_ids) = husk1.into_hive_swarm_send_unprocessed::>(&tx); + let (hive2, task_ids) = husk1.into_hive_swarm_send_unprocessed::>(&tx); // now spin up worker threads to process the tasks hive2.grow(8).expect("error spawning threads"); hive2.join(); @@ -201,9 +217,10 @@ mod tests { #[test] fn test_into_result() { - let hive = Builder::new() + let hive = ChannelBuilder::empty() .num_threads(4) - .build_with_default::, ChannelQueues<_>>(); + .with_worker_default::>() + .build(); hive.map_store((0..10).map(|i| Thunk::of(move || i))); hive.join(); let mut outputs = hive.try_into_husk().unwrap().into_parts().1.unwrap(); @@ -214,9 +231,10 @@ mod tests { #[test] #[should_panic] fn test_into_result_panic() { - let hive = Builder::new() + let hive = ChannelBuilder::empty() .num_threads(4) - .build_with_default::, ChannelQueues<_>>(); + .with_worker_default::>() + .build(); hive.map_store( (0..10).map(|i| Thunk::of(move || if i == 5 { panic!("oh no!") } else { i })), ); diff --git a/src/hive/inner/builder.rs b/src/hive/inner/builder.rs new file mode 100644 index 0000000..024837e --- /dev/null +++ b/src/hive/inner/builder.rs @@ -0,0 +1,331 @@ +use super::{Config, Token}; + +/// Private (sealed) trait depended on by `Builder` that must be implemented by builder types. +pub trait BuilderConfig { + /// Returns a reference to the underlying `Config`. + fn config(&mut self, token: Token) -> &mut Config; +} + +/// Trait that provides `Builder` types with methods for setting configuration parameters. +/// +/// This is a sealed trait, meaning it cannot be implemented outside of this crate. +pub trait Builder: BuilderConfig + Sized { + /// Sets the maximum number of worker threads that will be alive at any given moment in the + /// built [`Hive`]. If not specified, the built `Hive` will not be initialized with worker + /// threads until [`Hive::grow`] is called. + /// + /// # Examples + /// + /// No more than eight threads will be alive simultaneously for this hive: + /// + /// ``` + /// use beekeeper::bee::stock::{Thunk, ThunkWorker}; + /// use beekeeper::hive::{Builder, Hive}; + /// + /// # fn main() { + /// let hive = Builder::new() + /// .num_threads(8) + /// .build_with_default::>(); + /// + /// for _ in 0..100 { + /// hive.apply_store(Thunk::of(|| { + /// println!("Hello from a worker thread!") + /// })); + /// } + /// # } + /// ``` + fn num_threads(mut self, num: usize) -> Self { + let _ = self.config(Token).num_threads.set(Some(num)); + self + } + + /// Sets the number of worker threads to the global default value. + fn with_default_num_threads(mut self) -> Self { + let _ = self + .config(Token) + .num_threads + .set(super::config::DEFAULTS.lock().num_threads.get()); + self + } + + /// Specifies that the built [`Hive`] will use all available CPU cores for worker threads. + /// + /// # Examples + /// + /// All available threads will be alive simultaneously for this hive: + /// + /// ``` + /// use beekeeper::bee::stock::{Thunk, ThunkWorker}; + /// use beekeeper::hive::{Builder, Hive}; + /// + /// # fn main() { + /// let hive = Builder::new() + /// .with_thread_per_core() + /// .build_with_default::>(); + /// + /// for _ in 0..100 { + /// hive.apply_store(Thunk::of(|| { + /// println!("Hello from a worker thread!") + /// })); + /// } + /// # } + /// ``` + fn with_thread_per_core(mut self) -> Self { + let _ = self.config(Token).num_threads.set(Some(num_cpus::get())); + self + } + + /// Sets the thread name for each of the threads spawned by the built [`Hive`]. If not + /// specified, threads spawned by the thread pool will be unnamed. + /// + /// # Examples + /// + /// Each thread spawned by this hive will have the name `"foo"`: + /// + /// ``` + /// use beekeeper::bee::stock::{Thunk, ThunkWorker}; + /// use beekeeper::hive::{Builder, Hive}; + /// use std::thread; + /// + /// # fn main() { + /// let hive = Builder::default() + /// .thread_name("foo") + /// .build_with_default::>(); + /// + /// for _ in 0..100 { + /// hive.apply_store(Thunk::of(|| { + /// assert_eq!(thread::current().name(), Some("foo")); + /// })); + /// } + /// # hive.join(); + /// # } + /// ``` + fn thread_name>(mut self, name: T) -> Self { + let _ = self.config(Token).thread_name.set(Some(name.into())); + self + } + + /// Sets the stack size (in bytes) for each of the threads spawned by the built [`Hive`]. + /// If not specified, threads spawned by the hive will have a stack size [as specified in + /// the `std::thread` documentation][thread]. + /// + /// [thread]: https://doc.rust-lang.org/nightly/std/thread/index.html#stack-size + /// + /// # Examples + /// + /// Each thread spawned by this hive will have a 4 MB stack: + /// + /// ``` + /// use beekeeper::bee::stock::{Thunk, ThunkWorker}; + /// use beekeeper::hive::{Builder, Hive}; + /// + /// # fn main() { + /// let hive = Builder::default() + /// .thread_stack_size(4_000_000) + /// .build_with_default::>(); + /// + /// for _ in 0..100 { + /// hive.apply_store(Thunk::of(|| { + /// println!("This thread has a 4 MB stack size!"); + /// })); + /// } + /// # hive.join(); + /// # } + /// ``` + fn thread_stack_size(mut self, size: usize) -> Self { + let _ = self.config(Token).thread_stack_size.set(Some(size)); + self + } + + /// Sets set list of CPU core indices to which threads in the `Hive` should be pinned. + /// + /// Core indices are integers in the range `0..N`, where `N` is the number of available CPU + /// cores as reported by [`num_cpus::get()`]. The mapping between core indices and core IDs + /// is platform-specific. All CPU cores on a given system should be equivalent, and thus it + /// does not matter which cores are pinned so long as a core is not pinned to multiple + /// threads. + /// + /// Excess core indices (i.e., if `affinity.len() > num_threads`) are ignored. If + /// `affinity.len() < num_threads` then the excess threads will not be pinned. + /// + /// # Examples + /// + /// Each thread spawned by this hive will be pinned to a core: + /// + /// ``` + /// use beekeeper::bee::stock::{Thunk, ThunkWorker}; + /// use beekeeper::hive::{Builder, Hive}; + /// + /// # fn main() { + /// let hive = Builder::new() + /// .num_threads(4) + /// .core_affinity(0..4) + /// .build_with_default::>(); + /// + /// for _ in 0..100 { + /// hive.apply_store(Thunk::of(|| { + /// println!("This thread is pinned!"); + /// })); + /// } + /// # hive.join(); + /// # } + /// ``` + #[cfg(feature = "affinity")] + fn core_affinity>(mut self, affinity: C) -> Self { + let _ = self.config(Token).affinity.set(Some(affinity.into())); + self + } + + /// Specifies that worker threads should be pinned to all available CPU cores. If + /// `num_threads` is greater than the available number of CPU cores, then some threads + /// might not be pinned. + #[cfg(feature = "affinity")] + fn with_default_core_affinity(mut self) -> Self { + let _ = self + .config(Token) + .affinity + .set(Some(crate::hive::cores::Cores::all())); + self + } + + /// Sets the worker thread batch size. If `batch_size` is `0`, batching is disabled, but + /// note that the performance may be worse than with the `batching` feature disabled. + #[cfg(feature = "batching")] + fn batch_size(mut self, batch_size: usize) -> Self { + if batch_size == 0 { + self.config(Token).batch_size.set(None); + } else { + self.config(Token).batch_size.set(Some(batch_size)); + } + self + } + + /// Sets the worker thread batch size to the global default value. + #[cfg(feature = "batching")] + fn with_default_batch_size(mut self) -> Self { + let _ = self + .config(Token) + .batch_size + .set(super::config::DEFAULTS.lock().batch_size.get()); + self + } + + /// Sets the maximum number of times to retry a + /// [`ApplyError::Retryable`](crate::bee::ApplyError::Retryable) error. A worker + /// thread will retry a task until it either returns + /// [`ApplyError::Fatal`](crate::bee::ApplyError::Fatal) or the maximum number of retries is + /// reached. Each time a task is retried, the worker thread will first sleep for + /// `retry_factor * (2 ** (attempt - 1))` before attempting the task again. If not + /// specified, tasks are retried a default number of times. If set to `0`, tasks will be + /// retried immediately without delay. + /// + /// # Examples + /// + /// ``` + /// use beekeeper::bee::{ApplyError, Context}; + /// use beekeeper::bee::stock::RetryCaller; + /// use beekeeper::hive::{Builder, Hive}; + /// use std::time; + /// + /// fn sometimes_fail( + /// i: usize, + /// _: &Context + /// ) -> Result> { + /// match i % 3 { + /// 0 => Ok("Success".into()), + /// 1 => Err(ApplyError::Retryable { input: i, error: "Retryable".into() }), + /// 2 => Err(ApplyError::Fatal { input: Some(i), error: "Fatal".into() }), + /// _ => unreachable!(), + /// } + /// } + /// + /// # fn main() { + /// let hive = Builder::default() + /// .max_retries(3) + /// .build_with(RetryCaller::of(sometimes_fail)); + /// + /// for i in 0..10 { + /// hive.apply_store(i); + /// } + /// # hive.join(); + /// # } + /// ``` + #[cfg(feature = "retry")] + fn max_retries(mut self, limit: u32) -> Self { + let _ = if limit == 0 { + self.config(Token).max_retries.set(None) + } else { + self.config(Token).max_retries.set(Some(limit)) + }; + self + } + + /// Sets the exponential back-off factor for retrying tasks. Each time a task is retried, + /// the thread will first sleep for `retry_factor * (2 ** (attempt - 1))`. If not + /// specififed, a default retry factor is used. Set to + /// [`Duration::ZERO`](std::time::Duration::ZERO) to disableexponential backoff. + /// + /// # Examples + /// + /// ``` + /// use beekeeper::bee::{ApplyError, Context}; + /// use beekeeper::bee::stock::RetryCaller; + /// use beekeeper::hive::{Builder, Hive}; + /// use std::time; + /// + /// fn echo_time(i: usize, ctx: &Context) -> Result> { + /// let attempt = ctx.attempt(); + /// if attempt == 3 { + /// Ok("Success".into()) + /// } else { + /// // the delay between each message should be exponential + /// println!("Task {} attempt {}: {:?}", i, attempt, time::SystemTime::now()); + /// Err(ApplyError::Retryable { input: i, error: "Retryable".into() }) + /// } + /// } + /// + /// # fn main() { + /// let hive = Builder::default() + /// .max_retries(3) + /// .retry_factor(time::Duration::from_secs(1)) + /// .build_with(RetryCaller::of(echo_time)); + /// + /// for i in 0..10 { + /// hive.apply_store(i); + /// } + /// # hive.join(); + /// # } + /// ``` + #[cfg(feature = "retry")] + fn retry_factor(mut self, duration: std::time::Duration) -> Self { + let _ = if duration == std::time::Duration::ZERO { + self.config(Token).retry_factor.set(None) + } else { + self.config(Token).set_retry_factor_from(duration) + }; + self + } + + /// Sets retry parameters to their default values. + #[cfg(feature = "retry")] + fn with_default_retries(mut self) -> Self { + let defaults = super::config::DEFAULTS.lock(); + let _ = self + .config(Token) + .max_retries + .set(defaults.max_retries.get()); + let _ = self + .config(Token) + .retry_factor + .set(defaults.retry_factor.get()); + self + } + + /// Disables retrying tasks. + #[cfg(feature = "retry")] + fn with_no_retries(self) -> Self { + self.max_retries(0).retry_factor(std::time::Duration::ZERO) + } +} + +impl Builder for B {} diff --git a/src/hive/config.rs b/src/hive/inner/config.rs similarity index 84% rename from src/hive/config.rs rename to src/hive/inner/config.rs index 43f2c6f..6db735b 100644 --- a/src/hive/config.rs +++ b/src/hive/inner/config.rs @@ -1,7 +1,9 @@ #[cfg(feature = "batching")] -pub use batching::set_batch_size_default; +pub use self::batching::set_batch_size_default; #[cfg(feature = "retry")] -pub use retry::{set_max_retries_default, set_retries_default_disabled, set_retry_factor_default}; +pub use self::retry::{ + set_max_retries_default, set_retries_default_disabled, set_retry_factor_default, +}; use super::Config; use parking_lot::Mutex; @@ -9,8 +11,8 @@ use std::sync::LazyLock; const DEFAULT_NUM_THREADS: usize = 4; -pub(super) static DEFAULTS: LazyLock> = LazyLock::new(|| { - let mut config = Config::default(); +pub static DEFAULTS: LazyLock> = LazyLock::new(|| { + let mut config = Config::empty(); config.set_const_defaults(); Mutex::new(config) }); @@ -35,12 +37,19 @@ pub fn reset_defaults() { impl Config { /// Creates a new `Config` with all values unset. pub fn empty() -> Self { - Self::default() - } - - /// Creates a new `Config` with default values. This simply clones `DEFAULTS`. - pub fn with_defaults() -> Self { - DEFAULTS.lock().clone() + Self { + num_threads: Default::default(), + thread_name: Default::default(), + thread_stack_size: Default::default(), + #[cfg(feature = "affinity")] + affinity: Default::default(), + #[cfg(feature = "batching")] + batch_size: Default::default(), + #[cfg(feature = "retry")] + max_retries: Default::default(), + #[cfg(feature = "retry")] + retry_factor: Default::default(), + } } /// Resets config values to their pre-configured defaults. @@ -88,6 +97,13 @@ impl Config { } } +impl Default for Config { + /// Creates a new `Config` with default values. This simply clones `DEFAULTS`. + fn default() -> Self { + DEFAULTS.lock().clone() + } +} + #[cfg(test)] pub mod reset { /// Struct that resets the default values when `drop`ped. @@ -111,18 +127,18 @@ mod tests { fn test_set_num_threads_default() { let reset = Reset; super::set_num_threads_default(2); - let config = Config::with_defaults(); + let config = Config::default(); assert_eq!(config.num_threads.get(), Some(2)); // Dropping `Reset` should reset the defaults drop(reset); let reset = Reset; super::set_num_threads_default_all(); - let config = Config::with_defaults(); + let config = Config::default(); assert_eq!(config.num_threads.get(), Some(num_cpus::get())); drop(reset); - let config = Config::with_defaults(); + let config = Config::default(); assert_eq!(config.num_threads.get(), Some(super::DEFAULT_NUM_THREADS)); } } @@ -183,7 +199,7 @@ mod retry { #[cfg(test)] mod tests { use super::Config; - use crate::hive::config::reset::Reset; + use crate::hive::inner::config::reset::Reset; use serial_test::serial; use std::time::Duration; @@ -198,18 +214,18 @@ mod retry { fn test_set_max_retries_default() { let reset = Reset; super::set_max_retries_default(1); - let config = Config::with_defaults(); + let config = Config::default(); assert_eq!(config.max_retries.get(), Some(1)); // Dropping `Reset` should reset the defaults drop(reset); let reset = Reset; super::set_retries_default_disabled(); - let config = Config::with_defaults(); + let config = Config::default(); assert_eq!(config.max_retries.get(), Some(0)); drop(reset); - let config = Config::with_defaults(); + let config = Config::default(); assert_eq!(config.max_retries.get(), Some(super::DEFAULT_MAX_RETRIES)); } @@ -218,14 +234,14 @@ mod retry { fn test_set_retry_factor_default() { let reset = Reset; super::set_retry_factor_default(Duration::from_secs(2)); - let config = Config::with_defaults(); + let config = Config::default(); assert_eq!( config.get_retry_factor_duration(), Some(Duration::from_secs(2)) ); // Dropping `Reset` should reset the defaults drop(reset); - let config = Config::with_defaults(); + let config = Config::default(); assert_eq!( config.get_retry_factor_duration(), Some(Duration::from_secs(super::DEFAULT_RETRY_FACTOR_SECS)) diff --git a/src/hive/counter.rs b/src/hive/inner/counter.rs similarity index 100% rename from src/hive/counter.rs rename to src/hive/inner/counter.rs diff --git a/src/hive/gate.rs b/src/hive/inner/gate.rs similarity index 100% rename from src/hive/gate.rs rename to src/hive/inner/gate.rs diff --git a/src/hive/inner/mod.rs b/src/hive/inner/mod.rs new file mode 100644 index 0000000..9ce3330 --- /dev/null +++ b/src/hive/inner/mod.rs @@ -0,0 +1,107 @@ +mod builder; +mod config; +mod counter; +mod gate; +mod queue; +mod shared; +mod task; + +pub mod set_config { + #[cfg(feature = "batching")] + pub use super::config::set_batch_size_default; + pub use super::config::{reset_defaults, set_num_threads_default, set_num_threads_default_all}; + #[cfg(feature = "retry")] + pub use super::config::{ + set_max_retries_default, set_retries_default_disabled, set_retry_factor_default, + }; +} + +pub use self::builder::{Builder, BuilderConfig}; +pub use self::queue::{ChannelTaskQueues, TaskQueues}; + +use self::counter::DualCounter; +use self::gate::{Gate, PhasedGate}; +use self::queue::PopTaskError; +use crate::atomic::{AtomicAny, AtomicBool, AtomicOption, AtomicUsize}; +use crate::bee::{Queen, TaskId, Worker}; +use crate::hive::{OutcomeQueue, OutcomeSender, SpawnError}; +use parking_lot::Mutex; +use std::thread::JoinHandle; + +type Any = AtomicOption>; +type Usize = AtomicOption; +#[cfg(feature = "retry")] +type U32 = AtomicOption; +#[cfg(feature = "retry")] +type U64 = AtomicOption; + +/// Private, zero-size struct used to call private methods in public sealed traits. +pub struct Token; + +/// Internal representation of a task to be processed by a `Hive`. +#[derive(Debug)] +pub struct Task { + id: TaskId, + input: W::Input, + outcome_tx: Option>, + #[cfg(feature = "retry")] + attempt: u32, +} + +/// Data shared by all worker threads in a `Hive`. +pub struct Shared> { + /// core configuration parameters + config: Config, + /// the `Queen` used to create new workers + queen: Q, + /// global and local task queues used by the `Hive` to send tasks to the worker threads + task_queues: T, + /// The results of spawning each worker + spawn_results: Mutex, SpawnError>>>, + /// allows for 2^48 queued tasks and 2^16 active tasks + num_tasks: DualCounter<48>, + /// ID that will be assigned to the next task submitted to the `Hive` + next_task_id: AtomicUsize, + /// number of times a worker has panicked + num_panics: AtomicUsize, + /// number of `Hive` clones with a reference to this shared data + num_referrers: AtomicUsize, + /// whether the internal state of the hive is corrupted - if true, this prevents new tasks from + /// processed (new tasks may be queued but they will never be processed); currently, this can + /// only happen if the task counter somehow get corrupted + poisoned: AtomicBool, + /// whether the hive is suspended - if true, active tasks may complete and new tasks may be + /// queued, but new tasks will not be processed + suspended: AtomicBool, + /// gate used by worker threads to wait until the hive is resumed + resume_gate: Gate, + /// gate used by client threads to wait until all tasks have completed + join_gate: PhasedGate, + /// outcomes stored in the hive + outcomes: OutcomeQueue, +} + +/// Core configuration parameters that are set by a `Builder`, used in a `Hive`, and preserved in a +/// `Husk`. Fields are `AtomicOption`s, which enables them to be transitioned back and forth +/// between thread-safe and non-thread-safe contexts. +#[derive(Clone, Debug)] +pub struct Config { + /// Number of worker threads to spawn + num_threads: Usize, + /// Name to give each worker thread + thread_name: Any, + /// Stack size for each worker thread + thread_stack_size: Usize, + /// CPU cores to which worker threads can be pinned + #[cfg(feature = "affinity")] + affinity: Any, + /// Maximum number of tasks for a worker thread to take when receiving from the input channel + #[cfg(feature = "batching")] + batch_size: Usize, + /// Maximum number of retries for a task + #[cfg(feature = "retry")] + max_retries: U32, + /// Multiplier for the retry backoff strategy + #[cfg(feature = "retry")] + retry_factor: U64, +} diff --git a/src/hive/queue/local.rs b/src/hive/inner/queue/channel.rs similarity index 67% rename from src/hive/queue/local.rs rename to src/hive/inner/queue/channel.rs index 1993f01..fdcff43 100644 --- a/src/hive/queue/local.rs +++ b/src/hive/inner/queue/channel.rs @@ -1,86 +1,53 @@ -use std::marker::PhantomData; - +//! Implementation of `TaskQueues` that uses `crossbeam` channels for the global queue (i.e., for +//! sending tasks from the `Hive` to the worker threads) and a default implementation of local +//! queues that depends on which combination of the `retry` and `batching` features are enabled. +use super::{PopTaskError, Task, TaskQueues, Token}; +use crate::atomic::{Atomic, AtomicBool}; use crate::bee::{Queen, Worker}; -use crate::hive::{GlobalQueue, LocalQueues, QueuePair, Shared, Task}; +use crate::hive::inner::Shared; +use crossbeam_channel::RecvTimeoutError; use parking_lot::RwLock; +use std::time::Duration; + +// time to wait in between polling the retry queue and then the task receiver +const RECV_TIMEOUT: Duration = Duration::from_secs(1); -pub struct LocalQueuesImpl> { +/// Type alias for the input task channel sender +type TaskSender = crossbeam_channel::Sender>; +/// Type alias for the input task channel receiver +type TaskReceiver = crossbeam_channel::Receiver>; + +pub struct ChannelTaskQueues { + global_tx: TaskSender, + global_rx: TaskReceiver, + closed: AtomicBool, /// thread-local queues of tasks used when the `batching` feature is enabled #[cfg(feature = "batching")] - batch_queues: RwLock>>>, + local_batch_queues: RwLock>>>, /// thread-local queues used for tasks that are waiting to be retried after a failure #[cfg(feature = "retry")] - retry_queues: RwLock>>>, - /// marker for the global queue type - _global: PhantomData, -} - -#[cfg(feature = "retry")] -impl> LocalQueuesImpl { - #[inline] - fn try_pop_retry(&self, thread_index: usize) -> Option> { - self.retry_queues - .read() - .get(thread_index) - .and_then(|queue| queue.try_pop()) - } + local_retry_queues: RwLock>>>, } -#[cfg(feature = "batching")] -impl> LocalQueuesImpl { - #[inline] - fn try_push_local(&self, task: Task, thread_index: usize) -> Result<(), Task> { - self.batch_queues.read()[thread_index].push(task) - } - - #[inline] - fn try_pop_local_or_refill, P: QueuePair>( - &self, - thread_index: usize, - shared: &Shared, - ) -> Option> { - let local_queue = &self.batch_queues.read()[thread_index]; - // pop from the local queue if it has any tasks - if !local_queue.is_empty() { - return local_queue.pop(); - } - // otherwise pull at least 1 and up to `batch_size + 1` tasks from the input channel - // wait for the next task from the receiver - let first = shared.global_queue.try_pop().and_then(Result::ok); - // if we fail after trying to get one, don't keep trying to fill the queue - if first.is_some() { - let batch_size = shared.batch_size(); - // batch size 0 means batching is disabled - if batch_size > 0 { - // otherwise try to take up to `batch_size` tasks from the input channel - // and add them to the local queue, but don't block if the input channel - // is empty - for result in shared - .global_queue - .try_iter() - .take(batch_size) - .map(|task| local_queue.push(task)) - { - if let Err(task) = result { - // for some reason we can't push the task to the local queue; - // this should never happen, but just in case we turn it into an - // unprocessed outcome and stop iterating - shared.abandon_task(task); - break; - } - } - } +impl TaskQueues for ChannelTaskQueues { + fn new(_: Token) -> Self { + let (tx, rx) = crossbeam_channel::unbounded(); + Self { + global_tx: tx, + global_rx: rx, + closed: AtomicBool::default(), + #[cfg(feature = "batching")] + local_batch_queues: Default::default(), + #[cfg(feature = "retry")] + local_retry_queues: Default::default(), } - first } -} -impl> LocalQueues for LocalQueuesImpl { - fn init_for_threads, P: QueuePair>( + fn init_for_threads>( &self, start_index: usize, end_index: usize, - #[allow(unused_variables)] shared: &Shared, + #[allow(unused_variables)] shared: &Shared, ) { #[cfg(feature = "batching")] self.init_batch_queues_for_threads(start_index, end_index, shared); @@ -88,25 +55,24 @@ impl> LocalQueues for LocalQueuesImpl { self.init_retry_queues_for_threads(start_index, end_index); } - #[cfg(feature = "batching")] - fn resize, P: QueuePair>( - &self, - start_index: usize, - end_index: usize, - new_size: usize, - shared: &Shared, - ) { - self.resize_batch_queues(start_index, end_index, new_size, shared); + fn try_push_global(&self, task: Task) -> Result<(), Task> { + if !self.closed.get() { + self.global_tx + .try_send(task) + .map_err(|err| err.into_inner()) + } else { + Err(task) + } } /// Creates a task from `input` and pushes it to the local queue if there is space, /// otherwise attempts to add it to the global queue. Returns the task ID if the push /// succeeds, otherwise returns an error with the input. - fn push, P: QueuePair>( + fn push_local>( &self, task: Task, #[allow(unused_variables)] thread_index: usize, - shared: &Shared, + shared: &Shared, ) { #[cfg(feature = "batching")] let task = match self.try_push_local(task, thread_index) { @@ -119,24 +85,26 @@ impl> LocalQueues for LocalQueuesImpl { /// Returns the next task from the local queue if there are any, otherwise attempts to /// fetch at least 1 and up to `batch_size + 1` tasks from the input channel and puts all /// but the first one into the local queue. - fn try_pop, P: QueuePair>( + fn try_pop>( &self, thread_index: usize, - #[allow(unused_variables)] shared: &Shared, - ) -> Option> { + #[allow(unused_variables)] shared: &Shared, + ) -> Result, PopTaskError> { + // try to get a task from the local queues #[cfg(feature = "retry")] if let Some(task) = self.try_pop_retry(thread_index) { - return Some(task); + return Ok(task); } #[cfg(feature = "batching")] if let Some(task) = self.try_pop_local_or_refill(thread_index, shared) { - return Some(task); + return Ok(task); } - None + // fall back to requesting a task from the global queue + self.try_pop_timeout(RECV_TIMEOUT) } - fn drain(&self) -> Vec> { - let mut tasks = Vec::new(); + fn drain(&self, _: Token) -> Vec> { + let mut tasks = Vec::from_iter(self.global_rx.try_iter()); #[cfg(feature = "batching")] { self.drain_batch_queues_into(&mut tasks); @@ -148,63 +116,134 @@ impl> LocalQueues for LocalQueuesImpl { tasks } + #[cfg(feature = "batching")] + fn resize_local>( + &self, + start_index: usize, + end_index: usize, + new_size: usize, + shared: &Shared, + ) { + self.resize_batch_queues(start_index, end_index, new_size, shared); + } + #[cfg(feature = "retry")] - fn retry, P: QueuePair>( + fn retry>( &self, task: Task, thread_index: usize, - shared: &Shared, + shared: &Shared, ) -> Option { self.try_push_retry(task, thread_index, shared) } + + fn close(&self, _: Token) { + self.closed.set(true); + } } -impl> Default for LocalQueuesImpl { - fn default() -> Self { - Self { - #[cfg(feature = "batching")] - batch_queues: Default::default(), - #[cfg(feature = "retry")] - retry_queues: Default::default(), - _global: PhantomData, +impl ChannelTaskQueues { + #[inline] + fn try_pop_timeout(&self, timeout: Duration) -> Result, PopTaskError> { + match self.global_rx.recv_timeout(timeout) { + Ok(task) => Ok(task), + Err(RecvTimeoutError::Disconnected) => Err(PopTaskError::Closed), + Err(RecvTimeoutError::Timeout) if self.closed.get() && self.global_rx.is_empty() => { + Err(PopTaskError::Closed) + } + Err(RecvTimeoutError::Timeout) => Err(PopTaskError::Empty), + } + } +} + +#[cfg(feature = "retry")] +impl ChannelTaskQueues { + #[inline] + fn try_pop_retry(&self, thread_index: usize) -> Option> { + self.local_retry_queues + .read() + .get(thread_index) + .and_then(|queue| queue.try_pop()) + } +} + +#[cfg(feature = "batching")] +impl ChannelTaskQueues { + #[inline] + fn try_push_local(&self, task: Task, thread_index: usize) -> Result<(), Task> { + self.local_batch_queues.read()[thread_index].push(task) + } + + #[inline] + fn try_pop_local_or_refill>( + &self, + thread_index: usize, + shared: &Shared, + ) -> Option> { + let local_queue = &self.local_batch_queues.read()[thread_index]; + // pop from the local queue if it has any tasks + if !local_queue.is_empty() { + return local_queue.pop(); + } + // otherwise pull at least 1 and up to `batch_size + 1` tasks from the input channel + // wait for the next task from the receiver + let first = self.try_pop_timeout(RECV_TIMEOUT).ok(); + // if we fail after trying to get one, don't keep trying to fill the queue + if first.is_some() { + let batch_size = shared.batch_size(); + // batch size 0 means batching is disabled + if batch_size > 0 { + // otherwise try to take up to `batch_size` tasks from the input channel + // and add them to the local queue, but don't block if the input channel + // is empty + for result in self + .global_rx + .try_iter() + .take(batch_size) + .map(|task| local_queue.push(task)) + { + if let Err(task) = result { + // for some reason we can't push the task to the local queue; + // this should never happen, but just in case we turn it into an + // unprocessed outcome and stop iterating + shared.abandon_task(task); + break; + } + } + } } + first } } #[cfg(feature = "batching")] mod batching { - use super::LocalQueuesImpl; + use super::{ChannelTaskQueues, Task}; use crate::bee::{Queen, Worker}; - use crate::hive::{GlobalQueue, QueuePair, Shared, Task}; + use crate::hive::inner::Shared; use crossbeam_queue::ArrayQueue; use std::collections::HashSet; use std::time::Duration; - impl> LocalQueuesImpl { - pub(super) fn init_batch_queues_for_threads< - Q: Queen, - P: QueuePair, - >( + impl ChannelTaskQueues { + pub(super) fn init_batch_queues_for_threads>( &self, start_index: usize, end_index: usize, - shared: &Shared, + shared: &Shared, ) { - let mut batch_queues = self.batch_queues.write(); + let mut batch_queues = self.local_batch_queues.write(); assert_eq!(batch_queues.len(), start_index); let queue_size = shared.batch_size().max(1); (start_index..end_index).for_each(|_| batch_queues.push(ArrayQueue::new(queue_size))); } - pub(super) fn resize_batch_queues< - Q: Queen, - P: QueuePair, - >( + pub(super) fn resize_batch_queues>( &self, start_index: usize, end_index: usize, batch_size: usize, - shared: &Shared, + shared: &Shared, ) { // keep track of which queues need to be resized // TODO: this method could cause a hang if one of the worker threads is stuck - we @@ -216,7 +255,7 @@ mod batching { loop { // scope the mutable access to local_queues { - let mut batch_queues = self.batch_queues.write(); + let mut batch_queues = self.local_batch_queues.write(); to_resize.retain(|thread_index| { let queue = if let Some(queue) = batch_queues.get_mut(*thread_index) { queue @@ -251,7 +290,7 @@ mod batching { pub(super) fn drain_batch_queues_into(&self, tasks: &mut Vec>) { let _ = self - .batch_queues + .local_batch_queues .write() .iter_mut() .fold(tasks, |tasks, queue| { @@ -267,29 +306,26 @@ mod batching { #[cfg(feature = "retry")] mod retry { - use super::LocalQueuesImpl; + use super::{ChannelTaskQueues, Task}; use crate::bee::{Queen, Worker}; - use crate::hive::queue::delay::DelayQueue; - use crate::hive::{GlobalQueue, QueuePair, Shared, Task}; + use crate::hive::inner::queue::delay::DelayQueue; + use crate::hive::inner::Shared; use std::time::{Duration, Instant}; - impl> LocalQueuesImpl { + impl ChannelTaskQueues { /// Initializes the retry queues worker threads in the specified range. pub(super) fn init_retry_queues_for_threads(&self, start_index: usize, end_index: usize) { - let mut retry_queues = self.retry_queues.write(); + let mut retry_queues = self.local_retry_queues.write(); assert_eq!(retry_queues.len(), start_index); (start_index..end_index).for_each(|_| retry_queues.push(DelayQueue::default())) } /// Adds a task to the retry queue with a delay based on `attempt`. - pub(super) fn try_push_retry< - Q: Queen, - P: QueuePair, - >( + pub(super) fn try_push_retry>( &self, task: Task, thread_index: usize, - shared: &Shared, + shared: &Shared, ) -> Option { // compute the delay let delay = shared @@ -307,7 +343,7 @@ mod retry { .unwrap() }) .unwrap_or_default(); - if let Some(queue) = self.retry_queues.read().get(thread_index) { + if let Some(queue) = self.local_retry_queues.read().get(thread_index) { queue.push(task, delay) } else { Err(task) @@ -319,7 +355,7 @@ mod retry { pub(super) fn drain_retry_queues_into(&self, tasks: &mut Vec>) { let _ = self - .retry_queues + .local_retry_queues .write() .iter_mut() .fold(tasks, |tasks, queue| { diff --git a/src/hive/queue/delay.rs b/src/hive/inner/queue/delay.rs similarity index 100% rename from src/hive/queue/delay.rs rename to src/hive/inner/queue/delay.rs diff --git a/src/hive/inner/queue/mod.rs b/src/hive/inner/queue/mod.rs new file mode 100644 index 0000000..b0a4f59 --- /dev/null +++ b/src/hive/inner/queue/mod.rs @@ -0,0 +1,96 @@ +mod channel; +#[cfg(feature = "retry")] +mod delay; +//mod workstealing; + +pub use self::channel::ChannelTaskQueues; + +use super::{Shared, Task, Token}; +use crate::bee::{Queen, Worker}; + +/// Errors that may occur when trying to pop tasks from the global queue. +#[derive(thiserror::Error, Debug)] +pub enum PopTaskError { + #[error("Global task queue is empty")] + Empty, + #[error("Global task queue is closed")] + Closed, +} + +/// Trait that encapsulates the global and local task queues used by a `Hive` for managing tasks +/// within and between worker threads. +/// +/// This trait is sealed - it cannot be implemented outside of this crate. +pub trait TaskQueues: Sized + Send + Sync + 'static { + /// Returns a new instance. + fn new(token: Token) -> Self; + + /// Initializes the local queues for the given range of worker thread indices. + fn init_for_threads>( + &self, + start_index: usize, + end_index: usize, + shared: &Shared, + ); + + /// Changes the size of the local queues to `new_size`. + #[cfg(feature = "batching")] + fn resize_local>( + &self, + start_index: usize, + end_index: usize, + new_size: usize, + shared: &Shared, + ); + + /// Tries to add a task to the global queue. + /// + /// Returns an error with the task if the queue is disconnected. + fn try_push_global(&self, task: Task) -> Result<(), Task>; + + /// Attempts to add a task to the local queue if space is available, otherwise adds it to the + /// global queue. + /// + /// If adding to the global queue fails, the task is abandoned (converted to an + /// `Outcome::Unprocessed` and sent to the outcome channel or stored in the hive). + fn push_local>( + &self, + task: Task, + thread_index: usize, + shared: &Shared, + ); + + /// Attempts to remove a task from the local queue for the given worker thread index. If there + /// are no local queues, or if the local queues are empty, falls back to taking a task from the + /// global queue. + /// + /// Returns an error if a task is not available, where each implementation may have a different + /// definition of "available". + /// + /// Also returns an error if the queue is empty or disconnected. + fn try_pop>( + &self, + thread_index: usize, + shared: &Shared, + ) -> Result, PopTaskError>; + + /// Drains all tasks from all global and local queues and returns them as a `Vec`. + fn drain(&self, token: Token) -> Vec>; + + /// Attempts to add `task` to the local retry queue. + /// + /// Returns the earliest `Instant` at which it might be retried. If the task could not be added + /// to the retry queue (e.g., if the queue is full), the task is abandoned (converted to + /// `Outcome::Unprocessed` and sent to the outcome channel or stored in the hive) and this + /// method returns `None`. + #[cfg(feature = "retry")] + fn retry>( + &self, + task: Task, + thread_index: usize, + shared: &Shared, + ) -> Option; + + /// Closes this `GlobalQueue` so no more tasks may be pushed. + fn close(&self, token: Token); +} diff --git a/src/hive/inner/queue/workstealing.rs b/src/hive/inner/queue/workstealing.rs new file mode 100644 index 0000000..345db3e --- /dev/null +++ b/src/hive/inner/queue/workstealing.rs @@ -0,0 +1,36 @@ +use super::{GlobalTaskQueue, LocalTaskQueues, Task}; +use crate::bee::{Queen, Worker}; +use crate::hive::Shared; +use crossbeam_deque::{Injector, Stealer, Worker as LocalQueue}; +use std::marker::PhantomData; + +struct GlobalQueue { + queue: Injector>, + _worker: PhantomData, +} + +impl GlobalTaskQueue for GlobalQueue { + fn try_push(&self, task: Task) -> Result<(), Task> { + self.queue.push(task); + } + + fn try_pop(&self) -> Result, super::PopTaskError> { + todo!() + } + + fn try_iter(&self) -> impl Iterator> + '_ { + todo!() + } + + fn drain(&self) -> Vec> { + todo!() + } + + fn close(&self) { + todo!() + } +} + +struct WorkerQueue {} + +impl LocalTaskQueues> for WorkerQueue {} diff --git a/src/hive/shared.rs b/src/hive/inner/shared.rs similarity index 88% rename from src/hive/shared.rs rename to src/hive/inner/shared.rs index 823bbbd..f421cbf 100644 --- a/src/hive/shared.rs +++ b/src/hive/inner/shared.rs @@ -1,30 +1,22 @@ -use super::{ - Config, GlobalQueue, Husk, LocalQueues, Outcome, OutcomeSender, QueuePair, Shared, SpawnError, - Task, -}; +use super::{Config, PopTaskError, Shared, Task, TaskQueues, Token}; use crate::atomic::{Atomic, AtomicInt, AtomicUsize}; use crate::bee::{Queen, TaskId, Worker}; use crate::channel::SenderExt; +use crate::hive::{Husk, Outcome, OutcomeSender, SpawnError}; +use parking_lot::MutexGuard; use std::collections::HashMap; use std::ops::DerefMut; use std::thread::{Builder, JoinHandle}; use std::{fmt, iter}; -impl Shared -where - W: Worker, - Q: Queen, - P: QueuePair, -{ +impl, T: TaskQueues> Shared { /// Creates a new `Shared` instance with the given configuration, queen, and task receiver, /// and all other fields set to their default values. pub fn new(config: Config, queen: Q) -> Self { - let (global_queue, local_queues) = P::new(); Shared { config, queen, - global_queue, - local_queues, + task_queues: T::new(Token), spawn_results: Default::default(), num_tasks: Default::default(), next_task_id: Default::default(), @@ -50,13 +42,18 @@ where builder } + /// Returns the current number of worker threads. + pub fn num_threads(&self) -> usize { + self.config.num_threads.get_or_default() + } + /// Spawns the initial set of `self.config.num_threads` worker threads using the provided /// spawning function. Returns the number of worker threads that were successfully started. pub fn init_threads(&self, f: F) -> usize where F: Fn(usize) -> Result, SpawnError>, { - let num_threads = self.config.num_threads.get_or_default(); + let num_threads = self.num_threads(); if num_threads == 0 { return 0; } @@ -90,7 +87,7 @@ where assert_eq!(spawn_results.len(), start_index); let end_index = start_index + num_threads; // if worker threads need a local queue, initialize them before spawning - self.local_queues + self.task_queues .init_for_threads(start_index, end_index, self); // spawn the worker threads and return the results let results: Vec<_> = (start_index..end_index).map(f).collect(); @@ -143,6 +140,10 @@ where .count() } + pub fn spawn_results(&self) -> MutexGuard, SpawnError>>> { + self.spawn_results.lock() + } + /// Returns a new `Worker` from the queen, or an error if a `Worker` could not be created. pub fn create_worker(&self) -> Q::Kind { self.queen.create() @@ -167,7 +168,7 @@ where if let Some(abandoned_task) = if self.is_poisoned() { Some(task) } else { - self.global_queue.try_push(task).err() + self.task_queues.try_push_global(task).err() } { self.abandon_task(abandoned_task); } @@ -180,7 +181,7 @@ where input: W::Input, outcome_tx: Option<&OutcomeSender>, ) -> TaskId { - if self.config.num_threads.get_or_default() == 0 { + if self.num_threads() == 0 { dbg!("WARNING: no worker threads are active for hive"); } let task = self.prepare_task(input, outcome_tx); @@ -200,22 +201,22 @@ where ) -> TaskId { let task = self.prepare_task(input, outcome_tx); let task_id = task.id(); - self.local_queues.push(task, thread_index, self); + self.task_queues.push_local(task, thread_index, self); task_id } /// Creates a new `Task` for each input in the given batch and sends them to the global queue. - pub fn send_batch_global( + pub fn send_batch_global( &self, - inputs: T, + inputs: I, outcome_tx: Option<&OutcomeSender>, ) -> Vec where - T: IntoIterator, - T::IntoIter: ExactSizeIterator, + I: IntoIterator, + I::IntoIter: ExactSizeIterator, { #[cfg(debug_assertions)] - if self.config.num_threads.get_or_default() == 0 { + if self.num_threads() == 0 { dbg!("WARNING: no worker threads are active for hive"); } let iter = inputs.into_iter(); @@ -248,7 +249,7 @@ where // try to send the task to the hive; if sending fails, convert the task into an // `Unprocessed` outcome and try to send it to the outcome channel; if that // fails, store the outcome in the hive - if let Err(task) = self.global_queue.try_push(task) { + if let Err(task) = self.task_queues.try_push_global(task) { self.abandon_task(task); } task_id @@ -277,18 +278,16 @@ where if self.is_poisoned() { return None; } - // try to get a task from the local queues - if let Some(task) = self.local_queues.try_pop(thread_index, self) { - break Ok(task); - } - // fall back to requesting a task from the global queue - if let Some(result) = self.global_queue.try_pop() { - break result; + // get the next task from the queue - break if its closed + match self.task_queues.try_pop(thread_index, self) { + Ok(task) => break Some(task), + Err(PopTaskError::Closed) => break None, + Err(PopTaskError::Empty) => continue, } } // if a task was successfully received, decrement the queued counter and increment the // active counter - .map(|task| match self.num_tasks.transfer(1) { + .and_then(|task| match self.num_tasks.transfer(1) { Ok(_) => Some(task), Err(_) => { // the hive is in a corrupted state - abandon this task and then poison the hive @@ -298,8 +297,6 @@ where None } }) - .ok() - .flatten() } /// Sends an outcome to `outcome_tx`, or stores it in the `Hive` shared data if there is no @@ -370,6 +367,19 @@ where self.no_work_notify_all(); } + /// Returns a reference to the `Queen`. + /// + /// Note that, if the queen is a `QueenMut`, the returned value will be a `QueenCell`, and it + /// is necessary to call its `get()` method to obtain a reference to the inner queen. + pub fn queen(&self) -> &Q { + &self.queen + } + + /// Returns a reference to the `Config`. + pub fn config(&self) -> &Config { + &self.config + } + /// Returns a tuple with the number of (queued, active) tasks. #[inline] pub fn num_tasks(&self) -> (u64, u64) { @@ -399,6 +409,10 @@ where } } + pub fn num_panics(&self) -> usize { + self.num_panics.get() + } + /// Returns the number of `Hive`s holding a reference to this shared data. pub fn num_referrers(&self) -> usize { self.num_referrers.get() @@ -488,15 +502,19 @@ where /// to send them or (if the task does not have a sender, or if the send fails) stores them /// in the `outcomes` map. fn drain_tasks_into_unprocessed(&self) { - self.abandon_batch(self.global_queue.drain().into_iter()); - self.abandon_batch(self.local_queues.drain().into_iter()); + self.abandon_batch(self.task_queues.drain(Token).into_iter()); + } + + /// Close the tasks queues so no more tasks can be added. + pub fn close(&self) { + self.task_queues.close(Token); } /// Consumes this `Shared` and returns a `Husk` containing the `Queen`, panic count, stored /// outcomes, and all configuration information necessary to create a new `Hive`. Any queued /// tasks are converted into `Outcome::Unprocessed` outcomes and either sent to the task's /// sender or (if there is no sender, or the send fails) stored in the `outcomes` map. - pub fn into_husk(self) -> Husk { + pub fn into_husk(self) -> Husk { self.drain_tasks_into_unprocessed(); Husk::new( self.config.into_unsync(), @@ -507,11 +525,11 @@ where } } -impl fmt::Debug for Shared +impl fmt::Debug for Shared where W: Worker, Q: Queen, - P: QueuePair, + T: TaskQueues, { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let (queued, active) = self.num_tasks(); @@ -526,15 +544,15 @@ where #[cfg(feature = "affinity")] mod affinity { + use super::{Shared, TaskQueues}; use crate::bee::{Queen, Worker}; use crate::hive::cores::{Core, Cores}; - use crate::hive::{QueuePair, Shared}; - impl Shared + impl Shared where W: Worker, Q: Queen, - P: QueuePair, + T: TaskQueues, { /// Adds cores to which worker threads may be pinned. pub fn add_core_affinity(&self, new_cores: Cores) { @@ -556,14 +574,14 @@ mod affinity { #[cfg(feature = "batching")] mod batching { + use super::{Shared, TaskQueues}; use crate::bee::{Queen, Worker}; - use crate::hive::{LocalQueues, QueuePair, Shared}; - impl Shared + impl Shared where W: Worker, Q: Queen, - P: QueuePair, + T: TaskQueues, { /// Returns the local queue batch size. pub fn batch_size(&self) -> usize { @@ -587,11 +605,12 @@ mod batching { if prev_batch_size == batch_size { return prev_batch_size; } - let num_threads = self.config.num_threads.get_or_default(); + let num_threads = self.num_threads(); if num_threads == 0 { return prev_batch_size; } - self.local_queues.resize(0, num_threads, batch_size, self); + self.task_queues + .resize_local(0, num_threads, batch_size, self); prev_batch_size } } @@ -599,15 +618,16 @@ mod batching { #[cfg(feature = "retry")] mod retry { - use crate::bee::{Queen, Worker}; - use crate::hive::{LocalQueues, OutcomeSender, QueuePair, Shared, Task, TaskId}; + use crate::bee::{Queen, TaskId, Worker}; + use crate::hive::inner::{Shared, Task, TaskQueues}; + use crate::hive::OutcomeSender; use std::time::Instant; - impl Shared + impl Shared where W: Worker, Q: Queen, - P: QueuePair, + T: TaskQueues, { /// Returns `true` if the hive is configured to retry tasks and the `attempt` field of the /// given `ctx` is less than the maximum number of retries. @@ -633,7 +653,7 @@ mod retry { .increment_left(1) .expect("overflowed queued task counter"); let task = Task::with_attempt(task_id, input, outcome_tx, attempt); - self.local_queues.retry(task, thread_index, self) + self.task_queues.retry(task, thread_index, self) } } } @@ -642,17 +662,14 @@ mod retry { mod tests { use crate::bee::stock::ThunkWorker; use crate::bee::DefaultQueen; - use crate::hive::queue::ChannelQueues; + use crate::hive::ChannelTaskQueues; type VoidThunkWorker = ThunkWorker<()>; - type VoidThunkWorkerShared = super::Shared< - VoidThunkWorker, - DefaultQueen, - ChannelQueues, - >; + type VoidThunkWorkerShared = + super::Shared, ChannelTaskQueues>; #[test] - fn test_sync_shared() { + fn test_sync_hared() { fn assert_sync() {} assert_sync::(); } diff --git a/src/hive/task.rs b/src/hive/inner/task.rs similarity index 96% rename from src/hive/task.rs rename to src/hive/inner/task.rs index b2dd42d..5a2f634 100644 --- a/src/hive/task.rs +++ b/src/hive/inner/task.rs @@ -1,5 +1,6 @@ -use super::{Outcome, OutcomeSender, Task}; +use super::Task; use crate::bee::{TaskId, Worker}; +use crate::hive::{Outcome, OutcomeSender}; impl Task { /// Returns the ID of this task. diff --git a/src/hive/mod.rs b/src/hive/mod.rs index 5a95368..fbd65a4 100644 --- a/src/hive/mod.rs +++ b/src/hive/mod.rs @@ -356,261 +356,59 @@ //! ([`Husk::as_builder`](crate::hive::husk::Husk::as_builder)) or a new `Hive` //! ([`Husk::into_hive`](crate::hive::husk::Husk::into_hive)). mod builder; -mod config; -mod core; -mod counter; -mod gate; -mod husk; -mod outcome; -mod queue; -//mod scoped; -mod shared; -mod task; -//mod workstealing; - #[cfg(feature = "affinity")] pub mod cores; +#[allow(clippy::module_inception)] +mod hive; +mod husk; +mod inner; +mod outcome; -pub use self::builder::Builder; -#[cfg(feature = "batching")] -pub use self::config::set_batch_size_default; -pub use self::config::{reset_defaults, set_num_threads_default, set_num_threads_default_all}; -#[cfg(feature = "retry")] -pub use self::config::{ - set_max_retries_default, set_retries_default_disabled, set_retry_factor_default, -}; -pub use self::core::Poisoned; +pub use self::builder::{BeeBuilder, ChannelBuilder, FullBuilder, OpenBuilder}; +pub use self::hive::{Hive, Poisoned}; pub use self::husk::Husk; -pub use self::outcome::{Outcome, OutcomeBatch, OutcomeIteratorExt, OutcomeQueue, OutcomeStore}; -pub(crate) use self::queue::{ChannelGlobalQueue, ChannelQueues, DefaultLocalQueues}; +pub use self::inner::{set_config::*, Builder, ChannelTaskQueues}; +pub use self::outcome::{Outcome, OutcomeBatch, OutcomeIteratorExt, OutcomeStore}; + +use self::inner::{Config, Shared, Task, TaskQueues}; +use self::outcome::{DerefOutcomes, OutcomeQueue, OwnedOutcomes}; +use crate::bee::Worker; +use crate::channel::{channel, Receiver, Sender}; +use std::io::Error as SpawnError; /// Sender type for channel used to send task outcomes. -pub type OutcomeSender = crate::channel::Sender>; +pub type OutcomeSender = Sender>; /// Receiver type for channel used to receive task outcomes. -pub type OutcomeReceiver = crate::channel::Receiver>; +pub type OutcomeReceiver = Receiver>; /// Creates a channel (`Sender`, `Receiver`) pair for sending task outcomes from the `Hive` to the /// task submitter. #[inline] pub fn outcome_channel() -> (OutcomeSender, OutcomeReceiver) { - crate::channel::channel() + channel() } pub mod prelude { pub use super::{ - outcome_channel, Builder, Hive, Husk, Outcome, OutcomeBatch, OutcomeIteratorExt, - OutcomeStore, Poisoned, + outcome_channel, Builder, ChannelBuilder, ChannelTaskQueues, Hive, Husk, OpenBuilder, + Outcome, OutcomeBatch, OutcomeIteratorExt, OutcomeStore, Poisoned, }; } -use self::counter::DualCounter; -use self::gate::{Gate, PhasedGate}; -use self::outcome::{DerefOutcomes, OwnedOutcomes}; -use crate::atomic::{AtomicAny, AtomicBool, AtomicOption, AtomicUsize}; -use crate::bee::{Queen, TaskId, Worker}; -use parking_lot::Mutex; -use std::io::Error as SpawnError; -use std::sync::Arc; -use std::thread::JoinHandle; - -type Any = AtomicOption>; -type Usize = AtomicOption; -#[cfg(feature = "retry")] -type U32 = AtomicOption; -#[cfg(feature = "retry")] -type U64 = AtomicOption; - -/// A pool of worker threads that each execute the same function. -/// -/// See the [module documentation](crate::hive) for details. -pub struct Hive< - W: Worker, - Q: Queen, - G: GlobalQueue, - L: LocalQueues, - P: QueuePair, ->(Option>>); - -/// Internal representation of a task to be processed by a `Hive`. -#[derive(Debug)] -struct Task { - id: TaskId, - input: W::Input, - outcome_tx: Option>, - #[cfg(feature = "retry")] - attempt: u32, -} - -/// Core configuration parameters that are set by a `Builder`, used in a `Hive`, and preserved in a -/// `Husk`. Fields are `AtomicOption`s, which enables them to be transitioned back and forth -/// between thread-safe and non-thread-safe contexts. -#[derive(Clone, Debug, Default)] -struct Config { - /// Number of worker threads to spawn - num_threads: Usize, - /// Name to give each worker thread - thread_name: Any, - /// Stack size for each worker thread - thread_stack_size: Usize, - /// CPU cores to which worker threads can be pinned - #[cfg(feature = "affinity")] - affinity: Any, - /// Maximum number of tasks for a worker thread to take when receiving from the input channel - #[cfg(feature = "batching")] - batch_size: Usize, - /// Maximum number of retries for a task - #[cfg(feature = "retry")] - max_retries: U32, - /// Multiplier for the retry backoff strategy - #[cfg(feature = "retry")] - retry_factor: U64, -} - -/// Data shared by all worker threads in a `Hive`. -struct Shared, P: QueuePair> { - /// core configuration parameters - config: Config, - /// the `Queen` used to create new workers - queen: Q, - /// global task queue used by the `Hive` to send tasks to the worker threads - global_queue: P::Global, - /// local queues used by worker threads to manage tasks - local_queues: P::Local, - /// The results of spawning each worker - spawn_results: Mutex, SpawnError>>>, - /// allows for 2^48 queued tasks and 2^16 active tasks - num_tasks: DualCounter<48>, - /// ID that will be assigned to the next task submitted to the `Hive` - next_task_id: AtomicUsize, - /// number of times a worker has panicked - num_panics: AtomicUsize, - /// number of `Hive` clones with a reference to this shared data - num_referrers: AtomicUsize, - /// whether the internal state of the hive is corrupted - if true, this prevents new tasks from - /// processed (new tasks may be queued but they will never be processed); currently, this can - /// only happen if the task counter somehow get corrupted - poisoned: AtomicBool, - /// whether the hive is suspended - if true, active tasks may complete and new tasks may be - /// queued, but new tasks will not be processed - suspended: AtomicBool, - /// gate used by worker threads to wait until the hive is resumed - resume_gate: Gate, - /// gate used by client threads to wait until all tasks have completed - join_gate: PhasedGate, - /// outcomes stored in the hive - outcomes: OutcomeQueue, -} - -pub trait QueuePair: Sized + 'static { - type Global: GlobalQueue; - type Local: LocalQueues; - - fn new() -> (Self::Global, Self::Local); -} - -#[derive(thiserror::Error, Debug)] -pub enum GlobalPopError { - #[error("Task queue is closed")] - Closed, - #[error("The hive has been poisoned")] - Poisoned, -} - -/// Trait that provides access to a global queue for receiving tasks. -pub trait GlobalQueue: Sized + Send + Sync + 'static { - /// Tries to add a task to the global queue. - /// - /// Returns an error if the queue is disconnected. - fn try_push(&self, task: Task) -> Result<(), Task>; - - /// Tries to take a task from the global queue. - /// - /// Returns `None` if a task is not available, where each implementation may have a different - /// definition of "available". - /// - /// Returns an error if the queue is disconnected. - fn try_pop(&self) -> Option, GlobalPopError>>; - - /// Returns an iterator that yields tasks from the queue unitl it is empty. - fn try_iter(&self) -> impl Iterator> + '_; - - /// Drains all tasks from the global queue and returns them as an iterator. - fn drain(&self) -> Vec>; - - /// Closes this `GlobalQueue` so no more tasks may be pushed. - fn close(&self); -} - -/// Trait that provides access to thread-specific queues for managing tasks. -/// -/// Ideally, these queues would be managed in a global thread-local data structure, but since tasks -/// are `Worker`-specific, each `Hive` must have it's own set of queues stored within the Hive's -/// shared data. -pub trait LocalQueues>: Sized + Send + Sync + 'static { - /// Initializes the local queues for the given range of worker thread indices. - fn init_for_threads, P: QueuePair>( - &self, - start_index: usize, - end_index: usize, - shared: &Shared, - ); - - /// Changes the size of the local queues to `size`. - #[cfg(feature = "batching")] - fn resize, P: QueuePair>( - &self, - start_index: usize, - end_index: usize, - new_size: usize, - shared: &Shared, - ); - - /// Attempts to add a task to the local queue if space is available, otherwise adds it to the - /// global queue. If adding to the global queue fails, the task is abandoned (converted to an - /// `Unprocessed` outcome and sent to the outcome channel or stored in the hive). - fn push, P: QueuePair>( - &self, - task: Task, - thread_index: usize, - shared: &Shared, - ); - - /// Attempts to remove a task from the local queue for the given worker thread index. - /// - /// Returns `None` if there is no task immediately available. - fn try_pop, P: QueuePair>( - &self, - thread_index: usize, - shared: &Shared, - ) -> Option>; - - /// Drains all tasks from all local queues and returns them as an iterator. - fn drain(&self) -> Vec>; - - /// Attempts to add `task` to the local retry queue. Returns the earliest `Instant` at which it - /// might be retried. - #[cfg(feature = "retry")] - fn retry, P: QueuePair>( - &self, - task: Task, - thread_index: usize, - shared: &Shared, - ) -> Option; -} - #[cfg(test)] mod tests { - use super::queue::{ChannelQueues, DefaultLocalQueues}; - use super::{Builder, Hive, Outcome, OutcomeIteratorExt, OutcomeStore}; + use super::{ + Builder, ChannelBuilder, ChannelTaskQueues, Hive, OpenBuilder, Outcome, OutcomeIteratorExt, + OutcomeStore, + }; use crate::barrier::IndexedBarrier; use crate::bee::stock::{Caller, OnceCaller, RefCaller, Thunk, ThunkWorker}; use crate::bee::{ - ApplyError, ApplyRefError, Context, DefaultQueen, QueenCell, QueenMut, RefWorker, - RefWorkerResult, TaskId, Worker, WorkerResult, + ApplyError, ApplyRefError, Context, DefaultQueen, QueenMut, RefWorker, RefWorkerResult, + TaskId, Worker, WorkerResult, }; use crate::channel::{Message, ReceiverExt}; use crate::hive::outcome::DerefOutcomes; - use crate::hive::queue::ChannelGlobalQueue; use std::fmt::Debug; use std::io::{self, BufRead, BufReader, Write}; use std::process::{Child, ChildStdin, ChildStdout, Command, ExitStatus, Stdio}; @@ -626,1517 +424,1512 @@ mod tests { const SHORT_TASK: Duration = Duration::from_secs(2); const LONG_TASK: Duration = Duration::from_secs(5); - type Global = ChannelGlobalQueue; - type Local = DefaultLocalQueues>; - type ThunkHive = Hive< - ThunkWorker, - DefaultQueen>, - Global>, - Local>, - ChannelQueues>, - >; + type TWrk = ThunkWorker; + type THive = Hive>, ChannelTaskQueues>>; /// Convenience function that returns a `Hive` configured with the global defaults, and the /// specified number of workers that execute `Thunk`s, i.e. closures that return `T`. - pub fn thunk_hive(num_threads: usize) -> ThunkHive { - Builder::default() + pub fn thunk_hive( + num_threads: usize, + with_defaults: bool, + ) -> THive { + let builder = if with_defaults { + ChannelBuilder::default() + } else { + ChannelBuilder::empty() + }; + builder .num_threads(num_threads) - .build_with_default::<_, ChannelQueues>>() + .with_queen_default() + .build() + } + + pub fn void_thunk_hive(num_threads: usize, with_defaults: bool) -> THive<()> { + thunk_hive(num_threads, with_defaults) + } + + #[test] + fn test_works() { + let hive = thunk_hive(TEST_TASKS, true); + let (tx, rx) = mpsc::channel(); + assert_eq!(hive.max_workers(), TEST_TASKS); + assert_eq!(hive.alive_workers(), TEST_TASKS); + assert!(!hive.has_dead_workers()); + for _ in 0..TEST_TASKS { + let tx = tx.clone(); + hive.apply_store(Thunk::of(move || { + tx.send(1).unwrap(); + })); + } + assert_eq!(rx.iter().take(TEST_TASKS).sum::(), TEST_TASKS); + } + + #[test] + fn test_grow_from_zero() { + let hive = thunk_hive::(0, true); + // check that with 0 threads no tasks are scheduled + let (tx, rx) = super::outcome_channel(); + let _ = hive.apply_send(Thunk::of(|| 0), &tx); + thread::sleep(ONE_SEC); + assert_eq!(hive.num_tasks().0, 1); + assert!(matches!(rx.try_recv_msg(), Message::ChannelEmpty)); + hive.grow(1).expect("error spawning threads"); + thread::sleep(ONE_SEC); + assert_eq!(hive.num_tasks().0, 0); + assert!(matches!( + rx.try_recv_msg(), + Message::Received(Outcome::Success { value: 0, .. }) + )); + } + + #[test] + fn test_grow() { + let hive = void_thunk_hive(TEST_TASKS, false); + // queue some long-running tasks + for _ in 0..TEST_TASKS { + hive.apply_store(Thunk::of(|| thread::sleep(LONG_TASK))); + } + thread::sleep(ONE_SEC); + assert_eq!(hive.num_tasks().1, TEST_TASKS as u64); + // increase the number of threads + let new_threads = 4; + let total_threads = new_threads + TEST_TASKS; + hive.grow(new_threads).expect("error spawning threads"); + // queue some more long-running tasks + for _ in 0..new_threads { + hive.apply_store(Thunk::of(|| thread::sleep(LONG_TASK))); + } + thread::sleep(ONE_SEC); + assert_eq!(hive.num_tasks().1, total_threads as u64); + let husk = hive.try_into_husk().unwrap(); + assert_eq!(husk.iter_successes().count(), total_threads); + } + + #[test] + fn test_suspend() { + let hive = void_thunk_hive(TEST_TASKS, false); + // queue some long-running tasks + let total_tasks = 2 * TEST_TASKS; + for _ in 0..total_tasks { + hive.apply_store(Thunk::of(|| thread::sleep(SHORT_TASK))); + } + thread::sleep(ONE_SEC); + assert_eq!(hive.num_tasks(), (TEST_TASKS as u64, TEST_TASKS as u64)); + hive.suspend(); + // active tasks should finish but no more tasks should be started + thread::sleep(SHORT_TASK); + assert_eq!(hive.num_tasks(), (TEST_TASKS as u64, 0)); + assert_eq!(hive.num_successes(), TEST_TASKS); + hive.resume(); + // new tasks should start + thread::sleep(ONE_SEC); + assert_eq!(hive.num_tasks(), (0, TEST_TASKS as u64)); + thread::sleep(SHORT_TASK); + // all tasks should be completed + assert_eq!(hive.num_tasks(), (0, 0)); + assert_eq!(hive.num_successes(), total_tasks); + } + + #[derive(Debug, Default)] + struct MyRefWorker; + + impl RefWorker for MyRefWorker { + type Input = u8; + type Output = u8; + type Error = (); + + fn apply_ref( + &mut self, + input: &Self::Input, + ctx: &Context, + ) -> RefWorkerResult { + for _ in 0..3 { + thread::sleep(Duration::from_secs(1)); + if ctx.is_cancelled() { + return Err(ApplyRefError::Cancelled); + } + } + Ok(*input) + } + } + + #[test] + fn test_suspend_with_cancelled_tasks() { + let hive: Hive<_, _> = ChannelBuilder::empty() + .num_threads(TEST_TASKS) + .with_worker_default::() + .build(); + hive.swarm_store(0..TEST_TASKS as u8); + hive.suspend(); + // wait for tasks to be cancelled + thread::sleep(Duration::from_secs(2)); + hive.resume_store(); + thread::sleep(Duration::from_secs(1)); + // unprocessed tasks should be requeued + assert_eq!(hive.num_tasks().1, TEST_TASKS as u64); + thread::sleep(Duration::from_secs(3)); + assert_eq!(hive.num_successes(), TEST_TASKS); + } + + #[test] + fn test_num_tasks_active() { + let hive = void_thunk_hive(TEST_TASKS, false); + for _ in 0..2 * TEST_TASKS { + hive.apply_store(Thunk::of(|| loop { + thread::sleep(LONG_TASK) + })); + } + thread::sleep(ONE_SEC); + assert_eq!(hive.num_tasks().1, TEST_TASKS as u64); + let num_threads = hive.max_workers(); + assert_eq!(num_threads, TEST_TASKS); + } + + #[test] + fn test_all_threads() { + let hive: THive<()> = ChannelBuilder::empty() + .with_queen_default() + .with_thread_per_core() + .build(); + let num_threads = num_cpus::get(); + for _ in 0..num_threads { + hive.apply_store(Thunk::of(|| loop { + thread::sleep(LONG_TASK) + })); + } + thread::sleep(ONE_SEC); + assert_eq!(hive.num_tasks().1, num_threads as u64); + let num_threads = hive.max_workers(); + assert_eq!(num_threads, num_threads); + } + + #[test] + fn test_panic() { + let hive = thunk_hive(TEST_TASKS, true); + let (tx, _) = super::outcome_channel(); + // Panic all the existing threads. + for _ in 0..TEST_TASKS { + hive.apply_send(Thunk::of(|| panic!("intentional panic")), &tx); + } + hive.join(); + // Ensure that none of the threads have panicked + assert_eq!(hive.num_panics(), TEST_TASKS); + let husk = hive.try_into_husk().unwrap(); + assert_eq!(husk.num_panics(), TEST_TASKS); + } + + #[test] + fn test_catch_panic() { + let hive: Hive<_, _> = ChannelBuilder::empty() + .with_worker(RefCaller::of(|_: &u8| -> Result { + panic!("intentional panic") + })) + .num_threads(TEST_TASKS) + .build(); + let (tx, rx) = super::outcome_channel(); + // Panic all the existing threads. + for i in 0..TEST_TASKS { + hive.apply_send(i as u8, &tx); + } + hive.join(); + // Ensure that none of the threads have panicked + assert_eq!(hive.num_panics(), 0); + // Check that all the results are Outcome::Panic + for outcome in rx.into_iter().take(TEST_TASKS) { + assert!(matches!(outcome, Outcome::Panic { .. })); + } + } + + #[test] + fn test_should_not_panic_on_drop_if_subtasks_panic_after_drop() { + let hive = void_thunk_hive(TEST_TASKS, false); + let waiter = Arc::new(Barrier::new(TEST_TASKS + 1)); + let waiter_count = Arc::new(AtomicUsize::new(0)); + + // panic all the existing threads in a bit + for _ in 0..TEST_TASKS { + let waiter = waiter.clone(); + let waiter_count = waiter_count.clone(); + hive.apply_store(Thunk::of(move || { + waiter_count.fetch_add(1, Ordering::SeqCst); + waiter.wait(); + panic!("intentional panic"); + })); + } + + // queued tasks will not be processed after the hive is dropped, so we need to wait to make + // sure that all tasks have started and are waiting on the barrier + // TODO: find a Barrier implementation with try_wait() semantics + thread::sleep(Duration::from_secs(1)); + assert_eq!(waiter_count.load(Ordering::SeqCst), TEST_TASKS); + + drop(hive); + + // unblock the tasks and allow them to panic + waiter.wait(); + } + + #[test] + fn test_massive_task_creation() { + let test_tasks = 4_200_000; + + let hive = thunk_hive(TEST_TASKS, true); + let b0 = IndexedBarrier::new(TEST_TASKS); + let b1 = IndexedBarrier::new(TEST_TASKS); + + let (tx, rx) = mpsc::channel(); + + for _ in 0..test_tasks { + let tx = tx.clone(); + let (b0, b1) = (b0.clone(), b1.clone()); + + hive.apply_store(Thunk::of(move || { + // Wait until the pool has been filled once. + b0.wait(); + // wait so the pool can be measured + b1.wait(); + assert!(tx.send(1).is_ok()); + })); + } + + b0.wait(); + assert_eq!(hive.num_tasks().1, TEST_TASKS as u64); + b1.wait(); + + assert_eq!(rx.iter().take(test_tasks).sum::(), test_tasks); + hive.join(); + + let atomic_num_tasks_active = hive.num_tasks().1; + assert!( + atomic_num_tasks_active == 0, + "atomic_num_tasks_active: {}", + atomic_num_tasks_active + ); + } + + #[test] + fn test_name() { + let name = "test"; + let hive: THive<()> = ChannelBuilder::empty() + .with_queen_default() + .thread_name(name.to_owned()) + .num_threads(2) + .build(); + let (tx, rx) = mpsc::channel(); + + // initial thread should share the name "test" + for _ in 0..2 { + let tx = tx.clone(); + hive.apply_store(Thunk::of(move || { + let name = thread::current().name().unwrap().to_owned(); + tx.send(name).unwrap(); + })); + } + + // new spawn thread should share the name "test" too. + hive.grow(3).expect("error spawning threads"); + let tx_clone = tx.clone(); + hive.apply_store(Thunk::of(move || { + let name = thread::current().name().unwrap().to_owned(); + tx_clone.send(name).unwrap(); + })); + + for thread_name in rx.iter().take(3) { + assert_eq!(name, thread_name); + } + } + + #[test] + fn test_stack_size() { + let stack_size = 4_000_000; + + let hive: THive = ChannelBuilder::empty() + .with_queen_default() + .num_threads(1) + .thread_stack_size(stack_size) + .build(); + + let actual_stack_size = hive + .apply(Thunk::of(|| { + //println!("This thread has a 4 MB stack size!"); + stacker::remaining_stack().unwrap() + })) + .unwrap() as f64; + + // measured value should be within 1% of actual + assert!(actual_stack_size > (stack_size as f64 * 0.99)); + assert!(actual_stack_size < (stack_size as f64 * 1.01)); + } + + #[test] + fn test_debug() { + let hive = void_thunk_hive(4, true); + let debug = format!("{:?}", hive); + assert_eq!( + debug, + "Hive { task_tx: Sender { .. }, shared: Shared { name: None, num_threads: 4, num_tasks_queued: 0, num_tasks_active: 0 } }" + ); + + let hive: THive = ChannelBuilder::empty() + .with_queen_default() + .thread_name("hello") + .num_threads(4) + .build(); + let debug = format!("{:?}", hive); + assert_eq!( + debug, + "Hive { task_tx: Sender { .. }, shared: Shared { name: \"hello\", num_threads: 4, num_tasks_queued: 0, num_tasks_active: 0 } }" + ); + + let hive = thunk_hive(4, true); + hive.apply_store(Thunk::of(|| thread::sleep(LONG_TASK))); + thread::sleep(ONE_SEC); + let debug = format!("{:?}", hive); + assert_eq!( + debug, + "Hive { task_tx: Sender { .. }, shared: Shared { name: None, num_threads: 4, num_tasks_queued: 0, num_tasks_active: 1 } }" + ); + } + + #[test] + fn test_repeated_join() { + let hive: THive<()> = ChannelBuilder::empty() + .with_queen_default() + .thread_name("repeated join test") + .num_threads(8) + .build(); + + let test_count = Arc::new(AtomicUsize::new(0)); + + for _ in 0..42 { + let test_count = test_count.clone(); + hive.apply_store(Thunk::of(move || { + thread::sleep(SHORT_TASK); + test_count.fetch_add(1, Ordering::Release); + })); + } + + hive.join(); + assert_eq!(42, test_count.load(Ordering::Acquire)); + + for _ in 0..42 { + let test_count = test_count.clone(); + hive.apply_store(Thunk::of(move || { + thread::sleep(SHORT_TASK); + test_count.fetch_add(1, Ordering::Relaxed); + })); + } + hive.join(); + assert_eq!(84, test_count.load(Ordering::Relaxed)); + } + + #[test] + fn test_multi_join() { + // Toggle the following lines to debug the deadlock + // fn error(_s: String) { + // use ::std::io::Write; + // let stderr = ::std::io::stderr(); + // let mut stderr = stderr.lock(); + // stderr + // .write(&_s.as_bytes()) + // .expect("Failed to write to stderr"); + // } + + let hive0: THive<()> = ChannelBuilder::empty() + .with_queen_default() + .thread_name("multi join pool0") + .num_threads(4) + .build(); + let hive1: THive<()> = ChannelBuilder::empty() + .with_queen_default() + .thread_name("multi join pool1") + .num_threads(4) + .build(); + let (tx, rx) = crate::channel::channel(); + + for i in 0..8 { + let hive1_clone = hive1.clone(); + let hive0_clone = hive0.clone(); + let tx = tx.clone(); + hive0.apply_store(Thunk::of(move || { + hive1_clone.apply_store(Thunk::of(move || { + //error(format!("p1: {} -=- {:?}\n", i, hive0_clone)); + hive0_clone.join(); + // ensure that the main thread has a chance to execute + thread::sleep(Duration::from_millis(10)); + //error(format!("p1: send({})\n", i)); + tx.send(i).expect("send failed from hive1_clone to main"); + })); + //error(format!("p0: {}\n", i)); + })); + } + drop(tx); + + // no hive1 task should be completed yet, so the channel should be empty + let before_any_send = rx.try_recv_msg(); + assert!(matches!(before_any_send, Message::ChannelEmpty)); + //error(format!("{:?}\n{:?}\n", hive0, hive1)); + hive0.join(); + //error(format!("pool0.join() complete =-= {:?}", hive1)); + hive1.join(); + //error("pool1.join() complete\n".into()); + assert_eq!(rx.into_iter().sum::(), (0..8).sum()); + } + + #[test] + fn test_empty_hive() { + // Joining an empty hive must return imminently + // TODO: run this in a thread and kill it after a timeout to prevent hanging the tests + let hive = void_thunk_hive(4, true); + hive.join(); + } + + #[test] + fn test_no_fun_or_joy() { + // What happens when you keep adding tasks after a join + + fn sleepy_function() { + thread::sleep(LONG_TASK); + } + + let hive: THive<()> = ChannelBuilder::empty() + .with_queen_default() + .thread_name("no fun or joy") + .num_threads(8) + .build(); + + hive.apply_store(Thunk::of(sleepy_function)); + + let p_t = hive.clone(); + thread::spawn(move || { + (0..23) + .inspect(|_| { + p_t.apply_store(Thunk::of(sleepy_function)); + }) + .count(); + }); + + hive.join(); + } + + #[test] + fn test_map() { + let hive = thunk_hive::(2, false); + let outputs: Vec<_> = hive + .map((0..10u8).map(|i| { + Thunk::of(move || { + thread::sleep(Duration::from_millis((10 - i as u64) * 100)); + i + }) + })) + .map(Outcome::unwrap) + .collect(); + assert_eq!(outputs, (0..10).collect::>()) + } + + #[test] + fn test_map_unordered() { + let hive = thunk_hive::(8, false); + let outputs: Vec<_> = hive + .map_unordered((0..8u8).map(|i| { + Thunk::of(move || { + thread::sleep(Duration::from_millis((8 - i as u64) * 100)); + i + }) + })) + .map(Outcome::unwrap) + .collect(); + assert_eq!(outputs, (0..8).rev().collect::>()) + } + + #[test] + fn test_map_send() { + let hive = thunk_hive::(8, false); + let (tx, rx) = super::outcome_channel(); + let mut task_ids = hive.map_send( + (0..8u8).map(|i| { + Thunk::of(move || { + thread::sleep(Duration::from_millis((8 - i as u64) * 100)); + i + }) + }), + &tx, + ); + let (mut outcome_task_ids, values): (Vec, Vec) = rx + .iter() + .map(|outcome| match outcome { + Outcome::Success { value, task_id } => (task_id, value), + _ => panic!("unexpected error"), + }) + .unzip(); + assert_eq!(values, (0..8).rev().collect::>()); + task_ids.sort(); + outcome_task_ids.sort(); + assert_eq!(task_ids, outcome_task_ids); + } + + #[test] + fn test_map_store() { + let mut hive = thunk_hive::(8, false); + let mut task_ids = hive.map_store((0..8u8).map(|i| { + Thunk::of(move || { + thread::sleep(Duration::from_millis((8 - i as u64) * 100)); + i + }) + })); + hive.join(); + for i in task_ids.iter() { + assert!(hive.outcomes_deref().get(i).unwrap().is_success()); + } + let (mut outcome_task_ids, values): (Vec, Vec) = task_ids + .clone() + .into_iter() + .map(|i| (i, hive.remove_success(i).unwrap())) + .collect(); + assert_eq!(values, (0..8).collect::>()); + task_ids.sort(); + outcome_task_ids.sort(); + assert_eq!(task_ids, outcome_task_ids); + } + + #[test] + fn test_swarm() { + let hive = thunk_hive::(2, false); + let outputs: Vec<_> = hive + .swarm((0..10u8).map(|i| { + Thunk::of(move || { + thread::sleep(Duration::from_millis((10 - i as u64) * 100)); + i + }) + })) + .map(Outcome::unwrap) + .collect(); + assert_eq!(outputs, (0..10).collect::>()) + } + + #[test] + fn test_swarm_unordered() { + let hive = thunk_hive::(8, false); + let outputs: Vec<_> = hive + .swarm_unordered((0..8u8).map(|i| { + Thunk::of(move || { + thread::sleep(Duration::from_millis((8 - i as u64) * 100)); + i + }) + })) + .map(Outcome::unwrap) + .collect(); + assert_eq!(outputs, (0..8).rev().collect::>()) + } + + #[test] + fn test_swarm_send() { + let hive = thunk_hive::(8, false); + let (tx, rx) = super::outcome_channel(); + let mut task_ids = hive.swarm_send( + (0..8u8).map(|i| { + Thunk::of(move || { + thread::sleep(Duration::from_millis((8 - i as u64) * 100)); + i + }) + }), + &tx, + ); + let (mut outcome_task_ids, values): (Vec, Vec) = rx + .iter() + .map(|outcome| match outcome { + Outcome::Success { value, task_id } => (task_id, value), + _ => panic!("unexpected error"), + }) + .unzip(); + assert_eq!(values, (0..8).rev().collect::>()); + task_ids.sort(); + outcome_task_ids.sort(); + assert_eq!(task_ids, outcome_task_ids); + } + + #[test] + fn test_swarm_store() { + let mut hive = thunk_hive::(8, false); + let mut task_ids = hive.swarm_store((0..8u8).map(|i| { + Thunk::of(move || { + thread::sleep(Duration::from_millis((8 - i as u64) * 100)); + i + }) + })); + hive.join(); + for i in task_ids.iter() { + assert!(hive.outcomes_deref().get(i).unwrap().is_success()); + } + let (mut outcome_task_ids, values): (Vec, Vec) = task_ids + .clone() + .into_iter() + .map(|i| (i, hive.remove_success(i).unwrap())) + .collect(); + assert_eq!(values, (0..8).collect::>()); + task_ids.sort(); + outcome_task_ids.sort(); + assert_eq!(task_ids, outcome_task_ids); + } + + #[test] + fn test_scan() { + let hive = ChannelBuilder::empty() + .with_worker(Caller::of(|i| i * i)) + .num_threads(4) + .build(); + let (outputs, state) = hive.scan(0..10, 0, |acc, i| { + *acc += i; + *acc + }); + let mut outputs = outputs.unwrap(); + outputs.sort(); + assert_eq!(outputs.len(), 10); + assert_eq!(state, 45); + assert_eq!( + outputs, + (0..10) + .scan(0, |acc, i| { + *acc += i; + Some(*acc) + }) + .map(|i| i * i) + .collect::>() + ); + } + + #[test] + fn test_scan_send() { + let hive = ChannelBuilder::empty() + .with_worker(Caller::of(|i| i * i)) + .num_threads(4) + .build(); + let (tx, rx) = super::outcome_channel(); + let (mut task_ids, state) = hive.scan_send(0..10, &tx, 0, |acc, i| { + *acc += i; + *acc + }); + assert_eq!(task_ids.len(), 10); + assert_eq!(state, 45); + let (mut outcome_task_ids, mut values): (Vec, Vec) = rx + .iter() + .map(|outcome| match outcome { + Outcome::Success { value, task_id } => (task_id, value), + _ => panic!("unexpected error"), + }) + .unzip(); + values.sort(); + assert_eq!( + values, + (0..10) + .scan(0, |acc, i| { + *acc += i; + Some(*acc) + }) + .map(|i| i * i) + .collect::>() + ); + task_ids.sort(); + outcome_task_ids.sort(); + assert_eq!(task_ids, outcome_task_ids); + } + + #[test] + fn test_try_scan_send() { + let hive = ChannelBuilder::empty() + .with_worker(Caller::of(|i| i * i)) + .num_threads(4) + .build(); + let (tx, rx) = super::outcome_channel(); + let (results, state) = hive.try_scan_send(0..10, &tx, 0, |acc, i| { + *acc += i; + Ok::<_, String>(*acc) + }); + let mut task_ids: Vec<_> = results.into_iter().map(Result::unwrap).collect(); + assert_eq!(task_ids.len(), 10); + assert_eq!(state, 45); + let (mut outcome_task_ids, mut values): (Vec, Vec) = rx + .iter() + .map(|outcome| match outcome { + Outcome::Success { value, task_id } => (task_id, value), + _ => panic!("unexpected error"), + }) + .unzip(); + values.sort(); + assert_eq!( + values, + (0..10) + .scan(0, |acc, i| { + *acc += i; + Some(*acc) + }) + .map(|i| i * i) + .collect::>() + ); + task_ids.sort(); + outcome_task_ids.sort(); + assert_eq!(task_ids, outcome_task_ids); + } + + #[test] + #[should_panic] + fn test_try_scan_send_fail() { + let hive = ChannelBuilder::empty() + .with_worker(OnceCaller::of(|i: i32| Ok::<_, String>(i * i))) + .num_threads(4) + .build(); + let (tx, _) = super::outcome_channel(); + let _ = hive + .try_scan_send(0..10, &tx, 0, |_, _| Err("fail")) + .0 + .into_iter() + .map(Result::unwrap) + .collect::>(); + } + + #[test] + fn test_scan_store() { + let mut hive = ChannelBuilder::empty() + .with_worker(Caller::of(|i| i * i)) + .num_threads(4) + .build(); + let (mut task_ids, state) = hive.scan_store(0..10, 0, |acc, i| { + *acc += i; + *acc + }); + assert_eq!(task_ids.len(), 10); + assert_eq!(state, 45); + hive.join(); + for i in task_ids.iter() { + assert!(hive.outcomes_deref().get(i).unwrap().is_success()); + } + let (mut outcome_task_ids, values): (Vec, Vec) = task_ids + .clone() + .into_iter() + .map(|i| (i, hive.remove_success(i).unwrap())) + .unzip(); + assert_eq!( + values, + (0..10) + .scan(0, |acc, i| { + *acc += i; + Some(*acc) + }) + .map(|i| i * i) + .collect::>() + ); + task_ids.sort(); + outcome_task_ids.sort(); + assert_eq!(task_ids, outcome_task_ids); + } + + #[test] + fn test_try_scan_store() { + let mut hive = ChannelBuilder::empty() + .with_worker(Caller::of(|i| i * i)) + .num_threads(4) + .build(); + let (results, state) = hive.try_scan_store(0..10, 0, |acc, i| { + *acc += i; + Ok::(*acc) + }); + let mut task_ids: Vec<_> = results.into_iter().map(Result::unwrap).collect(); + assert_eq!(task_ids.len(), 10); + assert_eq!(state, 45); + hive.join(); + for i in task_ids.iter() { + assert!(hive.outcomes_deref().get(i).unwrap().is_success()); + } + let (mut outcome_task_ids, values): (Vec, Vec) = task_ids + .clone() + .into_iter() + .map(|i| (i, hive.remove_success(i).unwrap())) + .unzip(); + assert_eq!( + values, + (0..10) + .scan(0, |acc, i| { + *acc += i; + Some(*acc) + }) + .map(|i| i * i) + .collect::>() + ); + task_ids.sort(); + outcome_task_ids.sort(); + assert_eq!(task_ids, outcome_task_ids); + } + + #[test] + #[should_panic] + fn test_try_scan_store_fail() { + let hive = ChannelBuilder::empty() + .with_worker(OnceCaller::of(|i: i32| Ok::(i * i))) + .num_threads(4) + .build(); + let _ = hive + .try_scan_store(0..10, 0, |_, _| Err("fail")) + .0 + .into_iter() + .map(Result::unwrap) + .collect::>(); + } + + #[test] + fn test_husk() { + let hive1 = thunk_hive::(8, false); + let task_ids = hive1.map_store((0..8u8).map(|i| Thunk::of(move || i))); + hive1.join(); + let mut husk1 = hive1.try_into_husk().unwrap(); + for i in task_ids.iter() { + assert!(husk1.outcomes_deref().get(i).unwrap().is_success()); + assert!(matches!(husk1.get(*i), Some(Outcome::Success { .. }))); + } + + let builder = husk1.as_builder(); + let hive2 = builder + .num_threads(4) + .with_worker_default::>() + .with_channel_queues() + .build(); + hive2.map_store((0..8u8).map(|i| { + Thunk::of(move || { + thread::sleep(Duration::from_millis((8 - i as u64) * 100)); + i + }) + })); + hive2.join(); + let mut husk2 = hive2.try_into_husk().unwrap(); + + let mut outputs1 = husk1 + .remove_all() + .into_iter() + .map(Outcome::unwrap) + .collect::>(); + outputs1.sort(); + let mut outputs2 = husk2 + .remove_all() + .into_iter() + .map(Outcome::unwrap) + .collect::>(); + outputs2.sort(); + assert_eq!(outputs1, outputs2); + + let hive3 = husk1.into_hive::>(); + hive3.map_store((0..8u8).map(|i| { + Thunk::of(move || { + thread::sleep(Duration::from_millis((8 - i as u64) * 100)); + i + }) + })); + hive3.join(); + let husk3 = hive3.try_into_husk().unwrap(); + let (_, outcomes3) = husk3.into_parts(); + let mut outputs3 = outcomes3 + .into_iter() + .map(Outcome::unwrap) + .collect::>(); + outputs3.sort(); + assert_eq!(outputs1, outputs3); + } + + #[test] + fn test_clone() { + let hive: THive<()> = ChannelBuilder::empty() + .with_worker_default() + .thread_name("clone example") + .num_threads(2) + .build(); + + // This batch of tasks will occupy the pool for some time + for _ in 0..6 { + hive.apply_store(Thunk::of(|| { + thread::sleep(SHORT_TASK); + })); + } + + // The following tasks will be inserted into the pool in a random fashion + let t0 = { + let hive = hive.clone(); + thread::spawn(move || { + // wait for the first batch of tasks to finish + hive.join(); + + let (tx, rx) = mpsc::channel(); + for i in 0..42 { + let tx = tx.clone(); + hive.apply_store(Thunk::of(move || { + tx.send(i).expect("channel will be waiting"); + })); + } + drop(tx); + rx.iter().sum::() + }) + }; + let t1 = { + let pool = hive.clone(); + thread::spawn(move || { + // wait for the first batch of tasks to finish + pool.join(); + + let (tx, rx) = mpsc::channel(); + for i in 1..12 { + let tx = tx.clone(); + pool.apply_store(Thunk::of(move || { + tx.send(i).expect("channel will be waiting"); + })); + } + drop(tx); + rx.iter().product::() + }) + }; + + assert_eq!( + 861, + t0.join() + .expect("thread 0 will return after calculating additions",) + ); + assert_eq!( + 39916800, + t1.join() + .expect("thread 1 will return after calculating multiplications",) + ); + } + + #[test] + fn test_send() { + fn assert_send() {} + assert_send::>(); + } + + #[test] + fn test_cloned_eq() { + let a = thunk_hive::<()>(2, true); + assert_eq!(a, a.clone()); + } + + #[test] + /// When a thread joins on a pool, it blocks until all tasks have completed. If a second thread + /// adds tasks to the pool and then joins before all the tasks have completed, both threads + /// will wait for all tasks to complete. However, as soon as all tasks have completed, all + /// joining threads are notified, and the first one to wake will exit the join and increment + /// the phase of the condvar. Subsequent notified threads will then see that the phase has been + /// changed and will wake, even if new tasks have been added in the meantime. + /// + /// In this example, this means the waiting threads will exit the join in groups of four + /// because the waiter pool has four processes. + fn test_join_wavesurfer() { + let n_waves = 4; + let n_workers = 4; + let (tx, rx) = mpsc::channel(); + let builder = OpenBuilder::empty() + .num_threads(n_workers) + .thread_name("join wavesurfer") + .with_channel_queues(); + let waiter_hive = builder + .clone() + .with_worker_default::>() + .build(); + let clock_hive = builder.with_worker_default::>().build(); + + let barrier = Arc::new(Barrier::new(3)); + let wave_counter = Arc::new(AtomicUsize::new(0)); + let clock_thread = { + let barrier = barrier.clone(); + let wave_counter = wave_counter.clone(); + thread::spawn(move || { + barrier.wait(); + for wave_num in 0..n_waves { + let _ = wave_counter.swap(wave_num, Ordering::SeqCst); + thread::sleep(ONE_SEC); + } + }) + }; + + { + let barrier = barrier.clone(); + clock_hive.apply_store(Thunk::of(move || { + barrier.wait(); + // this sleep is for stabilisation on weaker platforms + thread::sleep(Duration::from_millis(100)); + })); + } + + // prepare three waves of tasks (0..=11) + for worker in 0..(3 * n_workers) { + let tx = tx.clone(); + let clock_hive = clock_hive.clone(); + let wave_counter = wave_counter.clone(); + waiter_hive.apply_store(Thunk::of(move || { + let wave_before = wave_counter.load(Ordering::SeqCst); + clock_hive.join(); + // submit tasks for the next wave + clock_hive.apply_store(Thunk::of(|| thread::sleep(ONE_SEC))); + let wave_after = wave_counter.load(Ordering::SeqCst); + tx.send((wave_before, wave_after, worker)).unwrap(); + })); + } + barrier.wait(); + + clock_hive.join(); + + drop(tx); + let mut hist = vec![0; n_waves]; + let mut data = vec![]; + for (before, after, worker) in rx.iter() { + let mut dur = after - before; + if dur >= n_waves - 1 { + dur = n_waves - 1; + } + hist[dur] += 1; + data.push((before, after, worker)); + } + + println!("Histogram of wave duration:"); + for (i, n) in hist.iter().enumerate() { + println!( + "\t{}: {} {}", + i, + n, + &*(0..*n).fold("".to_owned(), |s, _| s + "*") + ); + } + + for (wave_before, wave_after, worker) in data.iter() { + if *worker < n_workers { + assert_eq!(wave_before, wave_after); + } else { + assert!(wave_before < wave_after); + } + } + clock_thread.join().unwrap(); + } + + // cargo-llvm-cov doesn't yet support doctests in stable, so we need to duplicate them in + // unit tests to get coverage + + #[test] + fn doctest_lib_2() { + // create a hive to process `Thunk`s - no-argument closures with the same return type (`i32`) + let hive: THive = ChannelBuilder::empty() + .with_worker_default() + .num_threads(4) + .thread_name("thunk_hive") + .build(); + + // return results to your own channel... + let (tx, rx) = crate::hive::outcome_channel(); + let task_ids = hive.swarm_send((0..10).map(|i: i32| Thunk::of(move || i * i)), &tx); + let outputs: Vec<_> = rx.select_unordered_outputs(task_ids).collect(); + assert_eq!(285, outputs.into_iter().sum()); + + // return results as an iterator... + let outputs2: Vec<_> = hive + .swarm((0..10).map(|i: i32| Thunk::of(move || i * -i))) + .into_outputs() + .collect(); + assert_eq!(-285, outputs2.into_iter().sum()); + } + + #[test] + fn doctest_lib_3() { + #[derive(Debug)] + struct CatWorker { + stdin: ChildStdin, + stdout: BufReader, + } + + impl CatWorker { + fn new(stdin: ChildStdin, stdout: ChildStdout) -> Self { + Self { + stdin, + stdout: BufReader::new(stdout), + } + } + + fn write_char(&mut self, c: u8) -> io::Result { + self.stdin.write_all(&[c])?; + self.stdin.write_all(b"\n")?; + self.stdin.flush()?; + let mut s = String::new(); + self.stdout.read_line(&mut s)?; + s.pop(); // exclude newline + Ok(s) + } + } + + impl Worker for CatWorker { + type Input = u8; + type Output = String; + type Error = io::Error; + + fn apply(&mut self, input: Self::Input, _: &Context) -> WorkerResult { + self.write_char(input).map_err(|error| ApplyError::Fatal { + input: Some(input), + error, + }) + } + } + + #[derive(Default)] + struct CatQueen { + children: Vec, + } + + impl CatQueen { + fn wait_for_all(&mut self) -> Vec> { + self.children + .drain(..) + .map(|mut child| child.wait()) + .collect() + } + } + + impl QueenMut for CatQueen { + type Kind = CatWorker; + + fn create(&mut self) -> Self::Kind { + let mut child = Command::new("cat") + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::inherit()) + .spawn() + .unwrap(); + let stdin = child.stdin.take().unwrap(); + let stdout = child.stdout.take().unwrap(); + self.children.push(child); + CatWorker::new(stdin, stdout) + } + } + + impl Drop for CatQueen { + fn drop(&mut self) { + self.wait_for_all() + .into_iter() + .for_each(|result| match result { + Ok(status) if status.success() => (), + Ok(status) => eprintln!("Child process failed: {}", status), + Err(e) => eprintln!("Error waiting for child process: {}", e), + }) + } + } + + // build the Hive + let hive = ChannelBuilder::empty() + .with_queen_mut_default::() + .num_threads(4) + .build(); + + // prepare inputs + let inputs: Vec = (0..8).map(|i| 97 + i).collect(); + + // execute tasks and collect outputs + let output = hive + .swarm(inputs) + .into_outputs() + .fold(String::new(), |mut a, b| { + a.push_str(&b); + a + }) + .into_bytes(); + + // verify the output - note that `swarm` ensures the outputs are in the same order + // as the inputs + assert_eq!(output, b"abcdefgh"); + + // shutdown the hive, use the Queen to wait on child processes, and report errors + let mut queen = hive.try_into_husk().unwrap().into_parts().0.into_inner(); + let (wait_ok, wait_err): (Vec<_>, Vec<_>) = + queen.wait_for_all().into_iter().partition(Result::is_ok); + if !wait_err.is_empty() { + panic!( + "Error(s) occurred while waiting for child processes: {:?}", + wait_err + ); + } + let exec_err_codes: Vec<_> = wait_ok + .into_iter() + .map(Result::unwrap) + .filter(|status| !status.success()) + .filter_map(|status| status.code()) + .collect(); + if !exec_err_codes.is_empty() { + panic!( + "Child process(es) failed with exit codes: {:?}", + exec_err_codes + ); + } + } +} + +#[cfg(all(test, feature = "affinity"))] +mod affinity_tests { + use crate::bee::stock::{Thunk, ThunkWorker}; + use crate::hive::{Builder, ChannelBuilder}; + + #[test] + fn test_affinity() { + let hive = ChannelBuilder::empty() + .thread_name("affinity example") + .num_threads(2) + .core_affinity(0..2) + .with_worker_default::>() + .build(); + + hive.map_store((0..10).map(move |i| { + Thunk::of(move || { + if let Some(affininty) = core_affinity::get_core_ids() { + eprintln!("task {} on thread with affinity {:?}", i, affininty); + } + }) + })); + } + + #[test] + fn test_use_all_cores() { + let hive = ChannelBuilder::empty() + .thread_name("affinity example") + .with_thread_per_core() + .with_default_core_affinity() + .with_worker_default::>() + .build(); + + hive.map_store((0..num_cpus::get()).map(move |i| { + Thunk::of(move || { + if let Some(affininty) = core_affinity::get_core_ids() { + eprintln!("task {} on thread with affinity {:?}", i, affininty); + } + }) + })); + } +} + +#[cfg(all(test, feature = "batching"))] +mod batching_tests { + use crate::barrier::IndexedBarrier; + use crate::bee::stock::{Thunk, ThunkWorker}; + use crate::bee::DefaultQueen; + use crate::hive::{ + Builder, ChannelBuilder, ChannelTaskQueues, Hive, OutcomeIteratorExt, OutcomeReceiver, + OutcomeSender, + }; + use std::collections::HashMap; + use std::thread::{self, ThreadId}; + use std::time::Duration; + + fn launch_tasks( + hive: &Hive>, ChannelTaskQueues>>, + num_threads: usize, + num_tasks_per_thread: usize, + barrier: &IndexedBarrier, + tx: &OutcomeSender>, + ) -> Vec { + let total_tasks = num_threads * num_tasks_per_thread; + // send the first `num_threads` tasks widely spaced, so each worker thread only gets one + let init_task_ids: Vec<_> = (0..num_threads) + .map(|_| { + let barrier = barrier.clone(); + let task_id = hive.apply_send( + Thunk::of(move || { + barrier.wait(); + thread::sleep(Duration::from_millis(100)); + thread::current().id() + }), + tx, + ); + thread::sleep(Duration::from_millis(100)); + task_id + }) + .collect(); + // send the rest all at once + let rest_task_ids = hive.map_send( + (num_threads..total_tasks).map(|_| { + Thunk::of(move || { + thread::sleep(Duration::from_millis(1)); + thread::current().id() + }) + }), + tx, + ); + init_task_ids.into_iter().chain(rest_task_ids).collect() + } + + fn count_thread_ids( + rx: OutcomeReceiver>, + task_ids: Vec, + ) -> HashMap { + rx.select_unordered_outputs(task_ids) + .fold(HashMap::new(), |mut counter, id| { + *counter.entry(id).or_insert(0) += 1; + counter + }) + } + + fn run_test( + hive: &Hive>, ChannelTaskQueues>>, + num_threads: usize, + batch_size: usize, + ) { + let tasks_per_thread = batch_size + 2; + let (tx, rx) = crate::hive::outcome_channel(); + // each worker should take `batch_size` tasks for its queue + 1 to work on immediately, + // meaning there should be `batch_size + 1` tasks associated with each thread ID + let barrier = IndexedBarrier::new(num_threads); + let task_ids = launch_tasks(hive, num_threads, tasks_per_thread, &barrier, &tx); + // start the first tasks + barrier.wait(); + // wait for all tasks to complete + hive.join(); + let thread_counts = count_thread_ids(rx, task_ids); + assert_eq!(thread_counts.len(), num_threads); + assert!(thread_counts + .values() + .all(|&count| count == tasks_per_thread)); } - // #[test] - // fn test_works() { - // let hive = thunk_hive(TEST_TASKS); - // let (tx, rx) = mpsc::channel(); - // assert_eq!(hive.max_workers(), TEST_TASKS); - // assert_eq!(hive.alive_workers(), TEST_TASKS); - // assert!(!hive.has_dead_workers()); - // for _ in 0..TEST_TASKS { - // let tx = tx.clone(); - // hive.apply_store(Thunk::of(move || { - // tx.send(1).unwrap(); - // })); - // } - // assert_eq!(rx.iter().take(TEST_TASKS).sum::(), TEST_TASKS); - // } - - // #[test] - // fn test_grow_from_zero() { - // let hive = thunk_hive::(0); - // // check that with 0 threads no tasks are scheduled - // let (tx, rx) = super::outcome_channel(); - // let _ = hive.apply_send(Thunk::of(|| 0), &tx); - // thread::sleep(ONE_SEC); - // assert_eq!(hive.num_tasks().0, 1); - // assert!(matches!(rx.try_recv_msg(), Message::ChannelEmpty)); - // hive.grow(1).expect("error spawning threads"); - // thread::sleep(ONE_SEC); - // assert_eq!(hive.num_tasks().0, 0); - // assert!(matches!( - // rx.try_recv_msg(), - // Message::Received(Outcome::Success { value: 0, .. }) - // )); - // } - - // #[test] - // fn test_grow() { - // let hive: ThunkHive<()> = Builder::new().num_threads(TEST_TASKS).build_with_default(); - // // queue some long-running tasks - // for _ in 0..TEST_TASKS { - // hive.apply_store(Thunk::of(|| thread::sleep(LONG_TASK))); - // } - // thread::sleep(ONE_SEC); - // assert_eq!(hive.num_tasks().1, TEST_TASKS as u64); - // // increase the number of threads - // let new_threads = 4; - // let total_threads = new_threads + TEST_TASKS; - // hive.grow(new_threads).expect("error spawning threads"); - // // queue some more long-running tasks - // for _ in 0..new_threads { - // hive.apply_store(Thunk::of(|| thread::sleep(LONG_TASK))); - // } - // thread::sleep(ONE_SEC); - // assert_eq!(hive.num_tasks().1, total_threads as u64); - // let husk = hive.try_into_husk().unwrap(); - // assert_eq!(husk.iter_successes().count(), total_threads); - // } - - // #[test] - // fn test_suspend() { - // let hive: ThunkHive<()> = Builder::new().num_threads(TEST_TASKS).build_with_default(); - // // queue some long-running tasks - // let total_tasks = 2 * TEST_TASKS; - // for _ in 0..total_tasks { - // hive.apply_store(Thunk::of(|| thread::sleep(SHORT_TASK))); - // } - // thread::sleep(ONE_SEC); - // assert_eq!(hive.num_tasks(), (TEST_TASKS as u64, TEST_TASKS as u64)); - // hive.suspend(); - // // active tasks should finish but no more tasks should be started - // thread::sleep(SHORT_TASK); - // assert_eq!(hive.num_tasks(), (TEST_TASKS as u64, 0)); - // assert_eq!(hive.num_successes(), TEST_TASKS); - // hive.resume(); - // // new tasks should start - // thread::sleep(ONE_SEC); - // assert_eq!(hive.num_tasks(), (0, TEST_TASKS as u64)); - // thread::sleep(SHORT_TASK); - // // all tasks should be completed - // assert_eq!(hive.num_tasks(), (0, 0)); - // assert_eq!(hive.num_successes(), total_tasks); - // } - - // #[derive(Debug, Default)] - // struct MyRefWorker; - - // impl RefWorker for MyRefWorker { - // type Input = u8; - // type Output = u8; - // type Error = (); - - // fn apply_ref( - // &mut self, - // input: &Self::Input, - // ctx: &Context, - // ) -> RefWorkerResult { - // for _ in 0..3 { - // thread::sleep(Duration::from_secs(1)); - // if ctx.is_cancelled() { - // return Err(ApplyRefError::Cancelled); - // } - // } - // Ok(*input) - // } - // } - - // #[test] - // fn test_suspend_with_cancelled_tasks() { - // let hive = Builder::new() - // .num_threads(TEST_TASKS) - // .build_with_default::, Local<_>>(); - // hive.swarm_store(0..TEST_TASKS as u8); - // hive.suspend(); - // // wait for tasks to be cancelled - // thread::sleep(Duration::from_secs(2)); - // hive.resume_store(); - // thread::sleep(Duration::from_secs(1)); - // // unprocessed tasks should be requeued - // assert_eq!(hive.num_tasks().1, TEST_TASKS as u64); - // thread::sleep(Duration::from_secs(3)); - // assert_eq!(hive.num_successes(), TEST_TASKS); - // } - - // #[test] - // fn test_num_tasks_active() { - // let hive: ThunkHive<()> = Builder::new().num_threads(TEST_TASKS).build_with_default(); - // for _ in 0..2 * TEST_TASKS { - // hive.apply_store(Thunk::of(|| loop { - // thread::sleep(LONG_TASK) - // })); - // } - // thread::sleep(ONE_SEC); - // assert_eq!(hive.num_tasks().1, TEST_TASKS as u64); - // let num_threads = hive.max_workers(); - // assert_eq!(num_threads, TEST_TASKS); - // } - - // #[test] - // fn test_all_threads() { - // let hive = Builder::new() - // .with_thread_per_core() - // .build_with_default::, Global<_>, Local<_>>(); - // let num_threads = num_cpus::get(); - // for _ in 0..num_threads { - // hive.apply_store(Thunk::of(|| loop { - // thread::sleep(LONG_TASK) - // })); - // } - // thread::sleep(ONE_SEC); - // assert_eq!(hive.num_tasks().1, num_threads as u64); - // let num_threads = hive.max_workers(); - // assert_eq!(num_threads, num_threads); - // } - - // #[test] - // fn test_panic() { - // let hive = thunk_hive(TEST_TASKS); - // let (tx, _) = super::outcome_channel(); - // // Panic all the existing threads. - // for _ in 0..TEST_TASKS { - // hive.apply_send(Thunk::of(|| panic!("intentional panic")), &tx); - // } - // hive.join(); - // // Ensure that none of the threads have panicked - // assert_eq!(hive.num_panics(), TEST_TASKS); - // let husk = hive.try_into_husk().unwrap(); - // assert_eq!(husk.num_panics(), TEST_TASKS); - // } - - // #[test] - // fn test_catch_panic() { - // let hive: Hive<_, _, Global<_>, Local<_>> = Builder::new() - // .num_threads(TEST_TASKS) - // .build_with(RefCaller::of(|_: &u8| -> Result { - // panic!("intentional panic") - // })); - // let (tx, rx) = super::outcome_channel(); - // // Panic all the existing threads. - // for i in 0..TEST_TASKS { - // hive.apply_send(i as u8, &tx); - // } - // hive.join(); - // // Ensure that none of the threads have panicked - // assert_eq!(hive.num_panics(), 0); - // // Check that all the results are Outcome::Panic - // for outcome in rx.into_iter().take(TEST_TASKS) { - // assert!(matches!(outcome, Outcome::Panic { .. })); - // } - // } - - // #[test] - // fn test_should_not_panic_on_drop_if_subtasks_panic_after_drop() { - // let hive: ThunkHive<()> = Builder::new().num_threads(TEST_TASKS).build_with_default(); - // let waiter = Arc::new(Barrier::new(TEST_TASKS + 1)); - // let waiter_count = Arc::new(AtomicUsize::new(0)); - - // // panic all the existing threads in a bit - // for _ in 0..TEST_TASKS { - // let waiter = waiter.clone(); - // let waiter_count = waiter_count.clone(); - // hive.apply_store(Thunk::of(move || { - // waiter_count.fetch_add(1, Ordering::SeqCst); - // waiter.wait(); - // panic!("intentional panic"); - // })); - // } - - // // queued tasks will not be processed after the hive is dropped, so we need to wait to make - // // sure that all tasks have started and are waiting on the barrier - // // TODO: find a Barrier implementation with try_wait() semantics - // thread::sleep(Duration::from_secs(1)); - // assert_eq!(waiter_count.load(Ordering::SeqCst), TEST_TASKS); - - // drop(hive); - - // // unblock the tasks and allow them to panic - // waiter.wait(); - // } - - // #[test] - // fn test_massive_task_creation() { - // let test_tasks = 4_200_000; - - // let hive = thunk_hive(TEST_TASKS); - // let b0 = IndexedBarrier::new(TEST_TASKS); - // let b1 = IndexedBarrier::new(TEST_TASKS); - - // let (tx, rx) = mpsc::channel(); - - // for _ in 0..test_tasks { - // let tx = tx.clone(); - // let (b0, b1) = (b0.clone(), b1.clone()); - - // hive.apply_store(Thunk::of(move || { - // // Wait until the pool has been filled once. - // b0.wait(); - // // wait so the pool can be measured - // b1.wait(); - // assert!(tx.send(1).is_ok()); - // })); - // } - - // b0.wait(); - // assert_eq!(hive.num_tasks().1, TEST_TASKS as u64); - // b1.wait(); - - // assert_eq!(rx.iter().take(test_tasks).sum::(), test_tasks); - // hive.join(); - - // let atomic_num_tasks_active = hive.num_tasks().1; - // assert!( - // atomic_num_tasks_active == 0, - // "atomic_num_tasks_active: {}", - // atomic_num_tasks_active - // ); - // } - - // #[test] - // fn test_name() { - // let name = "test"; - // let hive = Builder::new() - // .thread_name(name.to_owned()) - // .num_threads(2) - // .build_with_default::, Global<_>, Local<_>>(); - // let (tx, rx) = mpsc::channel(); - - // // initial thread should share the name "test" - // for _ in 0..2 { - // let tx = tx.clone(); - // hive.apply_store(Thunk::of(move || { - // let name = thread::current().name().unwrap().to_owned(); - // tx.send(name).unwrap(); - // })); - // } - - // // new spawn thread should share the name "test" too. - // hive.grow(3).expect("error spawning threads"); - // let tx_clone = tx.clone(); - // hive.apply_store(Thunk::of(move || { - // let name = thread::current().name().unwrap().to_owned(); - // tx_clone.send(name).unwrap(); - // })); - - // for thread_name in rx.iter().take(3) { - // assert_eq!(name, thread_name); - // } - // } - - // #[test] - // fn test_stack_size() { - // let stack_size = 4_000_000; - - // let hive = Builder::new() - // .num_threads(1) - // .thread_stack_size(stack_size) - // .build_with_default::, Global<_>, Local<_>>(); - - // let actual_stack_size = hive - // .apply(Thunk::of(|| { - // //println!("This thread has a 4 MB stack size!"); - // stacker::remaining_stack().unwrap() - // })) - // .unwrap() as f64; - - // // measured value should be within 1% of actual - // assert!(actual_stack_size > (stack_size as f64 * 0.99)); - // assert!(actual_stack_size < (stack_size as f64 * 1.01)); - // } - - // #[test] - // fn test_debug() { - // let hive = thunk_hive::<()>(4); - // let debug = format!("{:?}", hive); - // assert_eq!( - // debug, - // "Hive { task_tx: Sender { .. }, shared: Shared { name: None, num_threads: 4, num_tasks_queued: 0, num_tasks_active: 0 } }" - // ); - - // let hive = Builder::new() - // .thread_name("hello") - // .num_threads(4) - // .build_with_default::, Global<_>, Local<_>>(); - // let debug = format!("{:?}", hive); - // assert_eq!( - // debug, - // "Hive { task_tx: Sender { .. }, shared: Shared { name: \"hello\", num_threads: 4, num_tasks_queued: 0, num_tasks_active: 0 } }" - // ); - - // let hive = thunk_hive(4); - // hive.apply_store(Thunk::of(|| thread::sleep(LONG_TASK))); - // thread::sleep(ONE_SEC); - // let debug = format!("{:?}", hive); - // assert_eq!( - // debug, - // "Hive { task_tx: Sender { .. }, shared: Shared { name: None, num_threads: 4, num_tasks_queued: 0, num_tasks_active: 1 } }" - // ); - // } - - // #[test] - // fn test_repeated_join() { - // let hive = Builder::new() - // .thread_name("repeated join test") - // .num_threads(8) - // .build_with_default::, Global<_>, Local<_>>(); - // let test_count = Arc::new(AtomicUsize::new(0)); - - // for _ in 0..42 { - // let test_count = test_count.clone(); - // hive.apply_store(Thunk::of(move || { - // thread::sleep(SHORT_TASK); - // test_count.fetch_add(1, Ordering::Release); - // })); - // } - - // hive.join(); - // assert_eq!(42, test_count.load(Ordering::Acquire)); - - // for _ in 0..42 { - // let test_count = test_count.clone(); - // hive.apply_store(Thunk::of(move || { - // thread::sleep(SHORT_TASK); - // test_count.fetch_add(1, Ordering::Relaxed); - // })); - // } - // hive.join(); - // assert_eq!(84, test_count.load(Ordering::Relaxed)); - // } - - // #[test] - // fn test_multi_join() { - // // Toggle the following lines to debug the deadlock - // // fn error(_s: String) { - // // use ::std::io::Write; - // // let stderr = ::std::io::stderr(); - // // let mut stderr = stderr.lock(); - // // stderr - // // .write(&_s.as_bytes()) - // // .expect("Failed to write to stderr"); - // // } - - // let hive0 = Builder::new() - // .thread_name("multi join pool0") - // .num_threads(4) - // .build_with_default::, Global<_>, Local<_>>(); - // let hive1 = Builder::new() - // .thread_name("multi join pool1") - // .num_threads(4) - // .build_with_default::, Global<_>, Local<_>>(); - // let (tx, rx) = crate::channel::channel(); - - // for i in 0..8 { - // let hive1_clone = hive1.clone(); - // let hive0_clone = hive0.clone(); - // let tx = tx.clone(); - // hive0.apply_store(Thunk::of(move || { - // hive1_clone.apply_store(Thunk::of(move || { - // //error(format!("p1: {} -=- {:?}\n", i, hive0_clone)); - // hive0_clone.join(); - // // ensure that the main thread has a chance to execute - // thread::sleep(Duration::from_millis(10)); - // //error(format!("p1: send({})\n", i)); - // tx.send(i).expect("send failed from hive1_clone to main"); - // })); - // //error(format!("p0: {}\n", i)); - // })); - // } - // drop(tx); - - // // no hive1 task should be completed yet, so the channel should be empty - // let before_any_send = rx.try_recv_msg(); - // assert!(matches!(before_any_send, Message::ChannelEmpty)); - // //error(format!("{:?}\n{:?}\n", hive0, hive1)); - // hive0.join(); - // //error(format!("pool0.join() complete =-= {:?}", hive1)); - // hive1.join(); - // //error("pool1.join() complete\n".into()); - // assert_eq!(rx.into_iter().sum::(), (0..8).sum()); - // } - - // #[test] - // fn test_empty_hive() { - // // Joining an empty hive must return imminently - // let hive = thunk_hive::<()>(4); - // hive.join(); - // } - - // #[test] - // fn test_no_fun_or_joy() { - // // What happens when you keep adding tasks after a join - - // fn sleepy_function() { - // thread::sleep(LONG_TASK); - // } - - // let hive = Builder::new() - // .thread_name("no fun or joy") - // .num_threads(8) - // .build_with_default::, Global<_>, Local<_>>(); - - // hive.apply_store(Thunk::of(sleepy_function)); - - // let p_t = hive.clone(); - // thread::spawn(move || { - // (0..23) - // .inspect(|_| { - // p_t.apply_store(Thunk::of(sleepy_function)); - // }) - // .count(); - // }); - - // hive.join(); - // } - - // #[test] - // fn test_map() { - // let hive = Builder::new() - // .num_threads(2) - // .build_with_default::, Global<_>, Local<_>>(); - // let outputs: Vec<_> = hive - // .map((0..10u8).map(|i| { - // Thunk::of(move || { - // thread::sleep(Duration::from_millis((10 - i as u64) * 100)); - // i - // }) - // })) - // .map(Outcome::unwrap) - // .collect(); - // assert_eq!(outputs, (0..10).collect::>()) - // } - - // #[test] - // fn test_map_unordered() { - // let hive = Builder::new() - // .num_threads(8) - // .build_with_default::, Global<_>, Local<_>>(); - // let outputs: Vec<_> = hive - // .map_unordered((0..8u8).map(|i| { - // Thunk::of(move || { - // thread::sleep(Duration::from_millis((8 - i as u64) * 100)); - // i - // }) - // })) - // .map(Outcome::unwrap) - // .collect(); - // assert_eq!(outputs, (0..8).rev().collect::>()) - // } - - // #[test] - // fn test_map_send() { - // let hive = Builder::new() - // .num_threads(8) - // .build_with_default::, Global<_>, Local<_>>(); - // let (tx, rx) = super::outcome_channel(); - // let mut task_ids = hive.map_send( - // (0..8u8).map(|i| { - // Thunk::of(move || { - // thread::sleep(Duration::from_millis((8 - i as u64) * 100)); - // i - // }) - // }), - // &tx, - // ); - // let (mut outcome_task_ids, values): (Vec, Vec) = rx - // .iter() - // .map(|outcome| match outcome { - // Outcome::Success { value, task_id } => (task_id, value), - // _ => panic!("unexpected error"), - // }) - // .unzip(); - // assert_eq!(values, (0..8).rev().collect::>()); - // task_ids.sort(); - // outcome_task_ids.sort(); - // assert_eq!(task_ids, outcome_task_ids); - // } - - // #[test] - // fn test_map_store() { - // let mut hive = Builder::new() - // .num_threads(8) - // .build_with_default::, Global<_>, Local<_>>(); - // let mut task_ids = hive.map_store((0..8u8).map(|i| { - // Thunk::of(move || { - // thread::sleep(Duration::from_millis((8 - i as u64) * 100)); - // i - // }) - // })); - // hive.join(); - // for i in task_ids.iter() { - // assert!(hive.outcomes_deref().get(i).unwrap().is_success()); - // } - // let (mut outcome_task_ids, values): (Vec, Vec) = task_ids - // .clone() - // .into_iter() - // .map(|i| (i, hive.remove_success(i).unwrap())) - // .collect(); - // assert_eq!(values, (0..8).collect::>()); - // task_ids.sort(); - // outcome_task_ids.sort(); - // assert_eq!(task_ids, outcome_task_ids); - // } - - // #[test] - // fn test_swarm() { - // let hive = Builder::new() - // .num_threads(2) - // .build_with_default::, Global<_>, Local<_>>(); - // let outputs: Vec<_> = hive - // .swarm((0..10u8).map(|i| { - // Thunk::of(move || { - // thread::sleep(Duration::from_millis((10 - i as u64) * 100)); - // i - // }) - // })) - // .map(Outcome::unwrap) - // .collect(); - // assert_eq!(outputs, (0..10).collect::>()) - // } - - // #[test] - // fn test_swarm_unordered() { - // let hive = Builder::new() - // .num_threads(8) - // .build_with_default::, Global<_>, Local<_>>(); - // let outputs: Vec<_> = hive - // .swarm_unordered((0..8u8).map(|i| { - // Thunk::of(move || { - // thread::sleep(Duration::from_millis((8 - i as u64) * 100)); - // i - // }) - // })) - // .map(Outcome::unwrap) - // .collect(); - // assert_eq!(outputs, (0..8).rev().collect::>()) - // } - - // #[test] - // fn test_swarm_send() { - // let hive = Builder::new() - // .num_threads(8) - // .build_with_default::, Global<_>, Local<_>>(); - // let (tx, rx) = super::outcome_channel(); - // let mut task_ids = hive.swarm_send( - // (0..8u8).map(|i| { - // Thunk::of(move || { - // thread::sleep(Duration::from_millis((8 - i as u64) * 100)); - // i - // }) - // }), - // &tx, - // ); - // let (mut outcome_task_ids, values): (Vec, Vec) = rx - // .iter() - // .map(|outcome| match outcome { - // Outcome::Success { value, task_id } => (task_id, value), - // _ => panic!("unexpected error"), - // }) - // .unzip(); - // assert_eq!(values, (0..8).rev().collect::>()); - // task_ids.sort(); - // outcome_task_ids.sort(); - // assert_eq!(task_ids, outcome_task_ids); - // } - - // #[test] - // fn test_swarm_store() { - // let mut hive = Builder::new() - // .num_threads(8) - // .build_with_default::, Global<_>, Local<_>>(); - // let mut task_ids = hive.swarm_store((0..8u8).map(|i| { - // Thunk::of(move || { - // thread::sleep(Duration::from_millis((8 - i as u64) * 100)); - // i - // }) - // })); - // hive.join(); - // for i in task_ids.iter() { - // assert!(hive.outcomes_deref().get(i).unwrap().is_success()); - // } - // let (mut outcome_task_ids, values): (Vec, Vec) = task_ids - // .clone() - // .into_iter() - // .map(|i| (i, hive.remove_success(i).unwrap())) - // .collect(); - // assert_eq!(values, (0..8).collect::>()); - // task_ids.sort(); - // outcome_task_ids.sort(); - // assert_eq!(task_ids, outcome_task_ids); - // } - - // #[test] - // fn test_scan() { - // let hive: Hive<_, _, Global<_>, Local<_>> = Builder::new() - // .num_threads(4) - // .build_with(Caller::of(|i| i * i)); - // let (outputs, state) = hive.scan(0..10, 0, |acc, i| { - // *acc += i; - // *acc - // }); - // let mut outputs = outputs.unwrap(); - // outputs.sort(); - // assert_eq!(outputs.len(), 10); - // assert_eq!(state, 45); - // assert_eq!( - // outputs, - // (0..10) - // .scan(0, |acc, i| { - // *acc += i; - // Some(*acc) - // }) - // .map(|i| i * i) - // .collect::>() - // ); - // } - - // #[test] - // fn test_scan_send() { - // let hive: Hive<_, _, Global<_>, Local<_>> = Builder::new() - // .num_threads(4) - // .build_with(Caller::of(|i| i * i)); - // let (tx, rx) = super::outcome_channel(); - // let (mut task_ids, state) = hive.scan_send(0..10, &tx, 0, |acc, i| { - // *acc += i; - // *acc - // }); - // assert_eq!(task_ids.len(), 10); - // assert_eq!(state, 45); - // let (mut outcome_task_ids, mut values): (Vec, Vec) = rx - // .iter() - // .map(|outcome| match outcome { - // Outcome::Success { value, task_id } => (task_id, value), - // _ => panic!("unexpected error"), - // }) - // .unzip(); - // values.sort(); - // assert_eq!( - // values, - // (0..10) - // .scan(0, |acc, i| { - // *acc += i; - // Some(*acc) - // }) - // .map(|i| i * i) - // .collect::>() - // ); - // task_ids.sort(); - // outcome_task_ids.sort(); - // assert_eq!(task_ids, outcome_task_ids); - // } - - // #[test] - // fn test_try_scan_send() { - // let hive: Hive<_, _, Global<_>, Local<_>> = Builder::new() - // .num_threads(4) - // .build_with(Caller::of(|i| i * i)); - // let (tx, rx) = super::outcome_channel(); - // let (results, state) = hive.try_scan_send(0..10, &tx, 0, |acc, i| { - // *acc += i; - // Ok::<_, String>(*acc) - // }); - // let mut task_ids: Vec<_> = results.into_iter().map(Result::unwrap).collect(); - // assert_eq!(task_ids.len(), 10); - // assert_eq!(state, 45); - // let (mut outcome_task_ids, mut values): (Vec, Vec) = rx - // .iter() - // .map(|outcome| match outcome { - // Outcome::Success { value, task_id } => (task_id, value), - // _ => panic!("unexpected error"), - // }) - // .unzip(); - // values.sort(); - // assert_eq!( - // values, - // (0..10) - // .scan(0, |acc, i| { - // *acc += i; - // Some(*acc) - // }) - // .map(|i| i * i) - // .collect::>() - // ); - // task_ids.sort(); - // outcome_task_ids.sort(); - // assert_eq!(task_ids, outcome_task_ids); - // } - - // #[test] - // #[should_panic] - // fn test_try_scan_send_fail() { - // let hive: Hive<_, _, Global<_>, Local<_>> = Builder::new() - // .num_threads(4) - // .build_with(OnceCaller::of(|i: i32| Ok::<_, String>(i * i))); - // let (tx, _) = super::outcome_channel(); - // let _ = hive - // .try_scan_send(0..10, &tx, 0, |_, _| Err("fail")) - // .0 - // .into_iter() - // .map(Result::unwrap) - // .collect::>(); - // } - - // #[test] - // fn test_scan_store() { - // let mut hive: Hive<_, _, Global<_>, Local<_>> = Builder::new() - // .num_threads(4) - // .build_with(Caller::of(|i| i * i)); - // let (mut task_ids, state) = hive.scan_store(0..10, 0, |acc, i| { - // *acc += i; - // *acc - // }); - // assert_eq!(task_ids.len(), 10); - // assert_eq!(state, 45); - // hive.join(); - // for i in task_ids.iter() { - // assert!(hive.outcomes_deref().get(i).unwrap().is_success()); - // } - // let (mut outcome_task_ids, values): (Vec, Vec) = task_ids - // .clone() - // .into_iter() - // .map(|i| (i, hive.remove_success(i).unwrap())) - // .unzip(); - // assert_eq!( - // values, - // (0..10) - // .scan(0, |acc, i| { - // *acc += i; - // Some(*acc) - // }) - // .map(|i| i * i) - // .collect::>() - // ); - // task_ids.sort(); - // outcome_task_ids.sort(); - // assert_eq!(task_ids, outcome_task_ids); - // } - - // #[test] - // fn test_try_scan_store() { - // let mut hive: Hive<_, _, Global<_>, Local<_>> = Builder::new() - // .num_threads(4) - // .build_with(Caller::of(|i| i * i)); - // let (results, state) = hive.try_scan_store(0..10, 0, |acc, i| { - // *acc += i; - // Ok::(*acc) - // }); - // let mut task_ids: Vec<_> = results.into_iter().map(Result::unwrap).collect(); - // assert_eq!(task_ids.len(), 10); - // assert_eq!(state, 45); - // hive.join(); - // for i in task_ids.iter() { - // assert!(hive.outcomes_deref().get(i).unwrap().is_success()); - // } - // let (mut outcome_task_ids, values): (Vec, Vec) = task_ids - // .clone() - // .into_iter() - // .map(|i| (i, hive.remove_success(i).unwrap())) - // .unzip(); - // assert_eq!( - // values, - // (0..10) - // .scan(0, |acc, i| { - // *acc += i; - // Some(*acc) - // }) - // .map(|i| i * i) - // .collect::>() - // ); - // task_ids.sort(); - // outcome_task_ids.sort(); - // assert_eq!(task_ids, outcome_task_ids); - // } - - // #[test] - // #[should_panic] - // fn test_try_scan_store_fail() { - // let hive: Hive<_, _, Global<_>, Local<_>> = Builder::new() - // .num_threads(4) - // .build_with(OnceCaller::of(|i: i32| Ok::(i * i))); - // let _ = hive - // .try_scan_store(0..10, 0, |_, _| Err("fail")) - // .0 - // .into_iter() - // .map(Result::unwrap) - // .collect::>(); - // } - - // #[test] - // fn test_husk() { - // let hive1 = Builder::new() - // .num_threads(8) - // .build_with_default::, Global<_>, Local<_>>(); - // let task_ids = hive1.map_store((0..8u8).map(|i| Thunk::of(move || i))); - // hive1.join(); - // let mut husk1 = hive1.try_into_husk().unwrap(); - // for i in task_ids.iter() { - // assert!(husk1.outcomes_deref().get(i).unwrap().is_success()); - // assert!(matches!(husk1.get(*i), Some(Outcome::Success { .. }))); - // } - - // let builder = husk1.as_builder(); - // let hive2 = builder - // .num_threads(4) - // .build_with_default::, Global<_>, Local<_>>(); - // hive2.map_store((0..8u8).map(|i| { - // Thunk::of(move || { - // thread::sleep(Duration::from_millis((8 - i as u64) * 100)); - // i - // }) - // })); - // hive2.join(); - // let mut husk2 = hive2.try_into_husk().unwrap(); - - // let mut outputs1 = husk1 - // .remove_all() - // .into_iter() - // .map(Outcome::unwrap) - // .collect::>(); - // outputs1.sort(); - // let mut outputs2 = husk2 - // .remove_all() - // .into_iter() - // .map(Outcome::unwrap) - // .collect::>(); - // outputs2.sort(); - // assert_eq!(outputs1, outputs2); - - // let hive3 = husk1.into_hive::, Local<_>>(); - // hive3.map_store((0..8u8).map(|i| { - // Thunk::of(move || { - // thread::sleep(Duration::from_millis((8 - i as u64) * 100)); - // i - // }) - // })); - // hive3.join(); - // let husk3 = hive3.try_into_husk().unwrap(); - // let (_, outcomes3) = husk3.into_parts(); - // let mut outputs3 = outcomes3 - // .into_iter() - // .map(Outcome::unwrap) - // .collect::>(); - // outputs3.sort(); - // assert_eq!(outputs1, outputs3); - // } - - // #[test] - // fn test_clone() { - // let hive = Builder::new() - // .thread_name("clone example") - // .num_threads(2) - // .build_with_default::, Global<_>, Local<_>>(); - - // // This batch of tasks will occupy the pool for some time - // for _ in 0..6 { - // hive.apply_store(Thunk::of(|| { - // thread::sleep(SHORT_TASK); - // })); - // } - - // // The following tasks will be inserted into the pool in a random fashion - // let t0 = { - // let hive = hive.clone(); - // thread::spawn(move || { - // // wait for the first batch of tasks to finish - // hive.join(); - - // let (tx, rx) = mpsc::channel(); - // for i in 0..42 { - // let tx = tx.clone(); - // hive.apply_store(Thunk::of(move || { - // tx.send(i).expect("channel will be waiting"); - // })); - // } - // drop(tx); - // rx.iter().sum::() - // }) - // }; - // let t1 = { - // let pool = hive.clone(); - // thread::spawn(move || { - // // wait for the first batch of tasks to finish - // pool.join(); - - // let (tx, rx) = mpsc::channel(); - // for i in 1..12 { - // let tx = tx.clone(); - // pool.apply_store(Thunk::of(move || { - // tx.send(i).expect("channel will be waiting"); - // })); - // } - // drop(tx); - // rx.iter().product::() - // }) - // }; - - // assert_eq!( - // 861, - // t0.join() - // .expect("thread 0 will return after calculating additions",) - // ); - // assert_eq!( - // 39916800, - // t1.join() - // .expect("thread 1 will return after calculating multiplications",) - // ); - // } - - // type VoidThunkWorker = ThunkWorker<()>; - // type VoidThunkWorkerHive = Hive< - // VoidThunkWorker, - // DefaultQueen, - // ChannelGlobalQueue, - // DefaultLocalQueues>, - // >; - - // #[test] - // fn test_send() { - // fn assert_send() {} - // assert_send::(); - // } - - // #[test] - // fn test_cloned_eq() { - // let a = thunk_hive::<()>(2); - // assert_eq!(a, a.clone()); - // } - - // #[test] - // /// When a thread joins on a pool, it blocks until all tasks have completed. If a second thread - // /// adds tasks to the pool and then joins before all the tasks have completed, both threads - // /// will wait for all tasks to complete. However, as soon as all tasks have completed, all - // /// joining threads are notified, and the first one to wake will exit the join and increment - // /// the phase of the condvar. Subsequent notified threads will then see that the phase has been - // /// changed and will wake, even if new tasks have been added in the meantime. - // /// - // /// In this example, this means the waiting threads will exit the join in groups of four - // /// because the waiter pool has four processes. - // fn test_join_wavesurfer() { - // let n_waves = 4; - // let n_workers = 4; - // let (tx, rx) = mpsc::channel(); - // let builder = Builder::new() - // .num_threads(n_workers) - // .thread_name("join wavesurfer"); - // let waiter_hive = builder - // .clone() - // .build_with_default::, Global<_>, Local<_>>(); - // let clock_hive = builder.build_with_default::, Global<_>, Local<_>>(); - - // let barrier = Arc::new(Barrier::new(3)); - // let wave_counter = Arc::new(AtomicUsize::new(0)); - // let clock_thread = { - // let barrier = barrier.clone(); - // let wave_counter = wave_counter.clone(); - // thread::spawn(move || { - // barrier.wait(); - // for wave_num in 0..n_waves { - // let _ = wave_counter.swap(wave_num, Ordering::SeqCst); - // thread::sleep(ONE_SEC); - // } - // }) - // }; - - // { - // let barrier = barrier.clone(); - // clock_hive.apply_store(Thunk::of(move || { - // barrier.wait(); - // // this sleep is for stabilisation on weaker platforms - // thread::sleep(Duration::from_millis(100)); - // })); - // } - - // // prepare three waves of tasks (0..=11) - // for worker in 0..(3 * n_workers) { - // let tx = tx.clone(); - // let clock_hive = clock_hive.clone(); - // let wave_counter = wave_counter.clone(); - // waiter_hive.apply_store(Thunk::of(move || { - // let wave_before = wave_counter.load(Ordering::SeqCst); - // clock_hive.join(); - // // submit tasks for the next wave - // clock_hive.apply_store(Thunk::of(|| thread::sleep(ONE_SEC))); - // let wave_after = wave_counter.load(Ordering::SeqCst); - // tx.send((wave_before, wave_after, worker)).unwrap(); - // })); - // } - // barrier.wait(); - - // clock_hive.join(); - - // drop(tx); - // let mut hist = vec![0; n_waves]; - // let mut data = vec![]; - // for (before, after, worker) in rx.iter() { - // let mut dur = after - before; - // if dur >= n_waves - 1 { - // dur = n_waves - 1; - // } - // hist[dur] += 1; - // data.push((before, after, worker)); - // } - - // println!("Histogram of wave duration:"); - // for (i, n) in hist.iter().enumerate() { - // println!( - // "\t{}: {} {}", - // i, - // n, - // &*(0..*n).fold("".to_owned(), |s, _| s + "*") - // ); - // } - - // for (wave_before, wave_after, worker) in data.iter() { - // if *worker < n_workers { - // assert_eq!(wave_before, wave_after); - // } else { - // assert!(wave_before < wave_after); - // } - // } - // clock_thread.join().unwrap(); - // } - - // // cargo-llvm-cov doesn't yet support doctests in stable, so we need to duplicate them in - // // unit tests to get coverage - - // #[test] - // fn doctest_lib_2() { - // // create a hive to process `Thunk`s - no-argument closures with the same return type (`i32`) - // let hive = Builder::new() - // .num_threads(4) - // .thread_name("thunk_hive") - // .build_with_default::, Global<_>, Local<_>>(); - - // // return results to your own channel... - // let (tx, rx) = crate::hive::outcome_channel(); - // let task_ids = hive.swarm_send((0..10).map(|i: i32| Thunk::of(move || i * i)), &tx); - // let outputs: Vec<_> = rx.select_unordered_outputs(task_ids).collect(); - // assert_eq!(285, outputs.into_iter().sum()); - - // // return results as an iterator... - // let outputs2: Vec<_> = hive - // .swarm((0..10).map(|i: i32| Thunk::of(move || i * -i))) - // .into_outputs() - // .collect(); - // assert_eq!(-285, outputs2.into_iter().sum()); - // } - - // #[test] - // fn doctest_lib_3() { - // #[derive(Debug)] - // struct CatWorker { - // stdin: ChildStdin, - // stdout: BufReader, - // } - - // impl CatWorker { - // fn new(stdin: ChildStdin, stdout: ChildStdout) -> Self { - // Self { - // stdin, - // stdout: BufReader::new(stdout), - // } - // } - - // fn write_char(&mut self, c: u8) -> io::Result { - // self.stdin.write_all(&[c])?; - // self.stdin.write_all(b"\n")?; - // self.stdin.flush()?; - // let mut s = String::new(); - // self.stdout.read_line(&mut s)?; - // s.pop(); // exclude newline - // Ok(s) - // } - // } - - // impl Worker for CatWorker { - // type Input = u8; - // type Output = String; - // type Error = io::Error; - - // fn apply(&mut self, input: Self::Input, _: &Context) -> WorkerResult { - // self.write_char(input).map_err(|error| ApplyError::Fatal { - // input: Some(input), - // error, - // }) - // } - // } - - // #[derive(Default)] - // struct CatQueen { - // children: Vec, - // } - - // impl CatQueen { - // fn wait_for_all(&mut self) -> Vec> { - // self.children - // .drain(..) - // .map(|mut child| child.wait()) - // .collect() - // } - // } - - // impl QueenMut for CatQueen { - // type Kind = CatWorker; - - // fn create(&mut self) -> Self::Kind { - // let mut child = Command::new("cat") - // .stdin(Stdio::piped()) - // .stdout(Stdio::piped()) - // .stderr(Stdio::inherit()) - // .spawn() - // .unwrap(); - // let stdin = child.stdin.take().unwrap(); - // let stdout = child.stdout.take().unwrap(); - // self.children.push(child); - // CatWorker::new(stdin, stdout) - // } - // } - - // impl Drop for CatQueen { - // fn drop(&mut self) { - // self.wait_for_all() - // .into_iter() - // .for_each(|result| match result { - // Ok(status) if status.success() => (), - // Ok(status) => eprintln!("Child process failed: {}", status), - // Err(e) => eprintln!("Error waiting for child process: {}", e), - // }) - // } - // } - - // // build the Hive - // let hive = Builder::new() - // .num_threads(4) - // .build_default::, Global<_>, Local<_>>(); - - // // prepare inputs - // let inputs: Vec = (0..8).map(|i| 97 + i).collect(); - - // // execute tasks and collect outputs - // let output = hive - // .swarm(inputs) - // .into_outputs() - // .fold(String::new(), |mut a, b| { - // a.push_str(&b); - // a - // }) - // .into_bytes(); - - // // verify the output - note that `swarm` ensures the outputs are in the same order - // // as the inputs - // assert_eq!(output, b"abcdefgh"); - - // // shutdown the hive, use the Queen to wait on child processes, and report errors - // let mut queen = hive.try_into_husk().unwrap().into_parts().0.into_inner(); - // let (wait_ok, wait_err): (Vec<_>, Vec<_>) = - // queen.wait_for_all().into_iter().partition(Result::is_ok); - // if !wait_err.is_empty() { - // panic!( - // "Error(s) occurred while waiting for child processes: {:?}", - // wait_err - // ); - // } - // let exec_err_codes: Vec<_> = wait_ok - // .into_iter() - // .map(Result::unwrap) - // .filter(|status| !status.success()) - // .filter_map(|status| status.code()) - // .collect(); - // if !exec_err_codes.is_empty() { - // panic!( - // "Child process(es) failed with exit codes: {:?}", - // exec_err_codes - // ); - // } - // } - // } - - // #[cfg(all(test, feature = "affinity"))] - // mod affinity_tests { - // use crate::bee::stock::{Thunk, ThunkWorker}; - // use crate::hive::queue::{ChannelGlobalQueue, DefaultLocalQueues}; - // use crate::hive::Builder; - - // type Global = ChannelGlobalQueue; - // type Local = DefaultLocalQueues>; - - // #[test] - // fn test_affinity() { - // let hive = Builder::new() - // .thread_name("affinity example") - // .num_threads(2) - // .core_affinity(0..2) - // .build_with_default::, Global<_>, Local<_>>(); - - // hive.map_store((0..10).map(move |i| { - // Thunk::of(move || { - // if let Some(affininty) = core_affinity::get_core_ids() { - // eprintln!("task {} on thread with affinity {:?}", i, affininty); - // } - // }) - // })); - // } - - // #[test] - // fn test_use_all_cores() { - // let hive = Builder::new() - // .thread_name("affinity example") - // .with_thread_per_core() - // .with_default_core_affinity() - // .build_with_default::, Global<_>, Local<_>>(); - - // hive.map_store((0..num_cpus::get()).map(move |i| { - // Thunk::of(move || { - // if let Some(affininty) = core_affinity::get_core_ids() { - // eprintln!("task {} on thread with affinity {:?}", i, affininty); - // } - // }) - // })); - // } - // } - - // #[cfg(all(test, feature = "batching"))] - // mod batching_tests { - // use crate::barrier::IndexedBarrier; - // use crate::bee::stock::{Thunk, ThunkWorker}; - // use crate::bee::DefaultQueen; - // use crate::hive::queue::{ChannelGlobalQueue, DefaultLocalQueues}; - // use crate::hive::{Builder, Hive, OutcomeIteratorExt, OutcomeReceiver, OutcomeSender}; - // use std::collections::HashMap; - // use std::thread::{self, ThreadId}; - // use std::time::Duration; - - // type Global = ChannelGlobalQueue; - // type Local = DefaultLocalQueues>; - - // fn launch_tasks( - // hive: &Hive< - // ThunkWorker, - // DefaultQueen>, - // ChannelGlobalQueue>, - // DefaultLocalQueues, ChannelGlobalQueue>>, - // >, - // num_threads: usize, - // num_tasks_per_thread: usize, - // barrier: &IndexedBarrier, - // tx: &OutcomeSender>, - // ) -> Vec { - // let total_tasks = num_threads * num_tasks_per_thread; - // // send the first `num_threads` tasks widely spaced, so each worker thread only gets one - // let init_task_ids: Vec<_> = (0..num_threads) - // .map(|_| { - // let barrier = barrier.clone(); - // let task_id = hive.apply_send( - // Thunk::of(move || { - // barrier.wait(); - // thread::sleep(Duration::from_millis(100)); - // thread::current().id() - // }), - // tx, - // ); - // thread::sleep(Duration::from_millis(100)); - // task_id - // }) - // .collect(); - // // send the rest all at once - // let rest_task_ids = hive.map_send( - // (num_threads..total_tasks).map(|_| { - // Thunk::of(move || { - // thread::sleep(Duration::from_millis(1)); - // thread::current().id() - // }) - // }), - // tx, - // ); - // init_task_ids.into_iter().chain(rest_task_ids).collect() - // } - - // fn count_thread_ids( - // rx: OutcomeReceiver>, - // task_ids: Vec, - // ) -> HashMap { - // rx.select_unordered_outputs(task_ids) - // .fold(HashMap::new(), |mut counter, id| { - // *counter.entry(id).or_insert(0) += 1; - // counter - // }) - // } - - // fn run_test( - // hive: &Hive< - // ThunkWorker, - // DefaultQueen>, - // ChannelGlobalQueue>, - // DefaultLocalQueues, ChannelGlobalQueue>>, - // >, - // num_threads: usize, - // batch_size: usize, - // ) { - // let tasks_per_thread = batch_size + 2; - // let (tx, rx) = crate::hive::outcome_channel(); - // // each worker should take `batch_size` tasks for its queue + 1 to work on immediately, - // // meaning there should be `batch_size + 1` tasks associated with each thread ID - // let barrier = IndexedBarrier::new(num_threads); - // let task_ids = launch_tasks(hive, num_threads, tasks_per_thread, &barrier, &tx); - // // start the first tasks - // barrier.wait(); - // // wait for all tasks to complete - // hive.join(); - // let thread_counts = count_thread_ids(rx, task_ids); - // assert_eq!(thread_counts.len(), num_threads); - // assert!(thread_counts - // .values() - // .all(|&count| count == tasks_per_thread)); - // } - - // #[test] - // fn test_batching() { - // const NUM_THREADS: usize = 4; - // const BATCH_SIZE: usize = 24; - // let hive = Builder::new() - // .num_threads(NUM_THREADS) - // .batch_size(BATCH_SIZE) - // .build_with_default::, Global<_>, Local<_>>(); - // run_test(&hive, NUM_THREADS, BATCH_SIZE); - // } - - // #[test] - // fn test_set_batch_size() { - // const NUM_THREADS: usize = 4; - // const BATCH_SIZE_0: usize = 10; - // const BATCH_SIZE_1: usize = 20; - // const BATCH_SIZE_2: usize = 50; - // let hive = Builder::new() - // .num_threads(NUM_THREADS) - // .batch_size(BATCH_SIZE_0) - // .build_with_default::, Global<_>, Local<_>>(); - // run_test(&hive, NUM_THREADS, BATCH_SIZE_0); - // // increase batch size - // hive.set_worker_batch_size(BATCH_SIZE_2); - // run_test(&hive, NUM_THREADS, BATCH_SIZE_2); - // // decrease batch size - // hive.set_worker_batch_size(BATCH_SIZE_1); - // run_test(&hive, NUM_THREADS, BATCH_SIZE_1); - // } - - // #[test] - // fn test_shrink_batch_size() { - // const NUM_THREADS: usize = 4; - // const NUM_TASKS_PER_THREAD: usize = 125; - // const BATCH_SIZE_0: usize = 100; - // const BATCH_SIZE_1: usize = 10; - // let hive = Builder::new() - // .num_threads(NUM_THREADS) - // .batch_size(BATCH_SIZE_0) - // .build_with_default::, Global<_>, Local<_>>(); - // let (tx, rx) = crate::hive::outcome_channel(); - // let barrier = IndexedBarrier::new(NUM_THREADS); - // let task_ids = launch_tasks(&hive, NUM_THREADS, NUM_TASKS_PER_THREAD, &barrier, &tx); - // let total_tasks = NUM_THREADS * NUM_TASKS_PER_THREAD; - // assert_eq!(task_ids.len(), total_tasks); - // barrier.wait(); - // hive.set_worker_batch_size(BATCH_SIZE_1); - // // The number of tasks completed by each thread could be variable, so we want to ensure - // // that a) each processed at least `BATCH_SIZE_0` tasks, and b) there are a total of - // // `NUM_TASKS` outputs with no errors - // hive.join(); - // let thread_counts = count_thread_ids(rx, task_ids); - // assert!(thread_counts.values().all(|count| *count > BATCH_SIZE_0)); - // assert_eq!(thread_counts.values().sum::(), total_tasks); - // } - // } - - // #[cfg(all(test, feature = "retry"))] - // mod retry_tests { - // use crate::bee::stock::RetryCaller; - // use crate::bee::{ApplyError, Context}; - // use crate::hive::queue::{ChannelGlobalQueue, DefaultLocalQueues}; - // use crate::hive::{Builder, Hive, Outcome, OutcomeIteratorExt}; - // use std::time::{Duration, SystemTime}; - - // type Global = ChannelGlobalQueue; - // type Local = DefaultLocalQueues>; - - // fn echo_time(i: usize, ctx: &Context) -> Result> { - // let attempt = ctx.attempt(); - // if attempt == 3 { - // Ok("Success".into()) - // } else { - // // the delay between each message should be exponential - // eprintln!("Task {} attempt {}: {:?}", i, attempt, SystemTime::now()); - // Err(ApplyError::Retryable { - // input: i, - // error: "Retryable".into(), - // }) - // } - // } - - // #[test] - // fn test_retries() { - // let hive: Hive<_, _, Global<_>, Local<_>> = Builder::new() - // .with_thread_per_core() - // .max_retries(3) - // .retry_factor(Duration::from_secs(1)) - // .build_with(RetryCaller::of(echo_time)); - - // let v: Result, _> = hive.swarm(0..10).into_results().collect(); - // assert_eq!(v.unwrap().len(), 10); - // } - - // #[test] - // fn test_retries_fail() { - // fn sometimes_fail( - // i: usize, - // _: &Context, - // ) -> Result> { - // match i % 3 { - // 0 => Ok("Success".into()), - // 1 => Err(ApplyError::Retryable { - // input: i, - // error: "Retryable".into(), - // }), - // 2 => Err(ApplyError::Fatal { - // input: Some(i), - // error: "Fatal".into(), - // }), - // _ => unreachable!(), - // } - // } - - // let hive: Hive<_, _, Global<_>, Local<_>> = Builder::new() - // .with_thread_per_core() - // .max_retries(3) - // .build_with(RetryCaller::of(sometimes_fail)); - - // let (success, retry_failed, not_retried) = hive.swarm(0..10).fold( - // (0, 0, 0), - // |(success, retry_failed, not_retried), outcome| match outcome { - // Outcome::Success { .. } => (success + 1, retry_failed, not_retried), - // Outcome::MaxRetriesAttempted { .. } => (success, retry_failed + 1, not_retried), - // Outcome::Failure { .. } => (success, retry_failed, not_retried + 1), - // _ => unreachable!(), - // }, - // ); - - // assert_eq!(success, 4); - // assert_eq!(retry_failed, 3); - // assert_eq!(not_retried, 3); - // } - - // #[test] - // fn test_disable_retries() { - // let hive: Hive<_, _, Global<_>, Local<_>> = Builder::new() - // .with_thread_per_core() - // .with_no_retries() - // .build_with(RetryCaller::of(echo_time)); - // let v: Result, _> = hive.swarm(0..10).into_results().collect(); - // assert!(v.is_err()); - // } + #[test] + fn test_batching() { + const NUM_THREADS: usize = 4; + const BATCH_SIZE: usize = 24; + let hive = ChannelBuilder::empty() + .with_worker_default() + .num_threads(NUM_THREADS) + .batch_size(BATCH_SIZE) + .build(); + run_test(&hive, NUM_THREADS, BATCH_SIZE); + } + + #[test] + fn test_set_batch_size() { + const NUM_THREADS: usize = 4; + const BATCH_SIZE_0: usize = 10; + const BATCH_SIZE_1: usize = 20; + const BATCH_SIZE_2: usize = 50; + let hive = ChannelBuilder::empty() + .with_worker_default() + .num_threads(NUM_THREADS) + .batch_size(BATCH_SIZE_0) + .build(); + run_test(&hive, NUM_THREADS, BATCH_SIZE_0); + // increase batch size + hive.set_worker_batch_size(BATCH_SIZE_2); + run_test(&hive, NUM_THREADS, BATCH_SIZE_2); + // decrease batch size + hive.set_worker_batch_size(BATCH_SIZE_1); + run_test(&hive, NUM_THREADS, BATCH_SIZE_1); + } + + #[test] + fn test_shrink_batch_size() { + const NUM_THREADS: usize = 4; + const NUM_TASKS_PER_THREAD: usize = 125; + const BATCH_SIZE_0: usize = 100; + const BATCH_SIZE_1: usize = 10; + let hive = ChannelBuilder::empty() + .with_worker_default() + .num_threads(NUM_THREADS) + .batch_size(BATCH_SIZE_0) + .build(); + let (tx, rx) = crate::hive::outcome_channel(); + let barrier = IndexedBarrier::new(NUM_THREADS); + let task_ids = launch_tasks(&hive, NUM_THREADS, NUM_TASKS_PER_THREAD, &barrier, &tx); + let total_tasks = NUM_THREADS * NUM_TASKS_PER_THREAD; + assert_eq!(task_ids.len(), total_tasks); + barrier.wait(); + hive.set_worker_batch_size(BATCH_SIZE_1); + // The number of tasks completed by each thread could be variable, so we want to ensure + // that a) each processed at least `BATCH_SIZE_0` tasks, and b) there are a total of + // `NUM_TASKS` outputs with no errors + hive.join(); + let thread_counts = count_thread_ids(rx, task_ids); + assert!(thread_counts.values().all(|count| *count > BATCH_SIZE_0)); + assert_eq!(thread_counts.values().sum::(), total_tasks); + } +} + +#[cfg(all(test, feature = "retry"))] +mod retry_tests { + use crate::bee::stock::RetryCaller; + use crate::bee::{ApplyError, Context}; + use crate::hive::{Builder, ChannelBuilder, Outcome, OutcomeIteratorExt}; + use std::time::{Duration, SystemTime}; + + fn echo_time(i: usize, ctx: &Context) -> Result> { + let attempt = ctx.attempt(); + if attempt == 3 { + Ok("Success".into()) + } else { + // the delay between each message should be exponential + eprintln!("Task {} attempt {}: {:?}", i, attempt, SystemTime::now()); + Err(ApplyError::Retryable { + input: i, + error: "Retryable".into(), + }) + } + } + + #[test] + fn test_retries() { + let hive = ChannelBuilder::empty() + .with_worker(RetryCaller::of(echo_time)) + .with_thread_per_core() + .max_retries(3) + .retry_factor(Duration::from_secs(1)) + .build(); + + let v: Result, _> = hive.swarm(0..10).into_results().collect(); + assert_eq!(v.unwrap().len(), 10); + } + + #[test] + fn test_retries_fail() { + fn sometimes_fail( + i: usize, + _: &Context, + ) -> Result> { + match i % 3 { + 0 => Ok("Success".into()), + 1 => Err(ApplyError::Retryable { + input: i, + error: "Retryable".into(), + }), + 2 => Err(ApplyError::Fatal { + input: Some(i), + error: "Fatal".into(), + }), + _ => unreachable!(), + } + } + + let hive = ChannelBuilder::empty() + .with_worker(RetryCaller::of(sometimes_fail)) + .with_thread_per_core() + .max_retries(3) + .build(); + + let (success, retry_failed, not_retried) = hive.swarm(0..10).fold( + (0, 0, 0), + |(success, retry_failed, not_retried), outcome| match outcome { + Outcome::Success { .. } => (success + 1, retry_failed, not_retried), + Outcome::MaxRetriesAttempted { .. } => (success, retry_failed + 1, not_retried), + Outcome::Failure { .. } => (success, retry_failed, not_retried + 1), + _ => unreachable!(), + }, + ); + + assert_eq!(success, 4); + assert_eq!(retry_failed, 3); + assert_eq!(not_retried, 3); + } + + #[test] + fn test_disable_retries() { + let hive = ChannelBuilder::empty() + .with_worker(RetryCaller::of(echo_time)) + .with_thread_per_core() + .with_no_retries() + .build(); + let v: Result, _> = hive.swarm(0..10).into_results().collect(); + assert!(v.is_err()); + } } diff --git a/src/hive/outcome/mod.rs b/src/hive/outcome/mod.rs index e908aef..f3d2610 100644 --- a/src/hive/outcome/mod.rs +++ b/src/hive/outcome/mod.rs @@ -5,10 +5,82 @@ mod outcome; mod queue; mod store; -pub use batch::OutcomeBatch; -pub use iter::OutcomeIteratorExt; -pub use outcome::Outcome; -pub use queue::OutcomeQueue; -pub use store::OutcomeStore; +pub use self::batch::OutcomeBatch; +pub use self::iter::OutcomeIteratorExt; +pub use self::queue::OutcomeQueue; +pub use self::store::OutcomeStore; -pub(super) use store::sealed::{DerefOutcomes, OwnedOutcomes}; +pub(super) use self::store::{DerefOutcomes, OwnedOutcomes}; + +use crate::bee::{TaskId, Worker}; +use crate::panic::Panic; + +/// The possible outcomes of a task execution. +/// +/// Each outcome includes the task ID of the task that produced it. Tasks that submitted +/// subtasks (via [`crate::bee::Context::submit_task`]) produce `Outcome` variants that have +/// `subtask_ids`. +/// +/// Note that `Outcome`s can only be compared or ordered with other `Outcome`s produced by the same +/// `Hive`, because comparison/ordering is completely based on the task ID. +#[derive(Debug)] +pub enum Outcome { + /// The task was executed successfully. + Success { value: W::Output, task_id: TaskId }, + /// The task was executed successfully, and it also submitted one or more subtask_ids to the + /// `Hive`. + SuccessWithSubtasks { + value: W::Output, + task_id: TaskId, + subtask_ids: Vec, + }, + /// The task failed with an error that was not retryable. The input value that caused the + /// failure is provided if possible. + Failure { + input: Option, + error: W::Error, + task_id: TaskId, + }, + /// The task failed with an error that was not retryable, but it submitted one or more subtask_ids + /// before failing. The input value that caused the failure is provided if possible. + FailureWithSubtasks { + input: Option, + error: W::Error, + task_id: TaskId, + subtask_ids: Vec, + }, + /// The task was not executed before the Hive was dropped, or processing of the task was + /// interrupted (e.g., by `suspend`ing the `Hive`). + Unprocessed { input: W::Input, task_id: TaskId }, + /// The task was not executed before the Hive was dropped, or processing of the task was + /// interrupted (e.g., by `suspend`ing the `Hive`), but it first submitted one or more subtask_ids. + UnprocessedWithSubtasks { + input: W::Input, + task_id: TaskId, + subtask_ids: Vec, + }, + /// The task with the given task_id was not found in the `Hive` or iterator from which it was + /// being requested. + Missing { task_id: TaskId }, + /// The task panicked. The input value that caused the panic is provided if possible. + Panic { + input: Option, + payload: Panic, + task_id: TaskId, + }, + /// The task panicked, but it submitted one or more subtask_ids before panicking. The input value + /// that caused the panic is provided if possible. + PanicWithSubtasks { + input: Option, + payload: Panic, + task_id: TaskId, + subtask_ids: Vec, + }, + /// The task failed after retrying the maximum number of times. + #[cfg(feature = "retry")] + MaxRetriesAttempted { + input: W::Input, + error: W::Error, + task_id: TaskId, + }, +} diff --git a/src/hive/outcome/outcome.rs b/src/hive/outcome/outcome.rs deleted file mode 100644 index ff926a0..0000000 --- a/src/hive/outcome/outcome.rs +++ /dev/null @@ -1,452 +0,0 @@ -use crate::bee::{ApplyError, TaskId, Worker, WorkerResult}; -use crate::panic::Panic; -use std::cmp::Ordering; -use std::fmt::Debug; - -/// The possible outcomes of a task execution. -/// -/// Each outcome includes the task ID of the task that produced it. Tasks that submitted -/// subtasks (via [`crate::bee::Context::submit_task`]) produce `Outcome` variants that have -/// `subtask_ids`. -/// -/// Note that `Outcome`s can only be compared or ordered with other `Outcome`s produced by the same -/// `Hive`, because comparison/ordering is completely based on the task ID. -#[derive(Debug)] -pub enum Outcome { - /// The task was executed successfully. - Success { value: W::Output, task_id: TaskId }, - /// The task was executed successfully, and it also submitted one or more subtask_ids to the - /// `Hive`. - SuccessWithSubtasks { - value: W::Output, - task_id: TaskId, - subtask_ids: Vec, - }, - /// The task failed with an error that was not retryable. The input value that caused the - /// failure is provided if possible. - Failure { - input: Option, - error: W::Error, - task_id: TaskId, - }, - /// The task failed with an error that was not retryable, but it submitted one or more subtask_ids - /// before failing. The input value that caused the failure is provided if possible. - FailureWithSubtasks { - input: Option, - error: W::Error, - task_id: TaskId, - subtask_ids: Vec, - }, - /// The task was not executed before the Hive was dropped, or processing of the task was - /// interrupted (e.g., by `suspend`ing the `Hive`). - Unprocessed { input: W::Input, task_id: TaskId }, - /// The task was not executed before the Hive was dropped, or processing of the task was - /// interrupted (e.g., by `suspend`ing the `Hive`), but it first submitted one or more subtask_ids. - UnprocessedWithSubtasks { - input: W::Input, - task_id: TaskId, - subtask_ids: Vec, - }, - /// The task with the given task_id was not found in the `Hive` or iterator from which it was - /// being requested. - Missing { task_id: TaskId }, - /// The task panicked. The input value that caused the panic is provided if possible. - Panic { - input: Option, - payload: Panic, - task_id: TaskId, - }, - /// The task panicked, but it submitted one or more subtask_ids before panicking. The input value - /// that caused the panic is provided if possible. - PanicWithSubtasks { - input: Option, - payload: Panic, - task_id: TaskId, - subtask_ids: Vec, - }, - /// The task failed after retrying the maximum number of times. - #[cfg(feature = "retry")] - MaxRetriesAttempted { - input: W::Input, - error: W::Error, - task_id: TaskId, - }, -} - -impl Outcome { - /// Converts a worker `result` into an `Outcome` with the given task_id and optional subtask ids. - pub(in crate::hive) fn from_worker_result( - result: WorkerResult, - task_id: TaskId, - subtask_ids: Option>, - ) -> Self { - match (result, subtask_ids) { - (Ok(value), Some(subtask_ids)) => Self::SuccessWithSubtasks { - value, - task_id, - subtask_ids, - }, - (Ok(value), None) => Self::Success { value, task_id }, - (Err(ApplyError::Retryable { input, error, .. }), Some(subtask_ids)) => { - Self::FailureWithSubtasks { - input: Some(input), - error, - task_id, - subtask_ids, - } - } - (Err(ApplyError::Retryable { input, error }), None) => { - #[cfg(feature = "retry")] - { - Self::MaxRetriesAttempted { - input, - error, - task_id, - } - } - #[cfg(not(feature = "retry"))] - { - Self::Failure { - input: Some(input), - error, - task_id, - } - } - } - (Err(ApplyError::Fatal { input, error }), Some(subtask_ids)) => { - Self::FailureWithSubtasks { - input, - error, - task_id, - subtask_ids, - } - } - (Err(ApplyError::Fatal { input, error }), None) => Self::Failure { - input, - error, - task_id, - }, - (Err(ApplyError::Cancelled { input }), Some(subtask_ids)) => { - Self::UnprocessedWithSubtasks { - input, - task_id, - subtask_ids, - } - } - (Err(ApplyError::Cancelled { input }), None) => Self::Unprocessed { input, task_id }, - (Err(ApplyError::Panic { input, payload }), Some(subtask_ids)) => { - Self::PanicWithSubtasks { - input, - payload, - task_id, - subtask_ids, - } - } - (Err(ApplyError::Panic { input, payload }), None) => Self::Panic { - input, - payload, - task_id, - }, - } - } - - /// Returns `true` if this is a `Success` outcome. - pub fn is_success(&self) -> bool { - matches!(self, Self::Success { .. }) - } - - /// Returns `true` if this outcome represents an unprocessed task input. - pub fn is_unprocessed(&self) -> bool { - matches!(self, Self::Unprocessed { .. }) - } - - /// Returns `true` if this outcome represents a task processing failure. - pub fn is_failure(&self) -> bool { - match self { - Self::Failure { .. } | Self::Panic { .. } => true, - #[cfg(feature = "retry")] - Self::MaxRetriesAttempted { .. } => true, - _ => false, - } - } - - /// Returns the task_id of the task that produced this outcome. - pub fn task_id(&self) -> &TaskId { - match self { - Self::Success { task_id, .. } - | Self::SuccessWithSubtasks { task_id, .. } - | Self::Failure { task_id, .. } - | Self::FailureWithSubtasks { task_id, .. } - | Self::Unprocessed { task_id, .. } - | Self::UnprocessedWithSubtasks { task_id, .. } - | Self::Missing { task_id } - | Self::Panic { task_id, .. } - | Self::PanicWithSubtasks { task_id, .. } => task_id, - #[cfg(feature = "retry")] - Self::MaxRetriesAttempted { task_id, .. } => task_id, - } - } - - /// Returns the IDs of the tasks submitted by the task that produced this outcome, or `None` - /// if the task did not submit any subtasks. - pub fn subtask_ids(&self) -> Option<&Vec> { - match self { - Self::SuccessWithSubtasks { subtask_ids, .. } - | Self::FailureWithSubtasks { subtask_ids, .. } - | Self::UnprocessedWithSubtasks { subtask_ids, .. } - | Self::PanicWithSubtasks { subtask_ids, .. } => Some(subtask_ids), - _ => None, - } - } - - /// Consumes this `Outcome` and returns the value if it is a `Success`, otherwise panics. - pub fn unwrap(self) -> W::Output { - self.success().expect("not a Success outcome") - } - - /// Consumes this `Outcome` and returns the output value if it is a `Success`, otherwise `None`. - pub fn success(self) -> Option { - match self { - Self::Success { value, .. } => Some(value), - _ => None, - } - } - - /// Consumes this `Outcome` and returns the input value if available, otherwise `None`. - pub fn try_into_input(self) -> Option { - match self { - Self::Failure { input, .. } - | Self::FailureWithSubtasks { input, .. } - | Self::Panic { input, .. } - | Self::PanicWithSubtasks { input, .. } => input, - Self::Unprocessed { input, .. } | Self::UnprocessedWithSubtasks { input, .. } => { - Some(input) - } - Self::Success { .. } | Self::SuccessWithSubtasks { .. } | Self::Missing { .. } => None, - #[cfg(feature = "retry")] - Self::MaxRetriesAttempted { input, .. } => Some(input), - } - } - - /// Consumes this `Outcome` and depending on the variant: - /// * Returns the wrapped error if this is a `Failure` or `MaxRetriesAttempted`, - /// * Resumes unwinding if this is a `Panic` outcome, - /// * Otherwise returns `None`. - pub fn try_into_error(self) -> Option { - match self { - Self::Failure { error, .. } | Self::FailureWithSubtasks { error, .. } => Some(error), - Self::Panic { payload, .. } | Self::PanicWithSubtasks { payload, .. } => { - payload.resume() - } - Self::Success { .. } - | Self::SuccessWithSubtasks { .. } - | Self::Unprocessed { .. } - | Self::UnprocessedWithSubtasks { .. } - | Self::Missing { .. } => None, - #[cfg(feature = "retry")] - Self::MaxRetriesAttempted { error, .. } => Some(error), - } - } -} - -impl PartialEq for Outcome { - fn eq(&self, other: &Self) -> bool { - match (self, other) { - (Self::Success { task_id: a, .. }, Self::Success { task_id: b, .. }) - | ( - Self::SuccessWithSubtasks { task_id: a, .. }, - Self::SuccessWithSubtasks { task_id: b, .. }, - ) - | (Self::Failure { task_id: a, .. }, Self::Failure { task_id: b, .. }) - | ( - Self::FailureWithSubtasks { task_id: a, .. }, - Self::FailureWithSubtasks { task_id: b, .. }, - ) - | (Self::Unprocessed { task_id: a, .. }, Self::Unprocessed { task_id: b, .. }) - | ( - Self::UnprocessedWithSubtasks { task_id: a, .. }, - Self::UnprocessedWithSubtasks { task_id: b, .. }, - ) - | (Self::Missing { task_id: a }, Self::Missing { task_id: b }) - | (Self::Panic { task_id: a, .. }, Self::Panic { task_id: b, .. }) - | ( - Self::PanicWithSubtasks { task_id: a, .. }, - Self::PanicWithSubtasks { task_id: b, .. }, - ) => a == b, - #[cfg(feature = "retry")] - ( - Self::MaxRetriesAttempted { task_id: a, .. }, - Self::MaxRetriesAttempted { task_id: b, .. }, - ) => a == b, - _ => false, - } - } -} - -impl Eq for Outcome {} - -impl PartialOrd for Outcome { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl Ord for Outcome { - fn cmp(&self, other: &Self) -> Ordering { - self.task_id().cmp(other.task_id()) - } -} - -#[cfg(test)] -mod tests { - use super::Outcome; - use crate::bee::stock::EchoWorker; - use crate::panic::Panic; - - type Worker = EchoWorker; - type WorkerOutcome = Outcome; - - #[test] - fn test_try_into_input() { - let outcome = WorkerOutcome::Success { - value: 42, - task_id: 1, - }; - assert_eq!(outcome.try_into_input(), None); - - let outcome = WorkerOutcome::Failure { - input: None, - error: (), - task_id: 2, - }; - assert_eq!(outcome.try_into_input(), None); - - let outcome = WorkerOutcome::Failure { - input: Some(42), - error: (), - task_id: 2, - }; - assert_eq!(outcome.try_into_input(), Some(42)); - - let outcome = WorkerOutcome::Unprocessed { - input: 42, - task_id: 3, - }; - assert_eq!(outcome.try_into_input(), Some(42)); - - let outcome = WorkerOutcome::Missing { task_id: 4 }; - assert_eq!(outcome.try_into_input(), None); - - let outcome = WorkerOutcome::Panic { - input: None, - payload: Panic::try_call(None, || panic!()).unwrap_err(), - task_id: 5, - }; - assert_eq!(outcome.try_into_input(), None); - - let outcome = WorkerOutcome::Panic { - input: Some(42), - payload: Panic::try_call(None, || panic!()).unwrap_err(), - task_id: 5, - }; - assert_eq!(outcome.try_into_input(), Some(42)); - } - - #[test] - fn test_try_into_error() { - let outcome = WorkerOutcome::Success { - value: 42, - task_id: 1, - }; - assert_eq!(outcome.try_into_error(), None); - - let outcome = WorkerOutcome::Failure { - input: None, - error: (), - task_id: 2, - }; - assert_eq!(outcome.try_into_error(), Some(())); - - let outcome = WorkerOutcome::Failure { - input: Some(42), - error: (), - task_id: 2, - }; - assert_eq!(outcome.try_into_error(), Some(())); - - let outcome = WorkerOutcome::Unprocessed { - input: 42, - task_id: 3, - }; - assert_eq!(outcome.try_into_error(), None); - - let outcome = WorkerOutcome::Missing { task_id: 4 }; - assert_eq!(outcome.try_into_error(), None); - } - - #[test] - #[should_panic] - fn test_try_into_error_panic() { - WorkerOutcome::Panic { - input: None, - payload: Panic::try_call(None, || panic!()).unwrap_err(), - task_id: 5, - } - .try_into_error(); - } - - #[test] - fn test_eq() { - let outcome1 = WorkerOutcome::Success { - value: 42, - task_id: 1, - }; - let outcome2 = WorkerOutcome::Success { - value: 42, - task_id: 1, - }; - assert_eq!(outcome1, outcome2); - - let outcome3 = WorkerOutcome::Success { - value: 42, - task_id: 2, - }; - assert_ne!(outcome1, outcome3); - - let outcome4 = WorkerOutcome::Failure { - input: None, - error: (), - task_id: 1, - }; - assert_ne!(outcome1, outcome4); - } -} - -#[cfg(all(test, feature = "retry"))] -mod retry_tests { - use super::Outcome; - use crate::bee::stock::EchoWorker; - - type Worker = EchoWorker; - type WorkerOutcome = Outcome; - - #[test] - fn test_try_into_input() { - let outcome = WorkerOutcome::MaxRetriesAttempted { - input: 42, - error: (), - task_id: 1, - }; - assert_eq!(outcome.try_into_input(), Some(42)); - } - - #[test] - fn test_try_into_error() { - let outcome = WorkerOutcome::MaxRetriesAttempted { - input: 42, - error: (), - task_id: 1, - }; - assert_eq!(outcome.try_into_error(), Some(())); - } -} diff --git a/src/hive/outcome/queue.rs b/src/hive/outcome/queue.rs index b109983..bec9154 100644 --- a/src/hive/outcome/queue.rs +++ b/src/hive/outcome/queue.rs @@ -1,6 +1,5 @@ use super::Outcome; -use crate::bee::Worker; -use crate::hive::TaskId; +use crate::bee::{TaskId, Worker}; use crossbeam_queue::SegQueue; use parking_lot::Mutex; use std::collections::HashMap; diff --git a/src/hive/outcome/store.rs b/src/hive/outcome/store.rs index eb19608..d401ed5 100644 --- a/src/hive/outcome/store.rs +++ b/src/hive/outcome/store.rs @@ -1,31 +1,24 @@ use super::Outcome; use crate::bee::{TaskId, Worker}; -use sealed::DerefOutcomes; +use std::{ + collections::HashMap, + ops::{Deref, DerefMut}, +}; -/// Traits with methods that should only be accessed internally by public traits. -pub mod sealed { - use crate::bee::{TaskId, Worker}; - use crate::hive::Outcome; - use std::{ - collections::HashMap, - ops::{Deref, DerefMut}, - }; +pub trait DerefOutcomes { + /// Returns a read-only reference to a map of task task_id to `Outcome`. + fn outcomes_deref(&self) -> impl Deref>>; - pub trait DerefOutcomes { - /// Returns a read-only reference to a map of task task_id to `Outcome`. - fn outcomes_deref(&self) -> impl Deref>>; - - /// Returns a mutable reference to a map of task task_id to `Outcome`. - fn outcomes_deref_mut(&mut self) -> impl DerefMut>> + '_; - } + /// Returns a mutable reference to a map of task task_id to `Outcome`. + fn outcomes_deref_mut(&mut self) -> impl DerefMut>> + '_; +} - pub trait OwnedOutcomes: Sized { - /// Returns an owned map of task task_id to `Outcome`. - fn outcomes(self) -> HashMap>; +pub trait OwnedOutcomes: Sized { + /// Returns an owned map of task task_id to `Outcome`. + fn outcomes(self) -> HashMap>; - /// Returns a read-only reference to a map of task task_id to `Outcome`. - fn outcomes_ref(&self) -> &HashMap>; - } + /// Returns a read-only reference to a map of task task_id to `Outcome`. + fn outcomes_ref(&self) -> &HashMap>; } /// Trait implemented by structs that store `Outcome`s (`Hive`, `Husk`, and `OutcomeBatch`). @@ -33,7 +26,7 @@ pub mod sealed { /// The first group of methods provided by this trait only require dereferencing the underlying map, /// while the second group of methods require the ability to borrow or take ownership of the /// underlying map (and thus, are not in scope for `Hive`). -pub trait OutcomeStore: sealed::DerefOutcomes { +pub trait OutcomeStore: DerefOutcomes { fn len(&self) -> usize { self.outcomes_deref().len() } @@ -218,7 +211,7 @@ pub trait OutcomeStore: sealed::DerefOutcomes { /// Returns the stored `Outcome` associated with the given task_id, if any. fn get(&self, task_id: TaskId) -> Option<&Outcome> where - Self: sealed::OwnedOutcomes, + Self: OwnedOutcomes, { self.outcomes_ref().get(&task_id) } @@ -226,7 +219,7 @@ pub trait OutcomeStore: sealed::DerefOutcomes { /// Consumes this store and returns an iterator over the outcomes in task_id order. fn into_iter(self) -> impl Iterator> where - Self: sealed::OwnedOutcomes, + Self: OwnedOutcomes, { self.outcomes().into_values() } @@ -234,7 +227,7 @@ pub trait OutcomeStore: sealed::DerefOutcomes { /// Returns the successes as a `Vec` if there are no errors, otherwise panics. fn unwrap(self) -> Vec where - Self: sealed::OwnedOutcomes, + Self: OwnedOutcomes, { assert!( !(self.has_failures() || self.has_unprocessed()), @@ -253,7 +246,7 @@ pub trait OutcomeStore: sealed::DerefOutcomes { /// they cause this method to panic. fn ok_or_unwrap_errors(self, drop_unprocessed: bool) -> Result, Vec> where - Self: sealed::OwnedOutcomes, + Self: OwnedOutcomes, { assert!( drop_unprocessed || !self.has_unprocessed(), @@ -275,7 +268,7 @@ pub trait OutcomeStore: sealed::DerefOutcomes { /// inputs are returned in task_id order, otherwise they are unordered. fn into_unprocessed(self, ordered: bool) -> Vec where - Self: sealed::OwnedOutcomes, + Self: OwnedOutcomes, { let values = self .outcomes() @@ -301,7 +294,7 @@ pub trait OutcomeStore: sealed::DerefOutcomes { /// that were queued but not yet processed when the `Hive` was dropped. fn iter_unprocessed(&self) -> impl Iterator where - Self: sealed::OwnedOutcomes, + Self: OwnedOutcomes, { self.outcomes_ref() .values() @@ -315,7 +308,7 @@ pub trait OutcomeStore: sealed::DerefOutcomes { /// that were successfully processed but not sent to any output channel. fn iter_successes(&self) -> impl Iterator where - Self: sealed::OwnedOutcomes, + Self: OwnedOutcomes, { self.outcomes_ref() .values() @@ -329,7 +322,7 @@ pub trait OutcomeStore: sealed::DerefOutcomes { /// that were successfully processed but not sent to any output channel. fn iter_failures(&self) -> impl Iterator> where - Self: sealed::OwnedOutcomes, + Self: OwnedOutcomes, { self.outcomes_ref() .values() diff --git a/src/hive/queue/global.rs b/src/hive/queue/global.rs deleted file mode 100644 index 87c737d..0000000 --- a/src/hive/queue/global.rs +++ /dev/null @@ -1,77 +0,0 @@ -use crate::atomic::{Atomic, AtomicBool}; -use crate::bee::Worker; -use crate::hive::{GlobalPopError, GlobalQueue, Task}; -use crossbeam_channel::RecvTimeoutError; -use std::time::Duration; - -/// Type alias for the input task channel sender -type TaskSender = crossbeam_channel::Sender>; -/// Type alias for the input task channel receiver -type TaskReceiver = crossbeam_channel::Receiver>; - -pub struct ChannelGlobalQueue { - tx: TaskSender, - rx: TaskReceiver, - closed: AtomicBool, -} - -impl ChannelGlobalQueue { - /// Returns a new `GlobalQueue` that uses the given channel sender for pushing new tasks - /// and the given channel receiver for popping tasks. - pub(super) fn new(tx: TaskSender, rx: TaskReceiver) -> Self { - Self { - tx, - rx, - closed: AtomicBool::default(), - } - } - - pub(super) fn try_pop_timeout( - &self, - timeout: Duration, - ) -> Option, GlobalPopError>> { - match self.rx.recv_timeout(timeout) { - Ok(task) => Some(Ok(task)), - Err(RecvTimeoutError::Disconnected) => Some(Err(GlobalPopError::Closed)), - Err(RecvTimeoutError::Timeout) if self.closed.get() && self.rx.is_empty() => { - Some(Err(GlobalPopError::Closed)) - } - Err(RecvTimeoutError::Timeout) => None, - } - } -} - -impl GlobalQueue for ChannelGlobalQueue { - fn try_push(&self, task: Task) -> Result<(), Task> { - if !self.closed.get() { - self.tx.try_send(task).map_err(|err| err.into_inner()) - } else { - Err(task) - } - } - - fn try_pop(&self) -> Option, GlobalPopError>> { - // time to wait in between polling the retry queue and then the task receiver - const RECV_TIMEOUT: Duration = Duration::from_secs(1); - self.try_pop_timeout(RECV_TIMEOUT) - } - - fn try_iter(&self) -> impl Iterator> + '_ { - self.rx.try_iter() - } - - fn drain(&self) -> Vec> { - self.rx.try_iter().collect() - } - - fn close(&self) { - self.closed.set(true); - } -} - -impl Default for ChannelGlobalQueue { - fn default() -> Self { - let (tx, rx) = crossbeam_channel::unbounded(); - Self::new(tx, rx) - } -} diff --git a/src/hive/queue/mod.rs b/src/hive/queue/mod.rs deleted file mode 100644 index 68586b2..0000000 --- a/src/hive/queue/mod.rs +++ /dev/null @@ -1,44 +0,0 @@ -#[cfg(feature = "retry")] -mod delay; -mod global; -#[cfg(any(feature = "batching", feature = "retry"))] -mod local; -#[cfg(not(any(feature = "batching", feature = "retry")))] -mod null; - -pub use global::ChannelGlobalQueue; -#[cfg(any(feature = "batching", feature = "retry"))] -pub use local::LocalQueuesImpl as DefaultLocalQueues; -#[cfg(not(any(feature = "batching", feature = "retry")))] -pub use null::LocalQueuesImpl as DefaultLocalQueues; - -use super::{GlobalQueue, LocalQueues, QueuePair}; -use crate::bee::Worker; -use std::marker::PhantomData; - -pub(crate) type ChannelQueues = - DefaultQueuePair, DefaultLocalQueues>>; - -pub(crate) struct DefaultQueuePair< - W: Worker, - G: GlobalQueue + Default, - L: LocalQueues + Default, -> { - _worker: PhantomData, - _global: PhantomData, - _local: PhantomData, -} - -impl QueuePair for DefaultQueuePair -where - W: Worker, - G: GlobalQueue + Default, - L: LocalQueues + Default, -{ - type Global = G; - type Local = L; - - fn new() -> (Self::Global, Self::Local) { - (Self::Global::default(), Self::Local::default()) - } -} diff --git a/src/hive/queue/null.rs b/src/hive/queue/null.rs deleted file mode 100644 index e3be134..0000000 --- a/src/hive/queue/null.rs +++ /dev/null @@ -1,44 +0,0 @@ -use crate::bee::{Queen, Worker}; -use crate::hive::{ChannelGlobalQueue, LocalQueues, Shared, Task}; -use std::marker::PhantomData; - -pub struct LocalQueuesImpl(PhantomData W>); - -impl LocalQueues> for LocalQueuesImpl { - fn init_for_threads>( - &self, - _: usize, - _: usize, - _: &Shared, Self>, - ) { - } - - #[inline(always)] - fn push>( - &self, - task: Task, - _: usize, - shared: &Shared, Self>, - ) { - shared.push_global(task); - } - - #[inline(always)] - fn try_pop>( - &self, - _: usize, - _: &Shared, Self>, - ) -> Option> { - None - } - - fn drain(&self) -> Vec> { - Vec::new() - } -} - -impl Default for LocalQueuesImpl { - fn default() -> Self { - Self(PhantomData) - } -} diff --git a/src/hive/scoped/hive.rs b/src/hive/scoped/hive.rs deleted file mode 100644 index 58232aa..0000000 --- a/src/hive/scoped/hive.rs +++ /dev/null @@ -1,152 +0,0 @@ -use crate::{ApplyError, Panic}; -use parking_lot::Mutex; -use std::{ - fmt::Debug, - sync::{mpsc, Arc}, - thread, -}; - -pub type WorkerError = ApplyError<::Input, ::Error>; -pub type WorkerResult = Result<::Output, WorkerError>; - -pub trait Worker: Debug + Sized { - type Input: Send; - type Output: Send; - type Error: Send + Debug; - - fn apply(&mut self, _: Self::Input, _: &Context) -> WorkerResult; -} - -pub trait Queen: Send + Sync { - type Kind: Worker; - - fn create(&mut self) -> Self::Kind; -} - -pub struct Hive> { - queen: Mutex, - num_threads: usize, -} - -#[derive(thiserror::Error, Debug)] -pub enum HiveError { - #[error("Task failed")] - Failed(W::Error), - #[error("Task retried the maximum number of times")] - MaxRetriesAttempted(W::Error), - #[error("Task input was not processed")] - Unprocessed(W::Input), - #[error("Task panicked")] - Panic(Panic), -} - -pub type HiveResult = Result>; -pub type TaskResult = HiveResult<::Output, W>; - -#[derive(Debug, PartialEq, Eq)] -pub enum Outcome { - /// The task was executed successfully. - Success { value: W::Output, task_id: TaskId }, - /// The task failed with an error that was not retryable. - Failure { error: W::Error, task_id: TaskId }, - /// The task failed after retrying the maximum number of times. - MaxRetriesAttempted { error: W::Error, task_id: TaskId }, - /// The task was not executed before the Hive was closed. - Unprocessed { value: W::Input, task_id: TaskId }, - /// The task panicked. - Panic { - payload: Panic, - task_id: TaskId, - }, -} - -impl Outcome { - /// Returns the ID of the task that produced this outcome. - pub fn task_id(&self) -> TaskId { - match self { - Outcome::Success { task_id, .. } - | Outcome::Failure { task_id, .. } - | Outcome::MaxRetriesAttempted { task_id, .. } - | Outcome::Unprocessed { task_id, .. } - | Outcome::Panic { task_id, .. } => *task_id, - } - } - - /// Creates a new `Outcome` from a `Panic`. - pub fn from_panic(payload: Panic, task_id: TaskId) -> Outcome { - Outcome::Panic { payload, task_id } - } - - pub(crate) fn from_panic_result( - result: Result, Panic>, - task_id: TaskId, - ) -> Outcome { - match result { - Ok(result) => Outcome::from_worker_result(result, task_id), - Err(panic) => Outcome::from_panic(panic, task_id), - } - } - - pub(crate) fn from_worker_result(result: WorkerResult, task_id: TaskId) -> Outcome { - match result { - Ok(value) => Self::Success { task_id, value }, - Err(ApplyError::Cancelled { input } | ApplyError::Retryable { input, .. }) => { - Self::Unprocessed { - value: input, - task_id, - } - } - Err(ApplyError::Fatal(error)) => Self::Failure { error, task_id }, - } - } -} - -/// Context for a task. -#[derive(Debug, Default)] -pub struct Context { - task_id: TaskId, - attempt: u32, -} - -impl Context { - fn new(task_id: TaskId) -> Self { - Self { - task_id, - attempt: 0, - } - } -} - -impl> Hive { - pub fn map(&self, inputs: impl IntoIterator) { - //-> impl Iterator> { - let (task_tx, task_rx) = mpsc::channel(); - let task_rx = Arc::new(Mutex::new(task_rx)); - let (outcome_tx, outcome_rx) = crate::outcome_channel(); - thread::scope(|scope| { - let join_handles = (0..self.num_threads) - .map(|task_id| { - let task_rx = task_rx.clone(); - let outcome_tx = outcome_tx.clone(); - scope.spawn(move || loop { - let mut worker = self.queen.lock().create(); - if let Ok(input) = task_rx.lock().recv() { - let ctx = Context::new(task_id); - let result: Result, Panic> = - Panic::try_call(None, || worker.apply(input, &ctx)); - let outcome: Outcome = Outcome::from_panic_result(result, task_id); - outcome_tx.send(outcome); - } else { - break; - } - }) - }) - .collect::>(); - }); - // let num_tasks = inputs - // .into_iter() - // .map(|input| self.apply_send(job, tx.clone())) - // .count(); - // rx.into_iter().take(num_tasks).map(Outcome::into_result) - } -} diff --git a/src/hive/scoped/mod.rs b/src/hive/scoped/mod.rs deleted file mode 100644 index e69de29..0000000 diff --git a/src/hive/workstealing.rs b/src/hive/workstealing.rs deleted file mode 100644 index 657ad05..0000000 --- a/src/hive/workstealing.rs +++ /dev/null @@ -1,17 +0,0 @@ -use super::{HiveInner, LocalQueue, Shared, Task}; -use crate::bee::{Context, Queen, Worker}; -use crossbeam_deque::Injector as GlobalQueue; -use std::sync::Arc; - -type WorkerQueue = crossbeam_deque::Worker>; - -impl LocalQueue for WorkerQueue { - fn new>(shared: &Arc>) -> Self - where - Self: Sized, - { - Self::new_fifo() - } -} - -pub struct WorkstealingHive>(Option>>); diff --git a/src/util.rs b/src/util.rs index 1160293..86916c4 100644 --- a/src/util.rs +++ b/src/util.rs @@ -4,9 +4,7 @@ //! creating the [`Hive`](crate::hive::Hive), submitting tasks, collecting results, and shutting //! down the `Hive` properly. use crate::bee::stock::{Caller, OnceCaller}; -use crate::hive::{ - Builder, ChannelGlobalQueue, ChannelQueues, DefaultLocalQueues, Outcome, OutcomeBatch, -}; +use crate::hive::{Builder, ChannelBuilder, Outcome, OutcomeBatch}; use std::fmt::Debug; /// Convenience function that creates a `Hive` with `num_threads` worker threads that execute the @@ -30,9 +28,10 @@ where Inputs: IntoIterator, F: FnMut(I) -> O + Send + Sync + Clone + 'static, { - Builder::default() + ChannelBuilder::default() .num_threads(num_threads) - .build_with::<_, ChannelQueues<_>>(Caller::of(f)) + .with_worker(Caller::of(f)) + .build() .map(inputs) .map(Outcome::unwrap) .collect() @@ -69,9 +68,10 @@ where Inputs: IntoIterator, F: FnMut(I) -> Result + Send + Sync + Clone + 'static, { - Builder::default() + ChannelBuilder::default() .num_threads(num_threads) - .build_with::<_, ChannelQueues<_>>(OnceCaller::of(f)) + .with_worker(OnceCaller::of(f)) + .build() .map(inputs) .into() } @@ -117,7 +117,7 @@ pub use retry::try_map_retryable; mod retry { use crate::bee::stock::RetryCaller; use crate::bee::{ApplyError, Context}; - use crate::hive::{Builder, ChannelQueues, OutcomeBatch}; + use crate::hive::{Builder, ChannelBuilder, OutcomeBatch}; use std::fmt::Debug; /// Convenience function that creates a `Hive` with `num_threads` worker threads that execute the @@ -158,10 +158,11 @@ mod retry { Inputs: IntoIterator, F: FnMut(I, &Context) -> Result> + Send + Sync + Clone + 'static, { - Builder::default() + ChannelBuilder::default() .num_threads(num_threads) .max_retries(max_retries) - .build_with::<_, ChannelQueues<_>>(RetryCaller::of(f)) + .with_worker(RetryCaller::of(f)) + .build() .map(inputs) .into() } From 6e6d97c811900a389c429f77c32356497d8891d9 Mon Sep 17 00:00:00 2001 From: jdidion Date: Sun, 16 Feb 2025 15:21:41 -0800 Subject: [PATCH 10/67] reorganize hive module --- src/hive/outcome/outcome.rs | 381 ++++++++++++++++++++++++++++++++++++ 1 file changed, 381 insertions(+) create mode 100644 src/hive/outcome/outcome.rs diff --git a/src/hive/outcome/outcome.rs b/src/hive/outcome/outcome.rs new file mode 100644 index 0000000..777430d --- /dev/null +++ b/src/hive/outcome/outcome.rs @@ -0,0 +1,381 @@ +use super::Outcome; +use crate::bee::{ApplyError, TaskId, Worker, WorkerResult}; +use std::cmp::Ordering; + +impl Outcome { + /// Converts a worker `result` into an `Outcome` with the given task_id and optional subtask ids. + pub(in crate::hive) fn from_worker_result( + result: WorkerResult, + task_id: TaskId, + subtask_ids: Option>, + ) -> Self { + match (result, subtask_ids) { + (Ok(value), Some(subtask_ids)) => Self::SuccessWithSubtasks { + value, + task_id, + subtask_ids, + }, + (Ok(value), None) => Self::Success { value, task_id }, + (Err(ApplyError::Retryable { input, error, .. }), Some(subtask_ids)) => { + Self::FailureWithSubtasks { + input: Some(input), + error, + task_id, + subtask_ids, + } + } + (Err(ApplyError::Retryable { input, error }), None) => { + #[cfg(feature = "retry")] + { + Self::MaxRetriesAttempted { + input, + error, + task_id, + } + } + #[cfg(not(feature = "retry"))] + { + Self::Failure { + input: Some(input), + error, + task_id, + } + } + } + (Err(ApplyError::Fatal { input, error }), Some(subtask_ids)) => { + Self::FailureWithSubtasks { + input, + error, + task_id, + subtask_ids, + } + } + (Err(ApplyError::Fatal { input, error }), None) => Self::Failure { + input, + error, + task_id, + }, + (Err(ApplyError::Cancelled { input }), Some(subtask_ids)) => { + Self::UnprocessedWithSubtasks { + input, + task_id, + subtask_ids, + } + } + (Err(ApplyError::Cancelled { input }), None) => Self::Unprocessed { input, task_id }, + (Err(ApplyError::Panic { input, payload }), Some(subtask_ids)) => { + Self::PanicWithSubtasks { + input, + payload, + task_id, + subtask_ids, + } + } + (Err(ApplyError::Panic { input, payload }), None) => Self::Panic { + input, + payload, + task_id, + }, + } + } + + /// Returns `true` if this is a `Success` outcome. + pub fn is_success(&self) -> bool { + matches!(self, Self::Success { .. }) + } + + /// Returns `true` if this outcome represents an unprocessed task input. + pub fn is_unprocessed(&self) -> bool { + matches!(self, Self::Unprocessed { .. }) + } + + /// Returns `true` if this outcome represents a task processing failure. + pub fn is_failure(&self) -> bool { + match self { + Self::Failure { .. } | Self::Panic { .. } => true, + #[cfg(feature = "retry")] + Self::MaxRetriesAttempted { .. } => true, + _ => false, + } + } + + /// Returns the task_id of the task that produced this outcome. + pub fn task_id(&self) -> &TaskId { + match self { + Self::Success { task_id, .. } + | Self::SuccessWithSubtasks { task_id, .. } + | Self::Failure { task_id, .. } + | Self::FailureWithSubtasks { task_id, .. } + | Self::Unprocessed { task_id, .. } + | Self::UnprocessedWithSubtasks { task_id, .. } + | Self::Missing { task_id } + | Self::Panic { task_id, .. } + | Self::PanicWithSubtasks { task_id, .. } => task_id, + #[cfg(feature = "retry")] + Self::MaxRetriesAttempted { task_id, .. } => task_id, + } + } + + /// Returns the IDs of the tasks submitted by the task that produced this outcome, or `None` + /// if the task did not submit any subtasks. + pub fn subtask_ids(&self) -> Option<&Vec> { + match self { + Self::SuccessWithSubtasks { subtask_ids, .. } + | Self::FailureWithSubtasks { subtask_ids, .. } + | Self::UnprocessedWithSubtasks { subtask_ids, .. } + | Self::PanicWithSubtasks { subtask_ids, .. } => Some(subtask_ids), + _ => None, + } + } + + /// Consumes this `Outcome` and returns the value if it is a `Success`, otherwise panics. + pub fn unwrap(self) -> W::Output { + self.success().expect("not a Success outcome") + } + + /// Consumes this `Outcome` and returns the output value if it is a `Success`, otherwise `None`. + pub fn success(self) -> Option { + match self { + Self::Success { value, .. } => Some(value), + _ => None, + } + } + + /// Consumes this `Outcome` and returns the input value if available, otherwise `None`. + pub fn try_into_input(self) -> Option { + match self { + Self::Failure { input, .. } + | Self::FailureWithSubtasks { input, .. } + | Self::Panic { input, .. } + | Self::PanicWithSubtasks { input, .. } => input, + Self::Unprocessed { input, .. } | Self::UnprocessedWithSubtasks { input, .. } => { + Some(input) + } + Self::Success { .. } | Self::SuccessWithSubtasks { .. } | Self::Missing { .. } => None, + #[cfg(feature = "retry")] + Self::MaxRetriesAttempted { input, .. } => Some(input), + } + } + + /// Consumes this `Outcome` and depending on the variant: + /// * Returns the wrapped error if this is a `Failure` or `MaxRetriesAttempted`, + /// * Resumes unwinding if this is a `Panic` outcome, + /// * Otherwise returns `None`. + pub fn try_into_error(self) -> Option { + match self { + Self::Failure { error, .. } | Self::FailureWithSubtasks { error, .. } => Some(error), + Self::Panic { payload, .. } | Self::PanicWithSubtasks { payload, .. } => { + payload.resume() + } + Self::Success { .. } + | Self::SuccessWithSubtasks { .. } + | Self::Unprocessed { .. } + | Self::UnprocessedWithSubtasks { .. } + | Self::Missing { .. } => None, + #[cfg(feature = "retry")] + Self::MaxRetriesAttempted { error, .. } => Some(error), + } + } +} + +impl PartialEq for Outcome { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Self::Success { task_id: a, .. }, Self::Success { task_id: b, .. }) + | ( + Self::SuccessWithSubtasks { task_id: a, .. }, + Self::SuccessWithSubtasks { task_id: b, .. }, + ) + | (Self::Failure { task_id: a, .. }, Self::Failure { task_id: b, .. }) + | ( + Self::FailureWithSubtasks { task_id: a, .. }, + Self::FailureWithSubtasks { task_id: b, .. }, + ) + | (Self::Unprocessed { task_id: a, .. }, Self::Unprocessed { task_id: b, .. }) + | ( + Self::UnprocessedWithSubtasks { task_id: a, .. }, + Self::UnprocessedWithSubtasks { task_id: b, .. }, + ) + | (Self::Missing { task_id: a }, Self::Missing { task_id: b }) + | (Self::Panic { task_id: a, .. }, Self::Panic { task_id: b, .. }) + | ( + Self::PanicWithSubtasks { task_id: a, .. }, + Self::PanicWithSubtasks { task_id: b, .. }, + ) => a == b, + #[cfg(feature = "retry")] + ( + Self::MaxRetriesAttempted { task_id: a, .. }, + Self::MaxRetriesAttempted { task_id: b, .. }, + ) => a == b, + _ => false, + } + } +} + +impl Eq for Outcome {} + +impl PartialOrd for Outcome { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for Outcome { + fn cmp(&self, other: &Self) -> Ordering { + self.task_id().cmp(other.task_id()) + } +} + +#[cfg(test)] +mod tests { + use super::Outcome; + use crate::bee::stock::EchoWorker; + use crate::panic::Panic; + + type Worker = EchoWorker; + type WorkerOutcome = Outcome; + + #[test] + fn test_try_into_input() { + let outcome = WorkerOutcome::Success { + value: 42, + task_id: 1, + }; + assert_eq!(outcome.try_into_input(), None); + + let outcome = WorkerOutcome::Failure { + input: None, + error: (), + task_id: 2, + }; + assert_eq!(outcome.try_into_input(), None); + + let outcome = WorkerOutcome::Failure { + input: Some(42), + error: (), + task_id: 2, + }; + assert_eq!(outcome.try_into_input(), Some(42)); + + let outcome = WorkerOutcome::Unprocessed { + input: 42, + task_id: 3, + }; + assert_eq!(outcome.try_into_input(), Some(42)); + + let outcome = WorkerOutcome::Missing { task_id: 4 }; + assert_eq!(outcome.try_into_input(), None); + + let outcome = WorkerOutcome::Panic { + input: None, + payload: Panic::try_call(None, || panic!()).unwrap_err(), + task_id: 5, + }; + assert_eq!(outcome.try_into_input(), None); + + let outcome = WorkerOutcome::Panic { + input: Some(42), + payload: Panic::try_call(None, || panic!()).unwrap_err(), + task_id: 5, + }; + assert_eq!(outcome.try_into_input(), Some(42)); + } + + #[test] + fn test_try_into_error() { + let outcome = WorkerOutcome::Success { + value: 42, + task_id: 1, + }; + assert_eq!(outcome.try_into_error(), None); + + let outcome = WorkerOutcome::Failure { + input: None, + error: (), + task_id: 2, + }; + assert_eq!(outcome.try_into_error(), Some(())); + + let outcome = WorkerOutcome::Failure { + input: Some(42), + error: (), + task_id: 2, + }; + assert_eq!(outcome.try_into_error(), Some(())); + + let outcome = WorkerOutcome::Unprocessed { + input: 42, + task_id: 3, + }; + assert_eq!(outcome.try_into_error(), None); + + let outcome = WorkerOutcome::Missing { task_id: 4 }; + assert_eq!(outcome.try_into_error(), None); + } + + #[test] + #[should_panic] + fn test_try_into_error_panic() { + WorkerOutcome::Panic { + input: None, + payload: Panic::try_call(None, || panic!()).unwrap_err(), + task_id: 5, + } + .try_into_error(); + } + + #[test] + fn test_eq() { + let outcome1 = WorkerOutcome::Success { + value: 42, + task_id: 1, + }; + let outcome2 = WorkerOutcome::Success { + value: 42, + task_id: 1, + }; + assert_eq!(outcome1, outcome2); + + let outcome3 = WorkerOutcome::Success { + value: 42, + task_id: 2, + }; + assert_ne!(outcome1, outcome3); + + let outcome4 = WorkerOutcome::Failure { + input: None, + error: (), + task_id: 1, + }; + assert_ne!(outcome1, outcome4); + } +} + +#[cfg(all(test, feature = "retry"))] +mod retry_tests { + use super::Outcome; + use crate::bee::stock::EchoWorker; + + type Worker = EchoWorker; + type WorkerOutcome = Outcome; + + #[test] + fn test_try_into_input() { + let outcome = WorkerOutcome::MaxRetriesAttempted { + input: 42, + error: (), + task_id: 1, + }; + assert_eq!(outcome.try_into_input(), Some(42)); + } + + #[test] + fn test_try_into_error() { + let outcome = WorkerOutcome::MaxRetriesAttempted { + input: 42, + error: (), + task_id: 1, + }; + assert_eq!(outcome.try_into_error(), Some(())); + } +} From 6e57c6aaf650840f7cec4146980f1d242aac53f6 Mon Sep 17 00:00:00 2001 From: jdidion Date: Tue, 18 Feb 2025 16:05:45 -0800 Subject: [PATCH 11/67] cleanup --- CHANGELOG.md | 4 +- README.md | 6 +- src/bee/worker.rs | 6 +- src/hive/hive.rs | 179 ++++++---- src/hive/inner/builder.rs | 53 +-- src/hive/inner/config.rs | 61 ++-- src/hive/inner/mod.rs | 9 +- src/hive/inner/queue/channel.rs | 512 ++++++++++++--------------- src/hive/inner/queue/delay.rs | 60 ++-- src/hive/inner/queue/mod.rs | 81 ++--- src/hive/inner/queue/workstealing.rs | 139 ++++++-- src/hive/inner/shared.rs | 213 ++++++----- src/hive/mod.rs | 78 ++-- src/hive/outcome/queue.rs | 14 +- 14 files changed, 770 insertions(+), 645 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9162526..2b7ce28 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,12 +19,12 @@ The general theme of this release is performance improvement by eliminating thre * Added the `TaskQueues` trait, which enables `Hive` to be specialized for different implementations of global (i.e., sending tasks from the `Hive` to worker threads) and local (i.e., worker thread-specific) queues. * `ChannelTaskQueues` implements the existing behavior, using a channel for sending tasks. * `WorkstealingTaskQueues` has been added to implement the workstealing pattern, based on `crossbeam::dequeue`. - * Added the `batching` feature, which enables worker threads to queue up batches of tasks locally, which can alleviate contention between threads in the pool, especially when there are many short-lived tasks. This feature is only used by `ChannelTaskQueues`. + * Added the `batching` feature, which enables `ChannelTaskQueues` to use worker-thread local queues to queue up batches of tasks locally, which can alleviate contention between threads in the pool, especially when there are many short-lived tasks. * Added the `Context::submit` method, which enables tasks to submit new tasks to the `Hive`. * Other * Switched to using thread-local retry queues for the implementation of the `retry` feature, to reduce thread-contention. * Switched to storing `Outcome`s in the hive using a data structure that does not require locking when inserting, which should reduce thread contention when using `*_store` operations. - * Switched to using `crossbeam_channel` for the task input channel in `ChannelTaskQueues`. + * Switched to using `crossbeam_channel` for the task input channel in `ChannelTaskQueues`. These are multi-produer, multi-consumer channels (mpmc; as opposed to `std::mpsc`, which is single-consumer), which means it is no longer necessary for worker threads to aquire a Mutex lock on the channel receiver when getting tasks. ## 0.2.1 diff --git a/README.md b/README.md index cc4e309..0bea8b5 100644 --- a/README.md +++ b/README.md @@ -337,9 +337,9 @@ if !exec_err_codes.is_empty() { ## Status -The `beekeeper` API is generally considered to be stable, but additional real-world battle-testing -is desired before promoting the version to `1.0.0`. If you identify bugs or have suggestions for -improvement, please [open an issue](https://github.com/jdidion/beekeeper/issues). +Early versions of this crate (< 0.3) had some fatal design flaws that needed to be corrected with breaking changes (see the [changelog](CHANGELOG.md)). + +As of version 0.3, the `beekeeper` API is generally considered to be stable, but additional real-world battle-testing is desired before promoting the version to `1.0.0`. If you identify bugs or have suggestions for improvement, please [open an issue](https://github.com/jdidion/beekeeper/issues). ## Similar libraries diff --git a/src/bee/worker.rs b/src/bee/worker.rs index adbc93b..cd85e16 100644 --- a/src/bee/worker.rs +++ b/src/bee/worker.rs @@ -133,11 +133,7 @@ mod tests { type Output = u8; type Error = (); - fn apply_ref( - &mut self, - input: &Self::Input, - _: &Context, - ) -> RefWorkerResult { + fn apply_ref(&mut self, input: &Self::Input, _: &Context) -> RefWorkerResult { match *input { 0 => Err(ApplyRefError::Retryable(())), 1 => Err(ApplyRefError::Fatal(())), diff --git a/src/hive/hive.rs b/src/hive/hive.rs index 5787250..19e86c8 100644 --- a/src/hive/hive.rs +++ b/src/hive/hive.rs @@ -1,9 +1,8 @@ use super::{ ChannelBuilder, ChannelTaskQueues, Config, DerefOutcomes, Husk, Outcome, OutcomeBatch, - OutcomeIteratorExt, OutcomeSender, Shared, SpawnError, TaskQueues, + OutcomeIteratorExt, OutcomeSender, Shared, SpawnError, TaskQueues, WorkerQueues, }; use crate::bee::{DefaultQueen, Queen, TaskContext, TaskId, Worker}; -use crossbeam_utils::Backoff; use std::collections::HashMap; use std::fmt; use std::ops::{Deref, DerefMut}; @@ -46,13 +45,15 @@ impl, T: TaskQueues> Hive { Self::init_thread(thread_index, &shared); // create a Sentinel that will spawn a new thread on panic until it is cancelled let sentinel = Sentinel::new(thread_index, Arc::clone(&shared)); + // get the thread-local interface to the task queues + let worker_queues = shared.worker_queues(thread_index); // create a new worker to process tasks let mut worker = shared.create_worker(); // execute the main loop: get the next task to process, which decrements the queued // counter and increments the active counter - while let Some(task) = shared.get_next_task(thread_index) { + while let Some(task) = shared.get_next_task(&worker_queues) { // execute the task and dispose of the outcome - Self::execute(task, thread_index, &mut worker, &shared); + Self::execute(task, &mut worker, &worker_queues, &shared); // finish the task - decrements the active counter and notifies other threads shared.finish_task(false); } @@ -93,6 +94,20 @@ impl, T: TaskQueues> Hive { self.grow(num_threads) } + /// Returns the batch limit for worker threads. + pub fn worker_batch_limit(&self) -> usize { + self.shared().worker_batch_limit() + } + + /// Sets the batch limit for worker threads. + /// + /// Depending on this hive's `TaskQueues` implementation, this method may: + /// * have no effect (if it does not support local batching) + /// * block the current thread until all worker thread queues can be resized. + pub fn set_worker_batch_limit(&self, batch_limit: usize) { + self.shared().set_worker_batch_limit(batch_limit); + } + /// Sends one `input` to the `Hive` for procesing and returns the result, blocking until the /// result is available. Creates a channel to send the input and receive the outcome. Returns /// an [`Outcome`] with the task output or an error. @@ -518,6 +533,42 @@ impl, T: TaskQueues> Hive { self.shared().take_outcomes() } + fn try_close(mut self) -> Option> { + if self.shared().num_referrers() > 1 { + return None; + } + // take the inner value and replace it with `None` + let shared = self.0.take().unwrap(); + // close the global queue to prevent new tasks from being submitted + shared.close_task_queues(); + // wait for all tasks to finish + shared.wait_on_done(); + // unwrap the Arc and return the inner Shared value + Some(super::unwrap_arc(shared)) + } + + /// Consumes this `Hive` and attempts to shut it down gracefully. + /// + /// All unprocessed tasks and stored outcomes are discarded. + /// + /// If this `Hive` has been cloned, and those clones have not been dropped, this method returns + /// `false`. + /// + /// Note that it is not necessary to call this method explicitly - all resources are dropped + /// automatically when the last clone of the hive is dropped. + pub fn close(self) -> bool { + self.try_close().is_some() + } + + /// Consumes this `Hive` and attempts to convert any remaining unprocessed tasks into + /// `Unprocessed` outcomes and either sends each to its outcome channel or adds it to the + /// stored outcomes. + /// + /// Returns a map of stored outcomes. + pub fn try_into_outcomes(self) -> Option>> { + self.try_close().map(|shared| shared.into_outcomes()) + } + /// Consumes this `Hive` and attempts to return a [`Husk`] containing the remnants of this /// `Hive`, including any stored task outcomes, and all the data necessary to create a new /// `Hive`. @@ -526,33 +577,8 @@ impl, T: TaskQueues> Hive { /// returns `None` since it cannot take exclusive ownership of the internal shared data. /// /// This method first joins on the `Hive` to wait for all tasks to finish. - pub fn try_into_husk(mut self) -> Option> { - if self.shared().num_referrers() > 1 { - return None; - } - // take the inner value and replace it with `None` - let mut shared = self.0.take().unwrap(); - // close the global queue to prevent new tasks from being submitted - shared.close(); - // wait for all tasks to finish - shared.wait_on_done(); - // wait for worker threads to drop, then take ownership of the shared data and convert it - // into a Husk - let mut backoff = None::; - loop { - // TODO: may want to have some timeout or other kind of limit to prevent this from - // looping forever if a worker thread somehow gets stuck, or if the `num_referrers` - // counter is corrupted - shared = match Arc::try_unwrap(shared) { - Ok(shared) => { - return Some(shared.into_husk()); - } - Err(shared) => { - backoff.get_or_insert_with(Backoff::new).spin(); - shared - } - }; - } + pub fn try_into_husk(self) -> Option> { + self.try_close().map(|shared| shared.into_husk()) } } @@ -767,37 +793,13 @@ mod affinity { } } -#[cfg(feature = "batching")] -mod batching { - use crate::bee::{Queen, Worker}; - use crate::hive::{Hive, TaskQueues}; - - impl Hive - where - W: Worker, - Q: Queen, - T: TaskQueues, - { - /// Returns the batch size for worker threads. - pub fn worker_batch_size(&self) -> usize { - self.shared().batch_size() - } - - /// Sets the batch size for worker threads. This will block the current thread until all - /// worker thread queues can be resized. - pub fn set_worker_batch_size(&self, batch_size: usize) { - self.shared().set_batch_size(batch_size); - } - } -} - struct HiveTaskContext<'a, W, Q, T> where W: Worker, Q: Queen, T: TaskQueues, { - thread_index: usize, + worker_queues: &'a T::WorkerQueues, shared: &'a Arc>, outcome_tx: Option<&'a OutcomeSender>, } @@ -813,8 +815,10 @@ where } fn submit_task(&self, input: W::Input) -> TaskId { - self.shared - .send_one_local(input, self.outcome_tx, self.thread_index) + let task = self.shared.prepare_task(input, self.outcome_tx); + let task_id = task.id(); + self.worker_queues.push(task); + task_id } } @@ -844,13 +848,13 @@ mod no_retry { { pub(super) fn execute( task: Task, - thread_index: usize, worker: &mut W, + worker_queues: &T::WorkerQueues, shared: &Arc>, ) { let (task_id, input, outcome_tx) = task.into_parts(); let task_ctx = HiveTaskContext { - thread_index, + worker_queues, shared, outcome_tx: outcome_tx.as_ref(), }; @@ -869,6 +873,7 @@ mod retry { use crate::bee::{ApplyError, Context, Queen, Worker}; use crate::hive::{Hive, Outcome, Shared, Task, TaskQueues}; use std::sync::Arc; + use std::time::Duration; impl Hive where @@ -876,15 +881,35 @@ mod retry { Q: Queen, T: TaskQueues, { + /// Returns the current retry limit for this hive. + pub fn worker_retry_limit(&self) -> u32 { + self.shared().worker_retry_limit() + } + + /// Updates the retry limit for this hive and returns the previous value. + pub fn set_worker_retry_limit(&self, limit: u32) -> u32 { + self.shared().set_worker_retry_limit(limit) + } + + /// Returns the current retry factor for this hive. + pub fn worker_retry_factor(&self) -> Duration { + self.shared().worker_retry_factor() + } + + /// Updates the retry factor for this hive and returns the previous value. + pub fn set_worker_retry_factor(&self, duration: Duration) -> Duration { + self.shared().set_worker_retry_factor(duration) + } + pub(super) fn execute( task: Task, - thread_index: usize, worker: &mut W, + worker_queues: &T::WorkerQueues, shared: &Arc>, ) { let (task_id, input, attempt, outcome_tx) = task.into_parts(); let task_ctx = HiveTaskContext { - thread_index, + worker_queues, shared, outcome_tx: outcome_tx.as_ref(), }; @@ -893,17 +918,31 @@ mod retry { // be the only place where a panic can occur let result = worker.apply(input, &ctx); let subtask_ids = ctx.into_subtask_ids(); - match result { - Err(ApplyError::Retryable { input, .. }) + #[cfg(feature = "retry")] + let result = match result { + Err(ApplyError::Retryable { input, error }) if subtask_ids.is_none() && shared.can_retry(attempt) => { - shared.send_retry(task_id, input, outcome_tx, attempt + 1, thread_index); + match shared.try_send_retry( + task_id, + input, + outcome_tx.as_ref(), + attempt + 1, + worker_queues, + ) { + Ok(_) => return, + Err(task) => Result::>::Err( + ApplyError::Fatal { + input: Some(task.into_parts().1), + error, + }, + ), + } } - result => { - let outcome = Outcome::from_worker_result(result, task_id, subtask_ids); - shared.send_or_store_outcome(outcome, outcome_tx); - } - } + result => result, + }; + let outcome = Outcome::from_worker_result(result, task_id, subtask_ids); + shared.send_or_store_outcome(outcome, outcome_tx); } } } diff --git a/src/hive/inner/builder.rs b/src/hive/inner/builder.rs index 024837e..ee2a850 100644 --- a/src/hive/inner/builder.rs +++ b/src/hive/inner/builder.rs @@ -137,6 +137,31 @@ pub trait Builder: BuilderConfig + Sized { self } + /// Sets the worker thread batch size. + /// + /// This may have no effect if the `batching` feature is disabled, or if the `TaskQueues` + /// implementation used for this hive does not support local batching. + /// + /// If `batch_limit` is `0`, batching is effectively disabled, but note that the performance + /// may be worse than with the `batching` feature disabled. + fn batch_limit(mut self, batch_limit: usize) -> Self { + if batch_limit == 0 { + self.config(Token).batch_limit.set(None); + } else { + self.config(Token).batch_limit.set(Some(batch_limit)); + } + self + } + + /// Sets the worker thread batch size to the global default value. + fn with_default_batch_limit(mut self) -> Self { + let _ = self + .config(Token) + .batch_limit + .set(super::config::DEFAULTS.lock().batch_limit.get()); + self + } + /// Sets set list of CPU core indices to which threads in the `Hive` should be pinned. /// /// Core indices are integers in the range `0..N`, where `N` is the number of available CPU @@ -188,28 +213,6 @@ pub trait Builder: BuilderConfig + Sized { self } - /// Sets the worker thread batch size. If `batch_size` is `0`, batching is disabled, but - /// note that the performance may be worse than with the `batching` feature disabled. - #[cfg(feature = "batching")] - fn batch_size(mut self, batch_size: usize) -> Self { - if batch_size == 0 { - self.config(Token).batch_size.set(None); - } else { - self.config(Token).batch_size.set(Some(batch_size)); - } - self - } - - /// Sets the worker thread batch size to the global default value. - #[cfg(feature = "batching")] - fn with_default_batch_size(mut self) -> Self { - let _ = self - .config(Token) - .batch_size - .set(super::config::DEFAULTS.lock().batch_size.get()); - self - } - /// Sets the maximum number of times to retry a /// [`ApplyError::Retryable`](crate::bee::ApplyError::Retryable) error. A worker /// thread will retry a task until it either returns @@ -298,10 +301,10 @@ pub trait Builder: BuilderConfig + Sized { /// ``` #[cfg(feature = "retry")] fn retry_factor(mut self, duration: std::time::Duration) -> Self { - let _ = if duration == std::time::Duration::ZERO { - self.config(Token).retry_factor.set(None) + if duration == std::time::Duration::ZERO { + let _ = self.config(Token).retry_factor.set(None); } else { - self.config(Token).set_retry_factor_from(duration) + let _ = self.config(Token).set_retry_factor_from(duration); }; self } diff --git a/src/hive/inner/config.rs b/src/hive/inner/config.rs index 6db735b..b1d2dd4 100644 --- a/src/hive/inner/config.rs +++ b/src/hive/inner/config.rs @@ -1,5 +1,3 @@ -#[cfg(feature = "batching")] -pub use self::batching::set_batch_size_default; #[cfg(feature = "retry")] pub use self::retry::{ set_max_retries_default, set_retries_default_disabled, set_retry_factor_default, @@ -10,6 +8,7 @@ use parking_lot::Mutex; use std::sync::LazyLock; const DEFAULT_NUM_THREADS: usize = 4; +const DEFAULT_BATCH_LIMIT: usize = 10; pub static DEFAULTS: LazyLock> = LazyLock::new(|| { let mut config = Config::empty(); @@ -28,6 +27,10 @@ pub fn set_num_threads_default_all() { set_num_threads_default(num_cpus::get()); } +pub fn set_batch_limit_default(batch_limit: usize) { + DEFAULTS.lock().batch_limit.set(Some(batch_limit)); +} + /// Resets all builder defaults to their original values. pub fn reset_defaults() { let mut config = DEFAULTS.lock(); @@ -41,10 +44,9 @@ impl Config { num_threads: Default::default(), thread_name: Default::default(), thread_stack_size: Default::default(), + batch_limit: Default::default(), #[cfg(feature = "affinity")] affinity: Default::default(), - #[cfg(feature = "batching")] - batch_size: Default::default(), #[cfg(feature = "retry")] max_retries: Default::default(), #[cfg(feature = "retry")] @@ -55,8 +57,7 @@ impl Config { /// Resets config values to their pre-configured defaults. fn set_const_defaults(&mut self) { self.num_threads.set(Some(DEFAULT_NUM_THREADS)); - #[cfg(feature = "batching")] - self.set_batch_const_defaults(); + self.batch_limit.set(Some(DEFAULT_BATCH_LIMIT)); #[cfg(feature = "retry")] self.set_retry_const_defaults(); } @@ -70,11 +71,11 @@ impl Config { #[cfg(feature = "affinity")] affinity: self.affinity.into_sync(), #[cfg(feature = "batching")] - batch_size: self.batch_size.into_sync_default(), + batch_limit: self.batch_limit.into_sync_default(), #[cfg(feature = "retry")] - max_retries: self.max_retries.into_sync(), + max_retries: self.max_retries.into_sync_default(), #[cfg(feature = "retry")] - retry_factor: self.retry_factor.into_sync(), + retry_factor: self.retry_factor.into_sync_default(), } } @@ -88,7 +89,7 @@ impl Config { #[cfg(feature = "affinity")] affinity: self.affinity.into_unsync(), #[cfg(feature = "batching")] - batch_size: self.batch_size.into_unsync(), + batch_limit: self.batch_limit.into_unsync(), #[cfg(feature = "retry")] max_retries: self.max_retries.into_unsync(), #[cfg(feature = "retry")] @@ -143,23 +144,6 @@ mod tests { } } -#[cfg(feature = "batching")] -mod batching { - use super::{Config, DEFAULTS}; - - const DEFAULT_BATCH_SIZE: usize = 10; - - pub fn set_batch_size_default(batch_size: usize) { - DEFAULTS.lock().batch_size.set(Some(batch_size)); - } - - impl Config { - pub(super) fn set_batch_const_defaults(&mut self) { - self.batch_size.set(Some(DEFAULT_BATCH_SIZE)); - } - } -} - #[cfg(feature = "retry")] mod retry { use super::{Config, DEFAULTS}; @@ -184,8 +168,21 @@ mod retry { } impl Config { - pub fn set_retry_factor_from(&mut self, duration: Duration) -> Option { - self.retry_factor.set(Some(duration.as_nanos() as u64)) + pub fn get_retry_factor_duration(&self) -> Option { + self.retry_factor.get().map(Duration::from_nanos) + } + + pub fn set_retry_factor_from(&mut self, duration: Duration) -> Option { + self.retry_factor + .set(Some(duration.as_nanos() as u64)) + .map(Duration::from_nanos) + } + + pub fn try_set_retry_factor_from(&self, duration: Duration) -> Option { + self.retry_factor + .try_set(duration.as_nanos() as u64) + .map(Duration::from_nanos) + .ok() } pub(super) fn set_retry_const_defaults(&mut self) { @@ -203,12 +200,6 @@ mod retry { use serial_test::serial; use std::time::Duration; - impl Config { - fn get_retry_factor_duration(&self) -> Option { - self.retry_factor.get().map(Duration::from_nanos) - } - } - #[test] #[serial] fn test_set_max_retries_default() { diff --git a/src/hive/inner/mod.rs b/src/hive/inner/mod.rs index 9ce3330..82af070 100644 --- a/src/hive/inner/mod.rs +++ b/src/hive/inner/mod.rs @@ -8,7 +8,7 @@ mod task; pub mod set_config { #[cfg(feature = "batching")] - pub use super::config::set_batch_size_default; + pub use super::config::set_batch_limit_default; pub use super::config::{reset_defaults, set_num_threads_default, set_num_threads_default_all}; #[cfg(feature = "retry")] pub use super::config::{ @@ -17,7 +17,7 @@ pub mod set_config { } pub use self::builder::{Builder, BuilderConfig}; -pub use self::queue::{ChannelTaskQueues, TaskQueues}; +pub use self::queue::{ChannelTaskQueues, TaskQueues, WorkerQueues}; use self::counter::DualCounter; use self::gate::{Gate, PhasedGate}; @@ -92,12 +92,11 @@ pub struct Config { thread_name: Any, /// Stack size for each worker thread thread_stack_size: Usize, + /// Maximum number of tasks for a worker thread to take when receiving from the input channel + batch_limit: Usize, /// CPU cores to which worker threads can be pinned #[cfg(feature = "affinity")] affinity: Any, - /// Maximum number of tasks for a worker thread to take when receiving from the input channel - #[cfg(feature = "batching")] - batch_size: Usize, /// Maximum number of retries for a task #[cfg(feature = "retry")] max_retries: U32, diff --git a/src/hive/inner/queue/channel.rs b/src/hive/inner/queue/channel.rs index fdcff43..aa5012a 100644 --- a/src/hive/inner/queue/channel.rs +++ b/src/hive/inner/queue/channel.rs @@ -1,12 +1,13 @@ //! Implementation of `TaskQueues` that uses `crossbeam` channels for the global queue (i.e., for //! sending tasks from the `Hive` to the worker threads) and a default implementation of local //! queues that depends on which combination of the `retry` and `batching` features are enabled. -use super::{PopTaskError, Task, TaskQueues, Token}; +use super::{Config, PopTaskError, Task, TaskQueues, Token, WorkerQueues}; use crate::atomic::{Atomic, AtomicBool}; -use crate::bee::{Queen, Worker}; -use crate::hive::inner::Shared; +use crate::bee::Worker; use crossbeam_channel::RecvTimeoutError; +use crossbeam_queue::SegQueue; use parking_lot::RwLock; +use std::sync::Arc; use std::time::Duration; // time to wait in between polling the retry queue and then the task receiver @@ -18,134 +19,92 @@ type TaskSender = crossbeam_channel::Sender>; type TaskReceiver = crossbeam_channel::Receiver>; pub struct ChannelTaskQueues { - global_tx: TaskSender, - global_rx: TaskReceiver, - closed: AtomicBool, - /// thread-local queues of tasks used when the `batching` feature is enabled - #[cfg(feature = "batching")] - local_batch_queues: RwLock>>>, - /// thread-local queues used for tasks that are waiting to be retried after a failure - #[cfg(feature = "retry")] - local_retry_queues: RwLock>>>, + global: Arc>, + local: RwLock>>>, } impl TaskQueues for ChannelTaskQueues { + type WorkerQueues = ChannelWorkerQueues; + fn new(_: Token) -> Self { - let (tx, rx) = crossbeam_channel::unbounded(); Self { - global_tx: tx, - global_rx: rx, - closed: AtomicBool::default(), - #[cfg(feature = "batching")] - local_batch_queues: Default::default(), - #[cfg(feature = "retry")] - local_retry_queues: Default::default(), + global: Arc::new(GlobalQueue::new()), + local: Default::default(), } } - fn init_for_threads>( - &self, - start_index: usize, - end_index: usize, - #[allow(unused_variables)] shared: &Shared, - ) { - #[cfg(feature = "batching")] - self.init_batch_queues_for_threads(start_index, end_index, shared); - #[cfg(feature = "retry")] - self.init_retry_queues_for_threads(start_index, end_index); + fn init_for_threads(&self, start_index: usize, end_index: usize, config: &Config) { + let mut local_queues = self.local.write(); + assert_eq!(local_queues.len(), start_index); + (start_index..end_index).for_each(|thread_index| { + local_queues.push(Arc::new(ChannelWorkerQueues::new( + thread_index, + &self.global, + config, + ))) + }); } - fn try_push_global(&self, task: Task) -> Result<(), Task> { - if !self.closed.get() { - self.global_tx - .try_send(task) - .map_err(|err| err.into_inner()) - } else { - Err(task) - } + fn update_for_threads(&self, start_index: usize, end_index: usize, config: &Config) { + let local_queues = self.local.write(); + assert!(local_queues.len() > end_index); + local_queues[start_index..end_index] + .iter() + .for_each(|queue| queue.update(config)); } - /// Creates a task from `input` and pushes it to the local queue if there is space, - /// otherwise attempts to add it to the global queue. Returns the task ID if the push - /// succeeds, otherwise returns an error with the input. - fn push_local>( - &self, - task: Task, - #[allow(unused_variables)] thread_index: usize, - shared: &Shared, - ) { - #[cfg(feature = "batching")] - let task = match self.try_push_local(task, thread_index) { - Ok(_) => return, - Err(task) => task, - }; - shared.push_global(task); + fn worker_queues(&self, thread_index: usize) -> Arc { + Arc::clone(&self.local.read()[thread_index]) } - /// Returns the next task from the local queue if there are any, otherwise attempts to - /// fetch at least 1 and up to `batch_size + 1` tasks from the input channel and puts all - /// but the first one into the local queue. - fn try_pop>( - &self, - thread_index: usize, - #[allow(unused_variables)] shared: &Shared, - ) -> Result, PopTaskError> { - // try to get a task from the local queues - #[cfg(feature = "retry")] - if let Some(task) = self.try_pop_retry(thread_index) { - return Ok(task); - } - #[cfg(feature = "batching")] - if let Some(task) = self.try_pop_local_or_refill(thread_index, shared) { - return Ok(task); - } - // fall back to requesting a task from the global queue - self.try_pop_timeout(RECV_TIMEOUT) + fn try_push_global(&self, task: Task) -> Result<(), Task> { + self.global.try_push(task) } - fn drain(&self, _: Token) -> Vec> { - let mut tasks = Vec::from_iter(self.global_rx.try_iter()); - #[cfg(feature = "batching")] - { - self.drain_batch_queues_into(&mut tasks); - } - #[cfg(feature = "retry")] - { - self.drain_retry_queues_into(&mut tasks); + fn close(&self, _: Token) { + self.global.close() + } + + fn drain(self) -> Vec> { + self.close(Token); + let mut tasks = Vec::new(); + let global = crate::hive::unwrap_arc(self.global); + global.drain_into(&mut tasks); + for local in self.local.into_inner().into_iter() { + let local = crate::hive::unwrap_arc(local); + local.drain_into(&mut tasks); } tasks } +} - #[cfg(feature = "batching")] - fn resize_local>( - &self, - start_index: usize, - end_index: usize, - new_size: usize, - shared: &Shared, - ) { - self.resize_batch_queues(start_index, end_index, new_size, shared); - } +pub struct GlobalQueue { + global_tx: TaskSender, + global_rx: TaskReceiver, + closed: AtomicBool, +} - #[cfg(feature = "retry")] - fn retry>( - &self, - task: Task, - thread_index: usize, - shared: &Shared, - ) -> Option { - self.try_push_retry(task, thread_index, shared) +impl GlobalQueue { + fn new() -> Self { + let (tx, rx) = crossbeam_channel::unbounded(); + Self { + global_tx: tx, + global_rx: rx, + closed: Default::default(), + } } - fn close(&self, _: Token) { - self.closed.set(true); + #[inline] + fn try_push(&self, task: Task) -> Result<(), Task> { + if self.closed.get() { + return Err(task); + } + self.global_tx.send(task).map_err(|err| err.into_inner()) } -} -impl ChannelTaskQueues { #[inline] - fn try_pop_timeout(&self, timeout: Duration) -> Result, PopTaskError> { - match self.global_rx.recv_timeout(timeout) { + fn try_pop(&self) -> Result, PopTaskError> { + match self.global_rx.recv_timeout(RECV_TIMEOUT) { Ok(task) => Ok(task), Err(RecvTimeoutError::Disconnected) => Err(PopTaskError::Closed), Err(RecvTimeoutError::Timeout) if self.closed.get() && self.global_rx.is_empty() => { @@ -154,215 +113,188 @@ impl ChannelTaskQueues { Err(RecvTimeoutError::Timeout) => Err(PopTaskError::Empty), } } -} -#[cfg(feature = "retry")] -impl ChannelTaskQueues { - #[inline] - fn try_pop_retry(&self, thread_index: usize) -> Option> { - self.local_retry_queues - .read() - .get(thread_index) - .and_then(|queue| queue.try_pop()) + fn try_iter(&self) -> impl Iterator> + '_ { + self.global_rx.try_iter() + } + + fn close(&self) { + self.closed.set(true); + } + + fn drain_into(self, tasks: &mut Vec>) { + tasks.extend(self.global_rx.try_iter()); } } -#[cfg(feature = "batching")] -impl ChannelTaskQueues { - #[inline] - fn try_push_local(&self, task: Task, thread_index: usize) -> Result<(), Task> { - self.local_batch_queues.read()[thread_index].push(task) +pub struct ChannelWorkerQueues { + _thread_index: usize, + global: Arc>, + /// queue of abandon tasks + local_abandoned: SegQueue>, + /// thread-local queue of tasks used when the `batching` feature is enabled + #[cfg(feature = "batching")] + local_batch: RwLock>>, + /// thread-local queues used for tasks that are waiting to be retried after a failure + #[cfg(feature = "retry")] + local_retry: super::delay::DelayQueue>, + #[cfg(feature = "retry")] + retry_factor: crate::atomic::AtomicU64, +} + +impl ChannelWorkerQueues { + fn new(thread_index: usize, global_queue: &Arc>, config: &Config) -> Self { + Self { + _thread_index: thread_index, + global: Arc::clone(global_queue), + local_abandoned: Default::default(), + #[cfg(feature = "batching")] + local_batch: RwLock::new(crossbeam_queue::ArrayQueue::new( + config.batch_limit.get_or_default().max(1), + )), + #[cfg(feature = "retry")] + local_retry: Default::default(), + #[cfg(feature = "retry")] + retry_factor: crate::atomic::AtomicU64::new(config.retry_factor.get_or_default()), + } } - #[inline] - fn try_pop_local_or_refill>( - &self, - thread_index: usize, - shared: &Shared, - ) -> Option> { - let local_queue = &self.local_batch_queues.read()[thread_index]; - // pop from the local queue if it has any tasks - if !local_queue.is_empty() { - return local_queue.pop(); + /// Updates the local queues based on the provided `config`: + /// If `batching` is enabled, resizes the batch queue if necessary. + /// If `retry` is enabled, updates the retry factor. + fn update(&self, config: &Config) { + #[cfg(feature = "batching")] + self.update_batch(config); + #[cfg(feature = "retry")] + self.retry_factor.set(config.retry_factor.get_or_default()); + } + + /// Consumes this `ChannelWorkerQueues` and drains the tasks currently in the queues into + /// `tasks`. + fn drain_into(self, tasks: &mut Vec>) { + while let Some(task) = self.local_abandoned.pop() { + tasks.push(task); } - // otherwise pull at least 1 and up to `batch_size + 1` tasks from the input channel - // wait for the next task from the receiver - let first = self.try_pop_timeout(RECV_TIMEOUT).ok(); - // if we fail after trying to get one, don't keep trying to fill the queue - if first.is_some() { - let batch_size = shared.batch_size(); - // batch size 0 means batching is disabled - if batch_size > 0 { - // otherwise try to take up to `batch_size` tasks from the input channel - // and add them to the local queue, but don't block if the input channel - // is empty - for result in self - .global_rx - .try_iter() - .take(batch_size) - .map(|task| local_queue.push(task)) - { - if let Err(task) = result { - // for some reason we can't push the task to the local queue; - // this should never happen, but just in case we turn it into an - // unprocessed outcome and stop iterating - shared.abandon_task(task); - break; - } - } + #[cfg(feature = "batching")] + { + let batch = self.local_batch.into_inner(); + tasks.reserve(batch.len()); + while let Some(task) = batch.pop() { + tasks.push(task); } } - first + #[cfg(feature = "retry")] + self.local_retry.drain_into(tasks); + } +} + +impl WorkerQueues for ChannelWorkerQueues { + fn push(&self, task: Task) { + #[cfg(feature = "batching")] + let task = match self.local_batch.read().push(task) { + Ok(_) => return, + Err(task) => task, + }; + let task = match self.global.try_push(task) { + Ok(_) => return, + Err(task) => task, + }; + self.local_abandoned.push(task); + } + + fn try_pop(&self) -> Result, PopTaskError> { + // first try to get a previously abandoned task + if let Some(task) = self.local_abandoned.pop() { + return Ok(task); + } + // if retry is enabled, try to get a task from the retry queue + #[cfg(feature = "retry")] + if let Some(task) = self.local_retry.try_pop() { + return Ok(task); + } + // if batching is enabled, try to get a task from the batch queue + // and try to refill it from the global queue if it's empty + #[cfg(feature = "batching")] + { + self.try_pop_batch_or_refill().ok_or(PopTaskError::Empty) + } + // fall back to requesting a task from the global queue + #[cfg(not(feature = "batching"))] + self.global.try_pop() + } + + #[cfg(feature = "retry")] + fn try_push_retry(&self, task: Task) -> Result> { + // compute the delay + let delay = 2u64 + .checked_pow(task.attempt - 1) + .and_then(|multiplier| { + self.retry_factor + .get() + .checked_mul(multiplier) + .or(Some(u64::MAX)) + .map(Duration::from_nanos) + }) + .unwrap_or_default(); + self.local_retry.push(task, delay) } } #[cfg(feature = "batching")] mod batching { - use super::{ChannelTaskQueues, Task}; - use crate::bee::{Queen, Worker}; - use crate::hive::inner::Shared; + use super::{ChannelWorkerQueues, Config, Task}; + use crate::bee::Worker; use crossbeam_queue::ArrayQueue; - use std::collections::HashSet; use std::time::Duration; - impl ChannelTaskQueues { - pub(super) fn init_batch_queues_for_threads>( - &self, - start_index: usize, - end_index: usize, - shared: &Shared, - ) { - let mut batch_queues = self.local_batch_queues.write(); - assert_eq!(batch_queues.len(), start_index); - let queue_size = shared.batch_size().max(1); - (start_index..end_index).for_each(|_| batch_queues.push(ArrayQueue::new(queue_size))); - } - - pub(super) fn resize_batch_queues>( - &self, - start_index: usize, - end_index: usize, - batch_size: usize, - shared: &Shared, - ) { - // keep track of which queues need to be resized - // TODO: this method could cause a hang if one of the worker threads is stuck - we - // might want to keep track of each queue's size and if we don't see it shrink - // within a certain amount of time, we give up on that thread and leave it with a - // wrong-sized queue (which should never cause a panic) - let mut to_resize: HashSet = (start_index..end_index).collect(); - // iterate until we've resized them all - loop { - // scope the mutable access to local_queues + impl ChannelWorkerQueues { + pub fn update_batch(&self, config: &Config) { + let batch_limit = config.batch_limit.get_or_default().max(1); + let mut queue = self.local_batch.write(); + // block until the current queue is small enough that it can fit into the new queue + while queue.len() > batch_limit { + std::thread::sleep(Duration::from_millis(10)); + } + let new_queue = ArrayQueue::new(batch_limit); + while let Some(task) = queue.pop() { + if let Err(task) = new_queue + .push(task) + .or_else(|task| self.global.try_push(task)) { - let mut batch_queues = self.local_batch_queues.write(); - to_resize.retain(|thread_index| { - let queue = if let Some(queue) = batch_queues.get_mut(*thread_index) { - queue - } else { - return false; - }; - if queue.len() > batch_size { - return true; - } - let new_queue = ArrayQueue::new(batch_size); - while let Some(task) = queue.pop() { - if let Err(task) = new_queue.push(task) { - // for some reason we can't push the task to the new queue - // this should never happen, but just in case we turn it into - // an unprocessed outcome - shared.abandon_task(task); - } - } - // this is safe because the worker threads can't get readable access to the - // queue while this thread holds the lock - let old_queue = std::mem::replace(queue, new_queue); - assert!(old_queue.is_empty()); - false - }); - } - if !to_resize.is_empty() { - // short sleep to give worker threads the chance to pull from their queues - std::thread::sleep(Duration::from_millis(10)); + self.local_abandoned.push(task); + break; } } + assert!(queue.is_empty()); + *queue = new_queue; } - pub(super) fn drain_batch_queues_into(&self, tasks: &mut Vec>) { - let _ = self - .local_batch_queues - .write() - .iter_mut() - .fold(tasks, |tasks, queue| { - tasks.reserve(queue.len()); - while let Some(task) = queue.pop() { - tasks.push(task); + pub(super) fn try_pop_batch_or_refill(&self) -> Option> { + // pop from the local queue if it has any tasks + let local_queue = self.local_batch.read(); + if !local_queue.is_empty() { + return local_queue.pop(); + } + // otherwise pull at least 1 and up to `batch_limit + 1` tasks from the input channel + // wait for the next task from the receiver + let first = self.global.try_pop().ok(); + // if we fail after trying to get one, don't keep trying to fill the queue + if first.is_some() { + let batch_limit = local_queue.capacity(); + // batch size 0 means batching is disabled + if batch_limit > 0 { + // otherwise try to take up to `batch_limit` tasks from the input channel + // and add them to the local queue, but don't block if the input channel + // is empty + for task in self.global.try_iter().take(batch_limit) { + if let Err(task) = local_queue.push(task) { + self.local_abandoned.push(task); + break; + } } - tasks - }); - } - } -} - -#[cfg(feature = "retry")] -mod retry { - use super::{ChannelTaskQueues, Task}; - use crate::bee::{Queen, Worker}; - use crate::hive::inner::queue::delay::DelayQueue; - use crate::hive::inner::Shared; - use std::time::{Duration, Instant}; - - impl ChannelTaskQueues { - /// Initializes the retry queues worker threads in the specified range. - pub(super) fn init_retry_queues_for_threads(&self, start_index: usize, end_index: usize) { - let mut retry_queues = self.local_retry_queues.write(); - assert_eq!(retry_queues.len(), start_index); - (start_index..end_index).for_each(|_| retry_queues.push(DelayQueue::default())) - } - - /// Adds a task to the retry queue with a delay based on `attempt`. - pub(super) fn try_push_retry>( - &self, - task: Task, - thread_index: usize, - shared: &Shared, - ) -> Option { - // compute the delay - let delay = shared - .config - .retry_factor - .get() - .map(|retry_factor| { - 2u64.checked_pow(task.attempt - 1) - .and_then(|multiplier| { - retry_factor - .checked_mul(multiplier) - .or(Some(u64::MAX)) - .map(Duration::from_nanos) - }) - .unwrap() - }) - .unwrap_or_default(); - if let Some(queue) = self.local_retry_queues.read().get(thread_index) { - queue.push(task, delay) - } else { - Err(task) + } } - // if unable to queue the task, abandon it - .map_err(|task| shared.abandon_task(task)) - .ok() - } - - pub(super) fn drain_retry_queues_into(&self, tasks: &mut Vec>) { - let _ = self - .local_retry_queues - .write() - .iter_mut() - .fold(tasks, |tasks, queue| { - tasks.reserve(queue.len()); - tasks.extend(queue.drain()); - tasks - }); + first } } } diff --git a/src/hive/inner/queue/delay.rs b/src/hive/inner/queue/delay.rs index 7faa6c4..1073808 100644 --- a/src/hive/inner/queue/delay.rs +++ b/src/hive/inner/queue/delay.rs @@ -7,20 +7,21 @@ use std::time::{Duration, Instant}; /// /// This is implemented internally as a `UnsafeCell`. /// -/// SAFETY: This data structure is designed to enable the queue to be modified by a *single thread* -/// using interior mutability. `UnsafeCell` is used for performance - this is safe so long as the -/// queue is only accessed from a single thread at a time. This data structure is *not* thread-safe. +/// SAFETY: This data structure is designed to enable the queue to be modified (using `push` and +/// `try_pop`) by a *single thread* using interior mutability. The `drain` method is called by a +/// different thread, but it first takes ownership of the queue and so will never be called +/// concurrently with `push/pop`. +/// +/// `UnsafeCell` is used for performance - this is safe so long as the queue is only accessed from +/// a single thread at a time. This data structure is *not* thread-safe. #[derive(Debug)] pub struct DelayQueue(UnsafeCell>>); impl DelayQueue { - /// Returns the number of items currently in the queue. - pub fn len(&self) -> usize { - unsafe { self.0.get().as_ref().unwrap().len() } - } - /// Pushes an item onto the queue. Returns the `Instant` at which the item will be available, /// or an error with `item` if there was an error pushing the item. + /// + /// SAFETY: this method is only ever called within a single thread. pub fn push(&self, item: T, delay: Duration) -> Result { unsafe { match self.0.get().as_mut() { @@ -35,28 +36,20 @@ impl DelayQueue { } } - /// Returns the `Instant` at which the next item will be available. Returns `None` if the queue - /// is empty. - pub fn next_available(&self) -> Option { - unsafe { - self.0 - .get() - .as_ref() - .and_then(|queue| queue.peek().map(|head| head.until)) - } - } - /// Returns the item at the head of the queue, if one exists and is available (i.e., its delay /// has been exceeded), and removes it. + /// + /// SAFETY: this method is only ever called within a single thread. pub fn try_pop(&self) -> Option { unsafe { - if self - .next_available() - .map(|until| until <= Instant::now()) + let queue_ptr = self.0.get(); + if queue_ptr + .as_ref() + .and_then(|queue| queue.peek()) + .map(|head| head.until <= Instant::now()) .unwrap_or(false) { - self.0 - .get() + queue_ptr .as_mut() .and_then(|queue| queue.pop()) .map(|delayed| delayed.value) @@ -66,9 +59,11 @@ impl DelayQueue { } } - /// Drains all items from the queue and returns them as an iterator. - pub fn drain(&mut self) -> impl Iterator + '_ { - self.0.get_mut().drain().map(|delayed| delayed.value) + /// Consumes this `DelayQueue` and drains all items from the queue into `sink`. + pub fn drain_into(self, sink: &mut Vec) { + let mut queue = self.0.into_inner(); + sink.reserve(queue.len()); + sink.extend(queue.drain().map(|delayed| delayed.value)) } } @@ -125,6 +120,12 @@ mod tests { use super::DelayQueue; use std::{thread, time::Duration}; + impl DelayQueue { + fn len(&self) -> usize { + unsafe { self.0.get().as_ref().unwrap().len() } + } + } + #[test] fn test_works() { let queue = DelayQueue::default(); @@ -153,11 +154,12 @@ mod tests { #[test] fn test_into_vec() { - let mut queue = DelayQueue::default(); + let queue = DelayQueue::default(); queue.push(1, Duration::from_secs(1)).unwrap(); queue.push(2, Duration::from_secs(2)).unwrap(); queue.push(3, Duration::from_secs(3)).unwrap(); - let mut v: Vec<_> = queue.drain().collect(); + let mut v = Vec::new(); + queue.drain_into(&mut v); v.sort(); assert_eq!(v, vec![1, 2, 3]); } diff --git a/src/hive/inner/queue/mod.rs b/src/hive/inner/queue/mod.rs index b0a4f59..337ebaa 100644 --- a/src/hive/inner/queue/mod.rs +++ b/src/hive/inner/queue/mod.rs @@ -5,8 +5,9 @@ mod delay; pub use self::channel::ChannelTaskQueues; -use super::{Shared, Task, Token}; -use crate::bee::{Queen, Worker}; +use super::{Config, Task, Token}; +use crate::bee::Worker; +use std::sync::Arc; /// Errors that may occur when trying to pop tasks from the global queue. #[derive(thiserror::Error, Debug)] @@ -22,43 +23,46 @@ pub enum PopTaskError { /// /// This trait is sealed - it cannot be implemented outside of this crate. pub trait TaskQueues: Sized + Send + Sync + 'static { + type WorkerQueues: WorkerQueues; + /// Returns a new instance. + /// + /// The private `Token` is used to prevent this method from being called externally. fn new(token: Token) -> Self; /// Initializes the local queues for the given range of worker thread indices. - fn init_for_threads>( - &self, - start_index: usize, - end_index: usize, - shared: &Shared, - ); + fn init_for_threads(&self, start_index: usize, end_index: usize, config: &Config); + + /// Updates the queue settings from `config` for the given range of worker threads. + fn update_for_threads(&self, start_index: usize, end_index: usize, config: &Config); - /// Changes the size of the local queues to `new_size`. - #[cfg(feature = "batching")] - fn resize_local>( - &self, - start_index: usize, - end_index: usize, - new_size: usize, - shared: &Shared, - ); + /// Returns a new `WorkerQueues` instance for a thread. + fn worker_queues(&self, thread_index: usize) -> Arc; /// Tries to add a task to the global queue. /// /// Returns an error with the task if the queue is disconnected. fn try_push_global(&self, task: Task) -> Result<(), Task>; - /// Attempts to add a task to the local queue if space is available, otherwise adds it to the - /// global queue. + /// Closes this `GlobalQueue` so no more tasks may be pushed. + /// + /// The private `Token` is used to prevent this method from being called externally. + fn close(&self, token: Token); + + /// Drains all tasks from all global and local queues and returns them as a `Vec`. /// - /// If adding to the global queue fails, the task is abandoned (converted to an - /// `Outcome::Unprocessed` and sent to the outcome channel or stored in the hive). - fn push_local>( - &self, - task: Task, - thread_index: usize, - shared: &Shared, - ); + /// This is a destructive operation - if `close` has not been called, it will be called before + /// draining the queues. + fn drain(self) -> Vec>; +} + +/// Trait that provides access to the task queues to each worker thread. Implementations of this +/// trait can hold thread-local types that are not Send/Sync. +pub trait WorkerQueues { + /// Attempts to add a task to the local queue if space is available, otherwise adds it to the + /// global queue. If adding to the global queue fails, the task is added to a local "abandoned" + /// queue from which it may be popped or will otherwise be converted. + fn push(&self, task: Task); /// Attempts to remove a task from the local queue for the given worker thread index. If there /// are no local queues, or if the local queues are empty, falls back to taking a task from the @@ -68,29 +72,12 @@ pub trait TaskQueues: Sized + Send + Sync + 'static { /// definition of "available". /// /// Also returns an error if the queue is empty or disconnected. - fn try_pop>( - &self, - thread_index: usize, - shared: &Shared, - ) -> Result, PopTaskError>; - - /// Drains all tasks from all global and local queues and returns them as a `Vec`. - fn drain(&self, token: Token) -> Vec>; + fn try_pop(&self) -> Result, PopTaskError>; /// Attempts to add `task` to the local retry queue. /// /// Returns the earliest `Instant` at which it might be retried. If the task could not be added - /// to the retry queue (e.g., if the queue is full), the task is abandoned (converted to - /// `Outcome::Unprocessed` and sent to the outcome channel or stored in the hive) and this - /// method returns `None`. + /// to the retry queue (e.g., if the queue is full), the task returned as an error. #[cfg(feature = "retry")] - fn retry>( - &self, - task: Task, - thread_index: usize, - shared: &Shared, - ) -> Option; - - /// Closes this `GlobalQueue` so no more tasks may be pushed. - fn close(&self, token: Token); + fn try_push_retry(&self, task: Task) -> Result>; } diff --git a/src/hive/inner/queue/workstealing.rs b/src/hive/inner/queue/workstealing.rs index 345db3e..a754478 100644 --- a/src/hive/inner/queue/workstealing.rs +++ b/src/hive/inner/queue/workstealing.rs @@ -1,36 +1,135 @@ -use super::{GlobalTaskQueue, LocalTaskQueues, Task}; -use crate::bee::{Queen, Worker}; -use crate::hive::Shared; -use crossbeam_deque::{Injector, Stealer, Worker as LocalQueue}; -use std::marker::PhantomData; - -struct GlobalQueue { - queue: Injector>, - _worker: PhantomData, +use super::{Config, PopTaskError, Task, TaskQueues, Token, WorkerQueues}; +use crate::atomic::{Atomic, AtomicBool, AtomicUsize}; +use crate::bee::Worker; +use crossbeam_deque::{Injector as GlobalQueue, Stealer, Worker as LocalQueue}; +use parking_lot::{Mutex, RwLock}; + +pub struct WorkstealingWorkerQueues { + local_batch: LocalQueue, + batch_limit: AtomicUsize, + /// thread-local queues used for tasks that are waiting to be retried after a failure + #[cfg(feature = "retry")] + local_retry: super::delay::DelayQueue>, + #[cfg(feature = "retry")] + retry_factor: crate::atomic::AtomicU64, } -impl GlobalTaskQueue for GlobalQueue { - fn try_push(&self, task: Task) -> Result<(), Task> { - self.queue.push(task); +impl WorkerQueues for WorkstealingWorkerQueues { + fn push(&self, task: Task) { + todo!() } - fn try_pop(&self) -> Result, super::PopTaskError> { + fn try_pop(&self) -> Result, PopTaskError> { todo!() } - fn try_iter(&self) -> impl Iterator> + '_ { + fn try_push_retry(&self, task: Task) -> Result> { todo!() } +} - fn drain(&self) -> Vec> { - todo!() +struct WorkstealingQueues { + global_queue: GlobalQueue>, + local_queues: RwLock>>, + local_stealers: RwLock>>>, + closed: AtomicBool, +} + +impl TaskQueues for WorkstealingQueues { + type WorkerQueues = WorkstealingWorkerQueues; + + fn new(_: Token) -> Self { + Self { + global_queue: Default::default(), + local_worker_queues: Default::default(), + local_stealers: Default::default(), + batch_limit: AtomicUsize::new(config.batch_limit.get_or_default()), + closed: Default::default(), + #[cfg(feature = "retry")] + local_retry_queues: Default::default(), + #[cfg(feature = "retry")] + retry_factor: crate::atomic::AtomicU64::new(config.retry_factor.get_or_default()), + } + } + + fn init_for_threads(&self, start_index: usize, end_index: usize, config: &Config) { + let mut local_queues = self.local_worker_queues.write(); + assert_eq!(local_queues.len(), start_index); + let mut stealers = self.local_stealers.write(); + (start_index..end_index).for_each(|_| { + let local_queue = LocalQueue::new_fifo(); + let stealer = local_queue.stealer(); + local_queues.push(Mutex::new(local_queue)); + stealers.push(stealer); + }); + //#[cfg(feature = "retry")] + //self.init_retry_queues_for_threads(start_index, end_index); } - fn close(&self) { + fn update_for_threads(&self, start_index: usize, end_index: usize, config: &Config) { todo!() } -} -struct WorkerQueue {} + fn try_push_global(&self, task: Task) -> Result<(), Task> { + if !self.closed.get() { + self.global_queue.push(task); + Ok(()) + } else { + Err(task) + } + } + + fn try_push_local(&self, task: Task, thread_index: usize) -> Result<(), Task> { + self.local_worker_queues.read()[thread_index] + .lock() + .push(task); + Ok(()) + } + + fn try_pop(&self, thread_index: usize) -> Result, PopTaskError> { + // first try popping from the local queue + { + let worker_queue_mutex = &self.local_worker_queues.read()[thread_index]; + let worker_queue = worker_queue_mutex.lock(); + worker_queue.pop().or_else(|| { + self.global_queue + .steal_batch_with_limit_and_pop(&worker_queue, self.batch_limit.get()) + .success() + }) + } + .or_else(|| { + // TODO: randomize the order + self.local_stealers + .read() + .iter() + .filter_map(|stealer| stealer.steal().success()) + .next() + }) + .ok_or(PopTaskError::Empty) + } + + fn drain(&self) -> Vec> { + let mut tasks = Vec::new(); + while let Some(task) = self.global_queue.steal().success() { + tasks.push(task); + } + let local_queues = self.local_worker_queues.read(); + local_queues + .iter() + .fold(tasks, |mut tasks, local_queue_mutex| { + let local_queue = local_queue_mutex.lock(); + while let Some(task) = local_queue.pop() { + tasks.push(task); + } + tasks + }) + } + + fn retry(&self, task: Task, thread_index: usize) -> Result> { + todo!() + } -impl LocalTaskQueues> for WorkerQueue {} + fn close(&self, _: Token) { + self.closed.set(true); + } +} diff --git a/src/hive/inner/shared.rs b/src/hive/inner/shared.rs index f421cbf..1f5b5ed 100644 --- a/src/hive/inner/shared.rs +++ b/src/hive/inner/shared.rs @@ -1,4 +1,4 @@ -use super::{Config, PopTaskError, Shared, Task, TaskQueues, Token}; +use super::{Config, PopTaskError, Shared, Task, TaskQueues, Token, WorkerQueues}; use crate::atomic::{Atomic, AtomicInt, AtomicUsize}; use crate::bee::{Queen, TaskId, Worker}; use crate::channel::SenderExt; @@ -6,6 +6,7 @@ use crate::hive::{Husk, Outcome, OutcomeSender, SpawnError}; use parking_lot::MutexGuard; use std::collections::HashMap; use std::ops::DerefMut; +use std::sync::Arc; use std::thread::{Builder, JoinHandle}; use std::{fmt, iter}; @@ -13,10 +14,11 @@ impl, T: TaskQueues> Shared { /// Creates a new `Shared` instance with the given configuration, queen, and task receiver, /// and all other fields set to their default values. pub fn new(config: Config, queen: Q) -> Self { + let task_queues = T::new(Token); Shared { config, queen, - task_queues: T::new(Token), + task_queues, spawn_results: Default::default(), num_tasks: Default::default(), next_task_id: Default::default(), @@ -88,7 +90,7 @@ impl, T: TaskQueues> Shared { let end_index = start_index + num_threads; // if worker threads need a local queue, initialize them before spawning self.task_queues - .init_for_threads(start_index, end_index, self); + .init_for_threads(start_index, end_index, &self.config); // spawn the worker threads and return the results let results: Vec<_> = (start_index..end_index).map(f).collect(); spawn_results.reserve(num_threads); @@ -140,10 +142,16 @@ impl, T: TaskQueues> Shared { .count() } + /// Returns the mutex guard for the results of spawing worker threads. pub fn spawn_results(&self) -> MutexGuard, SpawnError>>> { self.spawn_results.lock() } + /// Returns the `WorkerQueues` instance for the worker thread with the specified index. + pub fn worker_queues(&self, thread_index: usize) -> Arc { + self.task_queues.worker_queues(thread_index) + } + /// Returns a new `Worker` from the queen, or an error if a `Worker` could not be created. pub fn create_worker(&self) -> Q::Kind { self.queen.create() @@ -151,7 +159,7 @@ impl, T: TaskQueues> Shared { /// Increments the number of queued tasks. Returns a new `Task` with the provided input and /// `outcome_tx` and the next ID. - fn prepare_task(&self, input: W::Input, outcome_tx: Option<&OutcomeSender>) -> Task { + pub fn prepare_task(&self, input: W::Input, outcome_tx: Option<&OutcomeSender>) -> Task { self.num_tasks .increment_left(1) .expect("overflowed queued task counter"); @@ -190,21 +198,6 @@ impl, T: TaskQueues> Shared { task_id } - /// Creates a new `Task` for the given input and outcome channel, and attempts to add it to - /// the local queue for the specified `thread_index`. Falls back to adding it to the global - /// queue. - pub fn send_one_local( - &self, - input: W::Input, - outcome_tx: Option<&OutcomeSender>, - thread_index: usize, - ) -> TaskId { - let task = self.prepare_task(input, outcome_tx); - let task_id = task.id(); - self.task_queues.push_local(task, thread_index, self); - task_id - } - /// Creates a new `Task` for each input in the given batch and sends them to the global queue. pub fn send_batch_global( &self, @@ -270,7 +263,7 @@ impl, T: TaskQueues> Shared { /// /// Returns an error if the hive is poisoned or if the local queues are empty, and the global /// queue is disconnected. - pub fn get_next_task(&self, thread_index: usize) -> Option> { + pub fn get_next_task(&self, worker_queues: &T::WorkerQueues) -> Option> { loop { // block while the hive is suspended self.wait_on_resume(); @@ -279,7 +272,7 @@ impl, T: TaskQueues> Shared { return None; } // get the next task from the queue - break if its closed - match self.task_queues.try_pop(thread_index, self) { + match worker_queues.try_pop() { Ok(task) => break Some(task), Err(PopTaskError::Closed) => break None, Err(PopTaskError::Empty) => continue, @@ -367,6 +360,39 @@ impl, T: TaskQueues> Shared { self.no_work_notify_all(); } + /// Returns the local queue batch size. + pub fn worker_batch_limit(&self) -> usize { + self.config.batch_limit.get().unwrap_or_default() + } + + /// Changes the local queue batch size. This requires allocating a new queue for each + /// worker thread. + /// + /// Note: this method will block the current thread waiting for all local queues to become + /// writable; if `batch_limit` is less than the current batch size, this method will also + /// block while any thread's queue length is > `batch_limit` before moving the elements. + /// + /// TODO: this needs to be moved to an extension that is specific to channel hive + pub fn set_worker_batch_limit(&self, batch_limit: usize) -> usize { + // update the batch size first so any new threads spawned won't need to have their + // queues resized + let prev_batch_limit = self + .config + .batch_limit + .try_set(batch_limit) + .unwrap_or_default(); + if prev_batch_limit == batch_limit { + return prev_batch_limit; + } + let num_threads = self.num_threads(); + if num_threads == 0 { + return prev_batch_limit; + } + self.task_queues + .update_for_threads(0, num_threads, &self.config); + prev_batch_limit + } + /// Returns a reference to the `Queen`. /// /// Note that, if the queen is a `QueenMut`, the returned value will be a `QueenCell`, and it @@ -375,11 +401,6 @@ impl, T: TaskQueues> Shared { &self.queen } - /// Returns a reference to the `Config`. - pub fn config(&self) -> &Config { - &self.config - } - /// Returns a tuple with the number of (queued, active) tasks. #[inline] pub fn num_tasks(&self) -> (u64, u64) { @@ -428,12 +449,13 @@ impl, T: TaskQueues> Shared { self.num_referrers.sub(1) } - /// Sets the `poisoned` flag to `true`. Converts all queued tasks to `Outcome::Unprocessed` - /// and stores them in `outcomes`. Also automatically resumes the hive if it is suspendend, - /// which enables blocked worker threads to terminate. + /// Performs the following actions: + /// 1. Sets the `poisoned` flag to `true + /// 2. Closes all task queues so no more tasks may be pushed + /// 3. Resumes the hive if it is suspendend, which enables blocked worker threads to terminate. pub fn poison(&self) { self.poisoned.set(true); - self.drain_tasks_into_unprocessed(); + self.close_task_queues(); self.set_suspended(false); } @@ -498,29 +520,47 @@ impl, T: TaskQueues> Shared { .collect() } - /// Drains all queued tasks, converts them into `Outcome::Unprocessed` outcomes, and tries - /// to send them or (if the task does not have a sender, or if the send fails) stores them - /// in the `outcomes` map. - fn drain_tasks_into_unprocessed(&self) { - self.abandon_batch(self.task_queues.drain(Token).into_iter()); - } - /// Close the tasks queues so no more tasks can be added. - pub fn close(&self) { + pub fn close_task_queues(&self) { self.task_queues.close(Token); } + fn flush( + task_queues: T, + mut outcomes: HashMap>, + ) -> HashMap> { + task_queues.close(Token); + for task in task_queues.drain().into_iter() { + let task_id = task.id(); + let (outcome, outcome_tx) = task.into_unprocessed(); + if let Some(outcome) = if let Some(tx) = outcome_tx { + tx.try_send_msg(outcome) + } else { + Some(outcome) + } { + outcomes.insert(task_id, outcome); + } + } + outcomes + } + + /// Consumes this `Shared`, closes and drains task queues, converts any queued tasks into + /// `Outcome::Unprocessed outcomes, and tries to send them or (if the task does not have a + /// sender, or if the send fails) stores them in the `outcomes` map. Returns the outcome map. + pub fn into_outcomes(self) -> HashMap> { + Self::flush(self.task_queues, self.outcomes.into_inner()) + } + /// Consumes this `Shared` and returns a `Husk` containing the `Queen`, panic count, stored /// outcomes, and all configuration information necessary to create a new `Hive`. Any queued /// tasks are converted into `Outcome::Unprocessed` outcomes and either sent to the task's /// sender or (if there is no sender, or the send fails) stored in the `outcomes` map. pub fn into_husk(self) -> Husk { - self.drain_tasks_into_unprocessed(); Husk::new( self.config.into_unsync(), self.queen, self.num_panics.into_inner(), - self.outcomes.into_inner(), + Self::flush(self.task_queues, self.outcomes.into_inner()), ) } } @@ -572,10 +612,12 @@ mod affinity { } } -#[cfg(feature = "batching")] -mod batching { - use super::{Shared, TaskQueues}; - use crate::bee::{Queen, Worker}; +#[cfg(feature = "retry")] +mod retry { + use crate::bee::{Queen, TaskId, Worker}; + use crate::hive::inner::{Shared, Task, TaskQueues}; + use crate::hive::{OutcomeSender, WorkerQueues}; + use std::time::{Duration, Instant}; impl Shared where @@ -583,52 +625,55 @@ mod batching { Q: Queen, T: TaskQueues, { - /// Returns the local queue batch size. - pub fn batch_size(&self) -> usize { - self.config.batch_size.get().unwrap_or_default() + /// Returns the current worker retry limit. + pub fn worker_retry_limit(&self) -> u32 { + self.config.max_retries.get().unwrap_or_default() } - /// Changes the local queue batch size. This requires allocating a new queue for each - /// worker thread. - /// - /// Note: this method will block the current thread waiting for all local queues to become - /// writable; if `batch_size` is less than the current batch size, this method will also - /// block while any thread's queue length is > `batch_size` before moving the elements. - pub fn set_batch_size(&self, batch_size: usize) -> usize { - // update the batch size first so any new threads spawned won't need to have their - // queues resized - let prev_batch_size = self + /// Sets the worker retry limit and returns the previous value. + pub fn set_worker_retry_limit(&self, max_retries: u32) -> u32 { + let prev_retry_limit = self .config - .batch_size - .try_set(batch_size) + .max_retries + .try_set(max_retries) .unwrap_or_default(); - if prev_batch_size == batch_size { - return prev_batch_size; + if prev_retry_limit == max_retries { + return prev_retry_limit; } let num_threads = self.num_threads(); if num_threads == 0 { - return prev_batch_size; + return prev_retry_limit; } self.task_queues - .resize_local(0, num_threads, batch_size, self); - prev_batch_size + .update_for_threads(0, num_threads, &self.config); + prev_retry_limit } - } -} -#[cfg(feature = "retry")] -mod retry { - use crate::bee::{Queen, TaskId, Worker}; - use crate::hive::inner::{Shared, Task, TaskQueues}; - use crate::hive::OutcomeSender; - use std::time::Instant; + /// Returns the current worker retry factor. + pub fn worker_retry_factor(&self) -> Duration { + Duration::from_millis(self.config.retry_factor.get().unwrap_or_default()) + } + + /// Sets the worker retry factor and returns the previous value. + pub fn set_worker_retry_factor(&self, duration: Duration) -> Duration { + let prev_retry_factor = Duration::from_nanos( + self.config + .retry_factor + .try_set(duration.as_nanos() as u64) + .unwrap_or_default(), + ); + if prev_retry_factor == duration { + return prev_retry_factor; + } + let num_threads = self.num_threads(); + if num_threads == 0 { + return prev_retry_factor; + } + self.task_queues + .update_for_threads(0, num_threads, &self.config); + prev_retry_factor + } - impl Shared - where - W: Worker, - Q: Queen, - T: TaskQueues, - { /// Returns `true` if the hive is configured to retry tasks and the `attempt` field of the /// given `ctx` is less than the maximum number of retries. pub fn can_retry(&self, attempt: u32) -> bool { @@ -641,19 +686,19 @@ mod retry { /// Adds a task with the given `task_id`, `input`, and `outcome_tx` to the local retry /// queue for the specified `thread_index`. - pub fn send_retry( + pub fn try_send_retry( &self, task_id: TaskId, input: W::Input, - outcome_tx: Option>, + outcome_tx: Option<&OutcomeSender>, attempt: u32, - thread_index: usize, - ) -> Option { + worker_queues: &T::WorkerQueues, + ) -> Result> { self.num_tasks .increment_left(1) .expect("overflowed queued task counter"); - let task = Task::with_attempt(task_id, input, outcome_tx, attempt); - self.task_queues.retry(task, thread_index, self) + let task = Task::with_attempt(task_id, input, outcome_tx.cloned(), attempt); + worker_queues.try_push_retry(task) } } } diff --git a/src/hive/mod.rs b/src/hive/mod.rs index fbd65a4..6125e5c 100644 --- a/src/hive/mod.rs +++ b/src/hive/mod.rs @@ -137,8 +137,8 @@ //! queue. This behavior is activated by enabling the `batching` feature. //! //! With the `batching` feature enabled, `Builder` gains the -//! [`batch_size`](crate::hive::Builder::batch_size) method for configuring size of worker threads' -//! local queues, and `Hive` gains the [`set_worker_batch_size`](crate::hive::Hive::set_batch_size) +//! [`batch_limit`](crate::hive::Builder::batch_limit) method for configuring size of worker threads' +//! local queues, and `Hive` gains the [`set_worker_batch_limit`](crate::hive::Hive::set_batch_limit) //! method for changing the batch size of an existing `Hive`. //! //! ## Global defaults @@ -152,7 +152,7 @@ //! * `num_threads` //! * [`set_num_threads_default`]: sets the default to a specific value //! * [`set_num_threads_default_all`]: sets the default to all available CPU cores -//! * [`batch_size`](crate::hive::set_batch_size_default) (requires `feature = "batching"`) +//! * [`batch_limit`](crate::hive::set_BATCH_LIMIT_default) (requires `feature = "batching"`) //! * [`max_retries`](crate::hive::set_max_retries_default] (requires `feature = "retry"`) //! * [`retry_factor`](crate::hive::set_retry_factor_default] (requires `feature = "retry"`) //! @@ -370,7 +370,7 @@ pub use self::husk::Husk; pub use self::inner::{set_config::*, Builder, ChannelTaskQueues}; pub use self::outcome::{Outcome, OutcomeBatch, OutcomeIteratorExt, OutcomeStore}; -use self::inner::{Config, Shared, Task, TaskQueues}; +use self::inner::{Config, Shared, Task, TaskQueues, WorkerQueues}; use self::outcome::{DerefOutcomes, OutcomeQueue, OwnedOutcomes}; use crate::bee::Worker; use crate::channel::{channel, Receiver, Sender}; @@ -395,6 +395,28 @@ pub mod prelude { }; } +fn unwrap_arc(mut arc: std::sync::Arc) -> T { + // wait for worker threads to drop, then take ownership of the shared data and convert it + // into a Husk + let mut backoff = None::; + loop { + // TODO: may want to have some timeout or other kind of limit to prevent this from + // looping forever if a worker thread somehow gets stuck, or if the `num_referrers` + // counter is corrupted + arc = match std::sync::Arc::try_unwrap(arc) { + Ok(inner) => { + return inner; + } + Err(arc) => { + backoff + .get_or_insert_with(crossbeam_utils::Backoff::new) + .spin(); + arc + } + }; + } +} + #[cfg(test)] mod tests { use super::{ @@ -1768,12 +1790,12 @@ mod batching_tests { fn run_test( hive: &Hive>, ChannelTaskQueues>>, num_threads: usize, - batch_size: usize, + batch_limit: usize, ) { - let tasks_per_thread = batch_size + 2; + let tasks_per_thread = batch_limit + 2; let (tx, rx) = crate::hive::outcome_channel(); - // each worker should take `batch_size` tasks for its queue + 1 to work on immediately, - // meaning there should be `batch_size + 1` tasks associated with each thread ID + // each worker should take `batch_limit` tasks for its queue + 1 to work on immediately, + // meaning there should be `batch_limit + 1` tasks associated with each thread ID let barrier = IndexedBarrier::new(num_threads); let task_ids = launch_tasks(hive, num_threads, tasks_per_thread, &barrier, &tx); // start the first tasks @@ -1790,45 +1812,45 @@ mod batching_tests { #[test] fn test_batching() { const NUM_THREADS: usize = 4; - const BATCH_SIZE: usize = 24; + const BATCH_LIMIT: usize = 24; let hive = ChannelBuilder::empty() .with_worker_default() .num_threads(NUM_THREADS) - .batch_size(BATCH_SIZE) + .batch_limit(BATCH_LIMIT) .build(); - run_test(&hive, NUM_THREADS, BATCH_SIZE); + run_test(&hive, NUM_THREADS, BATCH_LIMIT); } #[test] - fn test_set_batch_size() { + fn test_set_batch_limit() { const NUM_THREADS: usize = 4; - const BATCH_SIZE_0: usize = 10; - const BATCH_SIZE_1: usize = 20; - const BATCH_SIZE_2: usize = 50; + const BATCH_LIMIT_0: usize = 10; + const BATCH_LIMIT_1: usize = 20; + const BATCH_LIMIT_2: usize = 50; let hive = ChannelBuilder::empty() .with_worker_default() .num_threads(NUM_THREADS) - .batch_size(BATCH_SIZE_0) + .batch_limit(BATCH_LIMIT_0) .build(); - run_test(&hive, NUM_THREADS, BATCH_SIZE_0); + run_test(&hive, NUM_THREADS, BATCH_LIMIT_0); // increase batch size - hive.set_worker_batch_size(BATCH_SIZE_2); - run_test(&hive, NUM_THREADS, BATCH_SIZE_2); + hive.set_worker_batch_limit(BATCH_LIMIT_2); + run_test(&hive, NUM_THREADS, BATCH_LIMIT_2); // decrease batch size - hive.set_worker_batch_size(BATCH_SIZE_1); - run_test(&hive, NUM_THREADS, BATCH_SIZE_1); + hive.set_worker_batch_limit(BATCH_LIMIT_1); + run_test(&hive, NUM_THREADS, BATCH_LIMIT_1); } #[test] - fn test_shrink_batch_size() { + fn test_shrink_batch_limit() { const NUM_THREADS: usize = 4; const NUM_TASKS_PER_THREAD: usize = 125; - const BATCH_SIZE_0: usize = 100; - const BATCH_SIZE_1: usize = 10; + const BATCH_LIMIT_0: usize = 100; + const BATCH_LIMIT_1: usize = 10; let hive = ChannelBuilder::empty() .with_worker_default() .num_threads(NUM_THREADS) - .batch_size(BATCH_SIZE_0) + .batch_limit(BATCH_LIMIT_0) .build(); let (tx, rx) = crate::hive::outcome_channel(); let barrier = IndexedBarrier::new(NUM_THREADS); @@ -1836,13 +1858,13 @@ mod batching_tests { let total_tasks = NUM_THREADS * NUM_TASKS_PER_THREAD; assert_eq!(task_ids.len(), total_tasks); barrier.wait(); - hive.set_worker_batch_size(BATCH_SIZE_1); + hive.set_worker_batch_limit(BATCH_LIMIT_1); // The number of tasks completed by each thread could be variable, so we want to ensure - // that a) each processed at least `BATCH_SIZE_0` tasks, and b) there are a total of + // that a) each processed at least `BATCH_LIMIT_0` tasks, and b) there are a total of // `NUM_TASKS` outputs with no errors hive.join(); let thread_counts = count_thread_ids(rx, task_ids); - assert!(thread_counts.values().all(|count| *count > BATCH_SIZE_0)); + assert!(thread_counts.values().all(|count| *count > BATCH_LIMIT_0)); assert_eq!(thread_counts.values().sum::(), total_tasks); } } diff --git a/src/hive/outcome/queue.rs b/src/hive/outcome/queue.rs index bec9154..6d3c83b 100644 --- a/src/hive/outcome/queue.rs +++ b/src/hive/outcome/queue.rs @@ -1,9 +1,9 @@ -use super::Outcome; +use super::{DerefOutcomes, Outcome}; use crate::bee::{TaskId, Worker}; use crossbeam_queue::SegQueue; use parking_lot::Mutex; use std::collections::HashMap; -use std::ops::DerefMut; +use std::ops::{Deref, DerefMut}; pub struct OutcomeQueue { queue: SegQueue>, @@ -55,3 +55,13 @@ impl Default for OutcomeQueue { } } } + +impl DerefOutcomes for OutcomeQueue { + fn outcomes_deref(&self) -> impl Deref>> { + self.outcomes.lock() + } + + fn outcomes_deref_mut(&mut self) -> impl DerefMut>> { + self.outcomes.lock() + } +} From d715e0aa8a9044dbea8bc5053f1e51ed4e85fe2f Mon Sep 17 00:00:00 2001 From: jdidion Date: Wed, 19 Feb 2025 10:09:20 -0800 Subject: [PATCH 12/67] refactor; finish workstealing queue impl --- CHANGELOG.md | 2 +- Cargo.toml | 3 + src/hive/inner/config.rs | 3 +- src/hive/inner/mod.rs | 5 +- src/hive/inner/queue/channel.rs | 24 +- src/hive/inner/queue/mod.rs | 9 +- src/hive/inner/queue/{delay.rs => retry.rs} | 105 +++++-- src/hive/inner/queue/workstealing.rs | 321 ++++++++++++++------ src/hive/inner/shared.rs | 3 +- src/hive/inner/task.rs | 31 ++ 10 files changed, 350 insertions(+), 156 deletions(-) rename src/hive/inner/queue/{delay.rs => retry.rs} (57%) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2b7ce28..ae21dc8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,7 +19,7 @@ The general theme of this release is performance improvement by eliminating thre * Added the `TaskQueues` trait, which enables `Hive` to be specialized for different implementations of global (i.e., sending tasks from the `Hive` to worker threads) and local (i.e., worker thread-specific) queues. * `ChannelTaskQueues` implements the existing behavior, using a channel for sending tasks. * `WorkstealingTaskQueues` has been added to implement the workstealing pattern, based on `crossbeam::dequeue`. - * Added the `batching` feature, which enables `ChannelTaskQueues` to use worker-thread local queues to queue up batches of tasks locally, which can alleviate contention between threads in the pool, especially when there are many short-lived tasks. + * Added the `batching` feature, which enables worker threads to queue up batches of tasks locally, which can alleviate contention between threads in the pool, especially when there are many short-lived tasks. * Added the `Context::submit` method, which enables tasks to submit new tasks to the `Hive`. * Other * Switched to using thread-local retry queues for the implementation of the `retry` feature, to reduce thread-contention. diff --git a/Cargo.toml b/Cargo.toml index 95c2478..061ea0e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,9 @@ num = "0.4.3" num_cpus = "1.16.0" parking_lot = "0.12.3" paste = "1.0.15" +rand = { version = "0.9.0", default-features = false, features = [ + "thread_rng", +] } thiserror = "1.0.63" # required with the `affinity` feature core_affinity = { version = "0.8.1", optional = true } diff --git a/src/hive/inner/config.rs b/src/hive/inner/config.rs index b1d2dd4..e41fcbc 100644 --- a/src/hive/inner/config.rs +++ b/src/hive/inner/config.rs @@ -44,9 +44,10 @@ impl Config { num_threads: Default::default(), thread_name: Default::default(), thread_stack_size: Default::default(), - batch_limit: Default::default(), #[cfg(feature = "affinity")] affinity: Default::default(), + #[cfg(feature = "batching")] + batch_limit: Default::default(), #[cfg(feature = "retry")] max_retries: Default::default(), #[cfg(feature = "retry")] diff --git a/src/hive/inner/mod.rs b/src/hive/inner/mod.rs index 82af070..13f9844 100644 --- a/src/hive/inner/mod.rs +++ b/src/hive/inner/mod.rs @@ -92,11 +92,12 @@ pub struct Config { thread_name: Any, /// Stack size for each worker thread thread_stack_size: Usize, - /// Maximum number of tasks for a worker thread to take when receiving from the input channel - batch_limit: Usize, /// CPU cores to which worker threads can be pinned #[cfg(feature = "affinity")] affinity: Any, + /// Maximum number of tasks for a worker thread to take when receiving from the input channel + #[cfg(feature = "batching")] + batch_limit: Usize, /// Maximum number of retries for a task #[cfg(feature = "retry")] max_retries: U32, diff --git a/src/hive/inner/queue/channel.rs b/src/hive/inner/queue/channel.rs index aa5012a..af934b6 100644 --- a/src/hive/inner/queue/channel.rs +++ b/src/hive/inner/queue/channel.rs @@ -25,6 +25,7 @@ pub struct ChannelTaskQueues { impl TaskQueues for ChannelTaskQueues { type WorkerQueues = ChannelWorkerQueues; + type WorkerQueuesTarget = Arc; fn new(_: Token) -> Self { Self { @@ -137,9 +138,7 @@ pub struct ChannelWorkerQueues { local_batch: RwLock>>, /// thread-local queues used for tasks that are waiting to be retried after a failure #[cfg(feature = "retry")] - local_retry: super::delay::DelayQueue>, - #[cfg(feature = "retry")] - retry_factor: crate::atomic::AtomicU64, + local_retry: super::retry::RetryQueue, } impl ChannelWorkerQueues { @@ -153,9 +152,7 @@ impl ChannelWorkerQueues { config.batch_limit.get_or_default().max(1), )), #[cfg(feature = "retry")] - local_retry: Default::default(), - #[cfg(feature = "retry")] - retry_factor: crate::atomic::AtomicU64::new(config.retry_factor.get_or_default()), + local_retry: super::retry::RetryQueue::new(config.retry_factor.get_or_default()), } } @@ -166,7 +163,7 @@ impl ChannelWorkerQueues { #[cfg(feature = "batching")] self.update_batch(config); #[cfg(feature = "retry")] - self.retry_factor.set(config.retry_factor.get_or_default()); + self.local_retry.set_delay_factor(config.retry_factor.get_or_default()); } /// Consumes this `ChannelWorkerQueues` and drains the tasks currently in the queues into @@ -225,18 +222,7 @@ impl WorkerQueues for ChannelWorkerQueues { #[cfg(feature = "retry")] fn try_push_retry(&self, task: Task) -> Result> { - // compute the delay - let delay = 2u64 - .checked_pow(task.attempt - 1) - .and_then(|multiplier| { - self.retry_factor - .get() - .checked_mul(multiplier) - .or(Some(u64::MAX)) - .map(Duration::from_nanos) - }) - .unwrap_or_default(); - self.local_retry.push(task, delay) + self.local_retry.try_push(task) } } diff --git a/src/hive/inner/queue/mod.rs b/src/hive/inner/queue/mod.rs index 337ebaa..0823a58 100644 --- a/src/hive/inner/queue/mod.rs +++ b/src/hive/inner/queue/mod.rs @@ -1,13 +1,13 @@ mod channel; #[cfg(feature = "retry")] -mod delay; -//mod workstealing; +mod retry; +mod workstealing; pub use self::channel::ChannelTaskQueues; use super::{Config, Task, Token}; use crate::bee::Worker; -use std::sync::Arc; +use std::ops::Deref; /// Errors that may occur when trying to pop tasks from the global queue. #[derive(thiserror::Error, Debug)] @@ -24,6 +24,7 @@ pub enum PopTaskError { /// This trait is sealed - it cannot be implemented outside of this crate. pub trait TaskQueues: Sized + Send + Sync + 'static { type WorkerQueues: WorkerQueues; + type WorkerQueuesTarget: Deref; /// Returns a new instance. /// @@ -37,7 +38,7 @@ pub trait TaskQueues: Sized + Send + Sync + 'static { fn update_for_threads(&self, start_index: usize, end_index: usize, config: &Config); /// Returns a new `WorkerQueues` instance for a thread. - fn worker_queues(&self, thread_index: usize) -> Arc; + fn worker_queues(&self, thread_index: usize) -> Self::WorkerQueuesTarget; /// Tries to add a task to the global queue. /// diff --git a/src/hive/inner/queue/delay.rs b/src/hive/inner/queue/retry.rs similarity index 57% rename from src/hive/inner/queue/delay.rs rename to src/hive/inner/queue/retry.rs index 1073808..347f8a7 100644 --- a/src/hive/inner/queue/delay.rs +++ b/src/hive/inner/queue/retry.rs @@ -1,3 +1,6 @@ +use crate::atomic::{Atomic, AtomicU64}; +use crate::bee::Worker; +use crate::hive::Task; use std::cell::UnsafeCell; use std::cmp::Ordering; use std::collections::BinaryHeap; @@ -15,23 +18,48 @@ use std::time::{Duration, Instant}; /// `UnsafeCell` is used for performance - this is safe so long as the queue is only accessed from /// a single thread at a time. This data structure is *not* thread-safe. #[derive(Debug)] -pub struct DelayQueue(UnsafeCell>>); +pub struct RetryQueue { + inner: UnsafeCell>>, + delay_factor: AtomicU64, +} + +impl RetryQueue { + pub fn new(delay_factor: u64) -> Self { + Self { + inner: UnsafeCell::new(BinaryHeap::new()), + delay_factor: AtomicU64::new(delay_factor), + } + } + + pub fn set_delay_factor(&self, delay_factor: u64) { + self.delay_factor.set(delay_factor); + } -impl DelayQueue { /// Pushes an item onto the queue. Returns the `Instant` at which the item will be available, /// or an error with `item` if there was an error pushing the item. /// /// SAFETY: this method is only ever called within a single thread. - pub fn push(&self, item: T, delay: Duration) -> Result { + pub fn try_push(&self, task: Task) -> Result> { + // compute the delay + let delay = 2u64 + .checked_pow(task.attempt - 1) + .and_then(|multiplier| { + self.delay_factor + .get() + .checked_mul(multiplier) + .or(Some(u64::MAX)) + .map(Duration::from_nanos) + }) + .unwrap_or_default(); unsafe { - match self.0.get().as_mut() { + match self.inner.get().as_mut() { Some(queue) => { - let delayed = Delayed::new(item, delay); + let delayed = Delayed::new(task, delay); let until = delayed.until; queue.push(delayed); Ok(until) } - None => Err(item), + None => Err(task), } } } @@ -40,9 +68,9 @@ impl DelayQueue { /// has been exceeded), and removes it. /// /// SAFETY: this method is only ever called within a single thread. - pub fn try_pop(&self) -> Option { + pub fn try_pop(&self) -> Option> { unsafe { - let queue_ptr = self.0.get(); + let queue_ptr = self.inner.get(); if queue_ptr .as_ref() .and_then(|queue| queue.peek()) @@ -59,21 +87,17 @@ impl DelayQueue { } } - /// Consumes this `DelayQueue` and drains all items from the queue into `sink`. - pub fn drain_into(self, sink: &mut Vec) { - let mut queue = self.0.into_inner(); + /// Consumes this `RetryQueue` and drains all items from the queue into `sink`. + pub fn drain_into(self, sink: &mut Vec>) { + let mut queue = self.inner.into_inner(); sink.reserve(queue.len()); sink.extend(queue.drain().map(|delayed| delayed.value)) } } -unsafe impl Sync for DelayQueue {} +unsafe impl Sync for RetryQueue {} -impl Default for DelayQueue { - fn default() -> Self { - DelayQueue(UnsafeCell::new(BinaryHeap::new())) - } -} +type DelayedTask = Delayed>; #[derive(Debug)] struct Delayed { @@ -91,7 +115,7 @@ impl Delayed { } /// Implements ordering for `Delayed`, so it can be used to correctly order elements in the -/// `BinaryHeap` of the `DelayQueue`. +/// `BinaryHeap` of the `RetryQueue`. /// /// Earlier entries have higher priority (should be popped first), so they are Greater that later /// entries. @@ -117,36 +141,44 @@ impl Eq for Delayed {} #[cfg(test)] mod tests { - use super::DelayQueue; + use super::{RetryQueue, Task, Worker}; + use crate::bee::stock::EchoWorker; use std::{thread, time::Duration}; - impl DelayQueue { + type TestWorker = EchoWorker; + const DELAY: u64 = Duration::from_secs(1).as_nanos() as u64; + + impl RetryQueue { fn len(&self) -> usize { - unsafe { self.0.get().as_ref().unwrap().len() } + unsafe { self.inner.get().as_ref().unwrap().len() } } } #[test] fn test_works() { - let queue = DelayQueue::default(); + let queue = RetryQueue::::new(DELAY); - queue.push(1, Duration::from_secs(1)).unwrap(); - queue.push(2, Duration::from_secs(2)).unwrap(); - queue.push(3, Duration::from_secs(3)).unwrap(); + let task1 = Task::with_attempt(1, 1, None, 1); + let task2 = Task::with_attempt(2, 2, None, 2); + let task3 = Task::with_attempt(3, 3, None, 3); + + queue.try_push(task1.clone()).unwrap(); + queue.try_push(task2.clone()).unwrap(); + queue.try_push(task3.clone()).unwrap(); assert_eq!(queue.len(), 3); assert_eq!(queue.try_pop(), None); thread::sleep(Duration::from_secs(1)); - assert_eq!(queue.try_pop(), Some(1)); + assert_eq!(queue.try_pop(), Some(task1)); assert_eq!(queue.len(), 2); thread::sleep(Duration::from_secs(1)); - assert_eq!(queue.try_pop(), Some(2)); + assert_eq!(queue.try_pop(), Some(task2)); assert_eq!(queue.len(), 1); thread::sleep(Duration::from_secs(1)); - assert_eq!(queue.try_pop(), Some(3)); + assert_eq!(queue.try_pop(), Some(task3)); assert_eq!(queue.len(), 0); assert_eq!(queue.try_pop(), None); @@ -154,13 +186,20 @@ mod tests { #[test] fn test_into_vec() { - let queue = DelayQueue::default(); - queue.push(1, Duration::from_secs(1)).unwrap(); - queue.push(2, Duration::from_secs(2)).unwrap(); - queue.push(3, Duration::from_secs(3)).unwrap(); + let queue = RetryQueue::::new(DELAY); + + let task1 = Task::with_attempt(1, 1, None, 1); + let task2 = Task::with_attempt(2, 2, None, 2); + let task3 = Task::with_attempt(3, 3, None, 3); + + queue.try_push(task1.clone()).unwrap(); + queue.try_push(task2.clone()).unwrap(); + queue.try_push(task3.clone()).unwrap(); + let mut v = Vec::new(); queue.drain_into(&mut v); v.sort(); - assert_eq!(v, vec![1, 2, 3]); + + assert_eq!(v, vec![task1, task2, task3]); } } diff --git a/src/hive/inner/queue/workstealing.rs b/src/hive/inner/queue/workstealing.rs index a754478..e8dba0f 100644 --- a/src/hive/inner/queue/workstealing.rs +++ b/src/hive/inner/queue/workstealing.rs @@ -1,135 +1,268 @@ use super::{Config, PopTaskError, Task, TaskQueues, Token, WorkerQueues}; use crate::atomic::{Atomic, AtomicBool, AtomicUsize}; use crate::bee::Worker; -use crossbeam_deque::{Injector as GlobalQueue, Stealer, Worker as LocalQueue}; -use parking_lot::{Mutex, RwLock}; +use crossbeam_deque::{Injector, Stealer}; +use crossbeam_queue::SegQueue; +use parking_lot::RwLock; +use rand::prelude::*; +use std::ops::Deref; +use std::sync::Arc; -pub struct WorkstealingWorkerQueues { - local_batch: LocalQueue, - batch_limit: AtomicUsize, - /// thread-local queues used for tasks that are waiting to be retried after a failure - #[cfg(feature = "retry")] - local_retry: super::delay::DelayQueue>, - #[cfg(feature = "retry")] - retry_factor: crate::atomic::AtomicU64, +struct WorkstealingTaskQueues { + global: Arc>, + local: RwLock>>>, } -impl WorkerQueues for WorkstealingWorkerQueues { - fn push(&self, task: Task) { - todo!() +impl TaskQueues for WorkstealingTaskQueues { + type WorkerQueues = WorkstealingWorkerQueues; + type WorkerQueuesTarget = Self::WorkerQueues; + + fn new(_: Token) -> Self { + Self { + global: Arc::new(GlobalQueue::new()), + local: Default::default(), + } } - fn try_pop(&self) -> Result, PopTaskError> { - todo!() + fn init_for_threads(&self, start_index: usize, end_index: usize, config: &Config) { + let mut local_queues = self.local.write(); + assert_eq!(local_queues.len(), start_index); + (start_index..end_index).for_each(|thread_index| { + local_queues.push(Arc::new(LocalQueueShared::new( + thread_index, + &self.global, + config, + ))); + }); } - fn try_push_retry(&self, task: Task) -> Result> { - todo!() + fn update_for_threads(&self, start_index: usize, end_index: usize, config: &Config) { + let local_queues = self.local.read(); + assert!(local_queues.len() > end_index); + (start_index..end_index).for_each(|thread_index| local_queues[thread_index].update(config)); + } + + fn worker_queues(&self, thread_index: usize) -> Self::WorkerQueuesTarget { + let local_queue = crossbeam_deque::Worker::new_fifo(); + self.global.add_stealer(local_queue.stealer()); + let shared = &self.local.read()[thread_index]; + WorkstealingWorkerQueues::new(local_queue, Arc::clone(&shared)) + } + + fn try_push_global(&self, task: Task) -> Result<(), Task> { + self.global.try_push(task) + } + + fn close(&self, _: Token) { + self.global.close(); + } + + fn drain(self) -> Vec> { + self.close(Token); + let mut tasks = Vec::new(); + let global = crate::hive::unwrap_arc(self.global); + global.drain_into(&mut tasks); + for local in self.local.into_inner().into_iter() { + let local = crate::hive::unwrap_arc(local); + local.drain_into(&mut tasks); + } + tasks } } -struct WorkstealingQueues { - global_queue: GlobalQueue>, - local_queues: RwLock>>, - local_stealers: RwLock>>>, +pub struct GlobalQueue { + queue: Injector>, + stealers: RwLock>>>, closed: AtomicBool, } -impl TaskQueues for WorkstealingQueues { - type WorkerQueues = WorkstealingWorkerQueues; - - fn new(_: Token) -> Self { +impl GlobalQueue { + fn new() -> Self { Self { - global_queue: Default::default(), - local_worker_queues: Default::default(), - local_stealers: Default::default(), - batch_limit: AtomicUsize::new(config.batch_limit.get_or_default()), - closed: Default::default(), - #[cfg(feature = "retry")] - local_retry_queues: Default::default(), - #[cfg(feature = "retry")] - retry_factor: crate::atomic::AtomicU64::new(config.retry_factor.get_or_default()), + queue: Injector::new(), + stealers: Default::default(), + closed: AtomicBool::default(), } } - fn init_for_threads(&self, start_index: usize, end_index: usize, config: &Config) { - let mut local_queues = self.local_worker_queues.write(); - assert_eq!(local_queues.len(), start_index); - let mut stealers = self.local_stealers.write(); - (start_index..end_index).for_each(|_| { - let local_queue = LocalQueue::new_fifo(); - let stealer = local_queue.stealer(); - local_queues.push(Mutex::new(local_queue)); - stealers.push(stealer); - }); - //#[cfg(feature = "retry")] - //self.init_retry_queues_for_threads(start_index, end_index); + fn add_stealer(&self, stealer: Stealer>) { + self.stealers.write().push(stealer); } - fn update_for_threads(&self, start_index: usize, end_index: usize, config: &Config) { - todo!() + fn try_push(&self, task: Task) -> Result<(), Task> { + if self.closed.get() { + return Err(task); + } + self.queue.push(task); + Ok(()) } - fn try_push_global(&self, task: Task) -> Result<(), Task> { - if !self.closed.get() { - self.global_queue.push(task); - Ok(()) + fn try_steal(&self) -> Option> { + let stealers = self.stealers.read(); + let n = stealers.len(); + // randomize the stealing order, to prevent always stealing from the same thread + // TODO: put this into a shared global to prevent creating a new instance on every call + std::iter::from_fn(|| Some(rand::rng().random_range(0..n))) + .take(n) + .filter_map(|i| stealers[i].steal().success()) + .next() + } + + /// Tries to steal a task from the global queue, otherwise tries to steal a task from another + /// worker thread. + #[cfg(not(feature = "batching"))] + fn try_pop(&self) -> Result, PopTaskError> { + if let Some(task) = self.queue.steal().success() { + Ok(task) } else { - Err(task) + self.try_steal().ok_or(PopTaskError::Empty) } } - fn try_push_local(&self, task: Task, thread_index: usize) -> Result<(), Task> { - self.local_worker_queues.read()[thread_index] - .lock() - .push(task); - Ok(()) + /// Tries to steal up to `limit` tasks from the global queue. If at least one task was stolen, + /// it is popped and returned. Otherwise tries to steal a task from another worker thread. + #[cfg(feature = "batching")] + fn try_refill_and_pop( + &self, + local_batch: &crossbeam_deque::Worker>, + limit: usize, + ) -> Result, PopTaskError> { + if let Some(task) = self + .queue + .steal_batch_with_limit_and_pop(local_batch, limit) + .success() + { + return Ok(task); + } + self.try_steal().ok_or(PopTaskError::Empty) } - fn try_pop(&self, thread_index: usize) -> Result, PopTaskError> { - // first try popping from the local queue - { - let worker_queue_mutex = &self.local_worker_queues.read()[thread_index]; - let worker_queue = worker_queue_mutex.lock(); - worker_queue.pop().or_else(|| { - self.global_queue - .steal_batch_with_limit_and_pop(&worker_queue, self.batch_limit.get()) - .success() - }) + fn close(&self) { + self.closed.set(true); + } + + fn drain_into(self, tasks: &mut Vec>) { + while let Some(task) = self.queue.steal().success() { + tasks.push(task); } - .or_else(|| { - // TODO: randomize the order - self.local_stealers - .read() - .iter() - .filter_map(|stealer| stealer.steal().success()) - .next() + self.stealers.into_inner().into_iter().for_each(|stealer| { + while let Some(task) = stealer.steal().success() { + tasks.push(task); + } }) - .ok_or(PopTaskError::Empty) } +} - fn drain(&self) -> Vec> { - let mut tasks = Vec::new(); - while let Some(task) = self.global_queue.steal().success() { - tasks.push(task); +pub struct WorkstealingWorkerQueues { + queue: crossbeam_deque::Worker>, + shared: Arc>, +} + +impl WorkstealingWorkerQueues { + fn new(queue: crossbeam_deque::Worker>, shared: Arc>) -> Self { + Self { queue, shared } + } +} + +impl WorkerQueues for WorkstealingWorkerQueues { + fn push(&self, task: Task) { + self.queue.push(task); + } + + fn try_pop(&self) -> Result, PopTaskError> { + self.shared.try_pop(&self.queue) + } + + #[cfg(feature = "retry")] + fn try_push_retry(&self, task: Task) -> Result> { + self.shared.try_push_retry(task) + } +} + +impl Deref for WorkstealingWorkerQueues { + type Target = Self; + + fn deref(&self) -> &Self::Target { + self + } +} + +struct LocalQueueShared { + _thread_index: usize, + global: Arc>, + /// queue of abandon tasks + local_abandoned: SegQueue>, + #[cfg(feature = "batching")] + batch_limit: AtomicUsize, + /// thread-local queues used for tasks that are waiting to be retried after a failure + #[cfg(feature = "retry")] + local_retry: super::retry::RetryQueue, +} + +impl LocalQueueShared { + fn new(thread_index: usize, global: &Arc>, config: &Config) -> Self { + Self { + _thread_index: thread_index, + global: Arc::clone(global), + local_abandoned: Default::default(), + #[cfg(feature = "batching")] + batch_limit: AtomicUsize::new(config.batch_limit.get_or_default()), + #[cfg(feature = "retry")] + local_retry: super::retry::RetryQueue::new(config.retry_factor.get_or_default()), } - let local_queues = self.local_worker_queues.read(); - local_queues - .iter() - .fold(tasks, |mut tasks, local_queue_mutex| { - let local_queue = local_queue_mutex.lock(); - while let Some(task) = local_queue.pop() { - tasks.push(task); - } - tasks - }) } - fn retry(&self, task: Task, thread_index: usize) -> Result> { - todo!() + fn update(&self, config: &Config) { + #[cfg(feature = "batching")] + self.batch_limit.set(config.batch_limit.get_or_default()); + #[cfg(feature = "retry")] + self.local_retry + .set_delay_factor(config.retry_factor.get_or_default()); } - fn close(&self, _: Token) { - self.closed.set(true); + fn try_pop( + &self, + local_batch: &crossbeam_deque::Worker>, + ) -> Result, PopTaskError> { + // first try to get a previously abandoned task + if let Some(task) = self.local_abandoned.pop() { + return Ok(task); + } + // if retry is enabled, try to get a task from the retry queue + #[cfg(feature = "retry")] + if let Some(task) = self.local_retry.try_pop() { + return Ok(task); + } + // next try the local queue + if let Some(task) = local_batch.pop() { + return Ok(task); + } + // fall back to requesting a task from the global queue - this will also refill the local + // batch queue if the batching feature is enabled + if let Some(task) = local_batch.pop() { + return Ok(task); + } + #[cfg(feature = "batching")] + { + self.global + .try_refill_and_pop(local_batch, self.batch_limit.get()) + } + #[cfg(not(feature = "batching"))] + { + self.global.try_pop() + } + } + + fn drain_into(self, tasks: &mut Vec>) { + while let Some(task) = self.local_abandoned.pop() { + tasks.push(task); + } + #[cfg(feature = "retry")] + self.local_retry.drain_into(tasks); + } + + #[cfg(feature = "retry")] + fn try_push_retry(&self, task: Task) -> Result> { + self.local_retry.try_push(task) } } diff --git a/src/hive/inner/shared.rs b/src/hive/inner/shared.rs index 1f5b5ed..c881c83 100644 --- a/src/hive/inner/shared.rs +++ b/src/hive/inner/shared.rs @@ -6,7 +6,6 @@ use crate::hive::{Husk, Outcome, OutcomeSender, SpawnError}; use parking_lot::MutexGuard; use std::collections::HashMap; use std::ops::DerefMut; -use std::sync::Arc; use std::thread::{Builder, JoinHandle}; use std::{fmt, iter}; @@ -148,7 +147,7 @@ impl, T: TaskQueues> Shared { } /// Returns the `WorkerQueues` instance for the worker thread with the specified index. - pub fn worker_queues(&self, thread_index: usize) -> Arc { + pub fn worker_queues(&self, thread_index: usize) -> T::WorkerQueuesTarget { self.task_queues.worker_queues(thread_index) } diff --git a/src/hive/inner/task.rs b/src/hive/inner/task.rs index 5a2f634..7123acb 100644 --- a/src/hive/inner/task.rs +++ b/src/hive/inner/task.rs @@ -66,3 +66,34 @@ impl Task { (self.id, self.input, self.attempt, self.outcome_tx) } } + +impl> Clone for Task { + fn clone(&self) -> Self { + Self { + id: self.id.clone(), + input: self.input.clone(), + outcome_tx: self.outcome_tx.clone(), + attempt: self.attempt.clone(), + } + } +} + +impl PartialEq for Task { + fn eq(&self, other: &Self) -> bool { + self.id == other.id + } +} + +impl Eq for Task {} + +impl PartialOrd for Task { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for Task { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.id.cmp(&other.id) + } +} From 5edcab00cc0c86498ad6bca1b020fbde0da41ef6 Mon Sep 17 00:00:00 2001 From: jdidion Date: Wed, 19 Feb 2025 11:43:19 -0800 Subject: [PATCH 13/67] fix lints --- src/atomic.rs | 2 +- src/hive/builder/bee.rs | 15 ++-- src/hive/builder/full.rs | 2 +- src/hive/builder/mod.rs | 6 +- src/hive/builder/open.rs | 6 +- src/hive/hive.rs | 113 +++++++++++++++------------ src/hive/inner/builder.rs | 52 ++++++------ src/hive/inner/config.rs | 27 +++++-- src/hive/inner/mod.rs | 2 +- src/hive/inner/queue/channel.rs | 14 ++-- src/hive/inner/queue/mod.rs | 1 + src/hive/inner/queue/workstealing.rs | 20 ++--- src/hive/inner/shared.rs | 80 +++++++++++-------- src/hive/inner/task.rs | 104 ++++++++++++++---------- src/hive/mod.rs | 2 +- 15 files changed, 261 insertions(+), 185 deletions(-) diff --git a/src/atomic.rs b/src/atomic.rs index 2f08c18..20b0e97 100644 --- a/src/atomic.rs +++ b/src/atomic.rs @@ -388,7 +388,7 @@ mod affinity { } } -#[cfg(feature = "batching")] +#[cfg(any(feature = "batching", feature = "retry"))] mod batching { use super::{Atomic, AtomicOption, MutError}; use std::fmt::Debug; diff --git a/src/hive/builder/bee.rs b/src/hive/builder/bee.rs index 466acf4..2cb2083 100644 --- a/src/hive/builder/bee.rs +++ b/src/hive/builder/bee.rs @@ -1,6 +1,6 @@ use super::{BuilderConfig, FullBuilder, Token}; use crate::bee::{CloneQueen, DefaultQueen, Queen, QueenCell, QueenMut, Worker}; -use crate::hive::{ChannelTaskQueues, Config, TaskQueues}; +use crate::hive::{ChannelTaskQueues, Config, TaskQueues, WorkstealingTaskQueues}; /// A Builder for creating `Hive` instances for specific [`Worker`] and [`TaskQueues`] types. #[derive(Clone, Default)] @@ -14,7 +14,7 @@ impl BeeBuilder { pub fn empty>(queen: Q) -> Self { Self { config: Config::empty(), - queen: queen.into(), + queen, } } @@ -23,16 +23,13 @@ impl BeeBuilder { pub fn preset>(queen: Q) -> Self { Self { config: Config::default(), - queen: queen.into(), + queen, } } /// Creates a new `BeeBuilder` from an existing `config` and a `queen`. pub(super) fn from(config: Config, queen: Q) -> Self { - Self { - config, - queen: queen.into(), - } + Self { config, queen } } /// Creates a new `FullBuilder` with the current configuration and queen and specified @@ -46,6 +43,10 @@ impl BeeBuilder { pub fn with_channel_queues(self) -> FullBuilder> { FullBuilder::from(self.config, self.queen) } + + pub fn with_workstealing_queues(self) -> FullBuilder> { + FullBuilder::from(self.config, self.queen) + } } impl BeeBuilder { diff --git a/src/hive/builder/full.rs b/src/hive/builder/full.rs index e2a88e7..902e878 100644 --- a/src/hive/builder/full.rs +++ b/src/hive/builder/full.rs @@ -15,7 +15,7 @@ impl> FullBuilder { pub fn empty>(queen: Q) -> Self { Self { config: Config::empty(), - queen: queen.into(), + queen, _queues: PhantomData, } } diff --git a/src/hive/builder/mod.rs b/src/hive/builder/mod.rs index 7c809cd..3f6401e 100644 --- a/src/hive/builder/mod.rs +++ b/src/hive/builder/mod.rs @@ -2,11 +2,11 @@ //! which provides methods to set configuration parameters. //! //! * Open: has no type parameters; can only set config parameters. Has methods to create -//! typed builders. +//! typed builders. //! * Bee-typed: has type parameters for the `Worker` and `Queen` types. //! * Queue-typed: builder instances that are specific to the `TaskQueues` type. //! * Fully-typed: builder that has type parameters for the `Worker`, `Queen`, and `TaskQueues` -//! types. This is the only builder with a `build` method to create a `Hive`. +//! types. This is the only builder with a `build` method to create a `Hive`. //! //! Generic - Queue //! | / @@ -17,11 +17,13 @@ mod bee; mod channel; mod full; mod open; +mod workstealing; pub use bee::BeeBuilder; pub use channel::ChannelBuilder; pub use full::FullBuilder; pub use open::OpenBuilder; +pub use workstealing::WorkstealingBuilder; use crate::hive::inner::{BuilderConfig, Token}; diff --git a/src/hive/builder/open.rs b/src/hive/builder/open.rs index 89d9ab8..b9791b5 100644 --- a/src/hive/builder/open.rs +++ b/src/hive/builder/open.rs @@ -1,4 +1,4 @@ -use super::{BeeBuilder, BuilderConfig, ChannelBuilder, Token}; +use super::{BeeBuilder, BuilderConfig, ChannelBuilder, Token, WorkstealingBuilder}; use crate::bee::{CloneQueen, DefaultQueen, Queen, QueenCell, QueenMut, Worker}; use crate::hive::Config; @@ -262,6 +262,10 @@ impl OpenBuilder { pub fn with_channel_queues(self) -> ChannelBuilder { ChannelBuilder::from(self.0) } + + pub fn with_workstealing_queues(self) -> WorkstealingBuilder { + WorkstealingBuilder::from(self.0) + } } impl BuilderConfig for OpenBuilder { diff --git a/src/hive/hive.rs b/src/hive/hive.rs index 19e86c8..bb5f863 100644 --- a/src/hive/hive.rs +++ b/src/hive/hive.rs @@ -94,20 +94,6 @@ impl, T: TaskQueues> Hive { self.grow(num_threads) } - /// Returns the batch limit for worker threads. - pub fn worker_batch_limit(&self) -> usize { - self.shared().worker_batch_limit() - } - - /// Sets the batch limit for worker threads. - /// - /// Depending on this hive's `TaskQueues` implementation, this method may: - /// * have no effect (if it does not support local batching) - /// * block the current thread until all worker thread queues can be resized. - pub fn set_worker_batch_limit(&self, batch_limit: usize) { - self.shared().set_worker_batch_limit(batch_limit); - } - /// Sends one `input` to the `Hive` for procesing and returns the result, blocking until the /// result is available. Creates a channel to send the input and receive the outcome. Returns /// an [`Outcome`] with the task output or an error. @@ -391,7 +377,7 @@ impl, T: TaskQueues> Hive { /// Returns a read-only reference to the [`Queen`]. pub fn queen(&self) -> &Q { - &self.shared().queen() + self.shared().queen() } /// Returns the number of worker threads that have been requested, i.e., the maximum number of @@ -793,43 +779,30 @@ mod affinity { } } -struct HiveTaskContext<'a, W, Q, T> -where - W: Worker, - Q: Queen, - T: TaskQueues, -{ - worker_queues: &'a T::WorkerQueues, - shared: &'a Arc>, - outcome_tx: Option<&'a OutcomeSender>, -} - -impl TaskContext for HiveTaskContext<'_, W, Q, T> -where - W: Worker, - Q: Queen, - T: TaskQueues, -{ - fn should_cancel_tasks(&self) -> bool { - self.shared.is_suspended() - } +#[cfg(feature = "batching")] +mod batching { + use crate::bee::{Queen, Worker}; + use crate::hive::{Hive, TaskQueues}; - fn submit_task(&self, input: W::Input) -> TaskId { - let task = self.shared.prepare_task(input, self.outcome_tx); - let task_id = task.id(); - self.worker_queues.push(task); - task_id - } -} + impl Hive + where + W: Worker, + Q: Queen, + T: TaskQueues, + { + /// Returns the batch limit for worker threads. + pub fn worker_batch_limit(&self) -> usize { + self.shared().worker_batch_limit() + } -impl fmt::Debug for HiveTaskContext<'_, W, Q, T> -where - W: Worker, - Q: Queen, - T: TaskQueues, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("HiveTaskContext").finish() + /// Sets the batch limit for worker threads. + /// + /// Depending on this hive's `TaskQueues` implementation, this method may: + /// * have no effect (if it does not support local batching) + /// * block the current thread until all worker thread queues can be resized. + pub fn set_worker_batch_limit(&self, batch_limit: usize) { + self.shared().set_worker_batch_limit(batch_limit); + } } } @@ -947,6 +920,46 @@ mod retry { } } +struct HiveTaskContext<'a, W, Q, T> +where + W: Worker, + Q: Queen, + T: TaskQueues, +{ + worker_queues: &'a T::WorkerQueues, + shared: &'a Arc>, + outcome_tx: Option<&'a OutcomeSender>, +} + +impl TaskContext for HiveTaskContext<'_, W, Q, T> +where + W: Worker, + Q: Queen, + T: TaskQueues, +{ + fn should_cancel_tasks(&self) -> bool { + self.shared.is_suspended() + } + + fn submit_task(&self, input: W::Input) -> TaskId { + let task = self.shared.prepare_task(input, self.outcome_tx); + let task_id = task.id(); + self.worker_queues.push(task); + task_id + } +} + +impl fmt::Debug for HiveTaskContext<'_, W, Q, T> +where + W: Worker, + Q: Queen, + T: TaskQueues, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("HiveTaskContext").finish() + } +} + #[cfg(test)] mod tests { use super::Poisoned; diff --git a/src/hive/inner/builder.rs b/src/hive/inner/builder.rs index ee2a850..808ddb9 100644 --- a/src/hive/inner/builder.rs +++ b/src/hive/inner/builder.rs @@ -137,31 +137,6 @@ pub trait Builder: BuilderConfig + Sized { self } - /// Sets the worker thread batch size. - /// - /// This may have no effect if the `batching` feature is disabled, or if the `TaskQueues` - /// implementation used for this hive does not support local batching. - /// - /// If `batch_limit` is `0`, batching is effectively disabled, but note that the performance - /// may be worse than with the `batching` feature disabled. - fn batch_limit(mut self, batch_limit: usize) -> Self { - if batch_limit == 0 { - self.config(Token).batch_limit.set(None); - } else { - self.config(Token).batch_limit.set(Some(batch_limit)); - } - self - } - - /// Sets the worker thread batch size to the global default value. - fn with_default_batch_limit(mut self) -> Self { - let _ = self - .config(Token) - .batch_limit - .set(super::config::DEFAULTS.lock().batch_limit.get()); - self - } - /// Sets set list of CPU core indices to which threads in the `Hive` should be pinned. /// /// Core indices are integers in the range `0..N`, where `N` is the number of available CPU @@ -213,6 +188,33 @@ pub trait Builder: BuilderConfig + Sized { self } + /// Sets the worker thread batch size. + /// + /// This may have no effect if the `batching` feature is disabled, or if the `TaskQueues` + /// implementation used for this hive does not support local batching. + /// + /// If `batch_limit` is `0`, batching is effectively disabled, but note that the performance + /// may be worse than with the `batching` feature disabled. + #[cfg(feature = "batching")] + fn batch_limit(mut self, batch_limit: usize) -> Self { + if batch_limit == 0 { + self.config(Token).batch_limit.set(None); + } else { + self.config(Token).batch_limit.set(Some(batch_limit)); + } + self + } + + /// Sets the worker thread batch size to the global default value. + #[cfg(feature = "batching")] + fn with_default_batch_limit(mut self) -> Self { + let _ = self + .config(Token) + .batch_limit + .set(super::config::DEFAULTS.lock().batch_limit.get()); + self + } + /// Sets the maximum number of times to retry a /// [`ApplyError::Retryable`](crate::bee::ApplyError::Retryable) error. A worker /// thread will retry a task until it either returns diff --git a/src/hive/inner/config.rs b/src/hive/inner/config.rs index e41fcbc..17bfb75 100644 --- a/src/hive/inner/config.rs +++ b/src/hive/inner/config.rs @@ -1,3 +1,5 @@ +#[cfg(feature = "batching")] +pub use self::batching::set_batch_limit_default; #[cfg(feature = "retry")] pub use self::retry::{ set_max_retries_default, set_retries_default_disabled, set_retry_factor_default, @@ -8,7 +10,6 @@ use parking_lot::Mutex; use std::sync::LazyLock; const DEFAULT_NUM_THREADS: usize = 4; -const DEFAULT_BATCH_LIMIT: usize = 10; pub static DEFAULTS: LazyLock> = LazyLock::new(|| { let mut config = Config::empty(); @@ -27,10 +28,6 @@ pub fn set_num_threads_default_all() { set_num_threads_default(num_cpus::get()); } -pub fn set_batch_limit_default(batch_limit: usize) { - DEFAULTS.lock().batch_limit.set(Some(batch_limit)); -} - /// Resets all builder defaults to their original values. pub fn reset_defaults() { let mut config = DEFAULTS.lock(); @@ -58,7 +55,8 @@ impl Config { /// Resets config values to their pre-configured defaults. fn set_const_defaults(&mut self) { self.num_threads.set(Some(DEFAULT_NUM_THREADS)); - self.batch_limit.set(Some(DEFAULT_BATCH_LIMIT)); + #[cfg(feature = "batching")] + self.set_batch_const_defaults(); #[cfg(feature = "retry")] self.set_retry_const_defaults(); } @@ -145,6 +143,23 @@ mod tests { } } +#[cfg(feature = "batching")] +mod batching { + use super::{Config, DEFAULTS}; + + const DEFAULT_BATCH_LIMIT: usize = 10; + + pub fn set_batch_limit_default(batch_limit: usize) { + DEFAULTS.lock().batch_limit.set(Some(batch_limit)); + } + + impl Config { + pub(super) fn set_batch_const_defaults(&mut self) { + self.batch_limit.set(Some(DEFAULT_BATCH_LIMIT)); + } + } +} + #[cfg(feature = "retry")] mod retry { use super::{Config, DEFAULTS}; diff --git a/src/hive/inner/mod.rs b/src/hive/inner/mod.rs index 13f9844..f3f4686 100644 --- a/src/hive/inner/mod.rs +++ b/src/hive/inner/mod.rs @@ -17,7 +17,7 @@ pub mod set_config { } pub use self::builder::{Builder, BuilderConfig}; -pub use self::queue::{ChannelTaskQueues, TaskQueues, WorkerQueues}; +pub use self::queue::{ChannelTaskQueues, TaskQueues, WorkerQueues, WorkstealingTaskQueues}; use self::counter::DualCounter; use self::gate::{Gate, PhasedGate}; diff --git a/src/hive/inner/queue/channel.rs b/src/hive/inner/queue/channel.rs index af934b6..548d7b8 100644 --- a/src/hive/inner/queue/channel.rs +++ b/src/hive/inner/queue/channel.rs @@ -115,6 +115,7 @@ impl GlobalQueue { } } + #[cfg(feature = "batching")] fn try_iter(&self) -> impl Iterator> + '_ { self.global_rx.try_iter() } @@ -142,28 +143,29 @@ pub struct ChannelWorkerQueues { } impl ChannelWorkerQueues { - fn new(thread_index: usize, global_queue: &Arc>, config: &Config) -> Self { + fn new(thread_index: usize, global_queue: &Arc>, _config: &Config) -> Self { Self { _thread_index: thread_index, global: Arc::clone(global_queue), local_abandoned: Default::default(), #[cfg(feature = "batching")] local_batch: RwLock::new(crossbeam_queue::ArrayQueue::new( - config.batch_limit.get_or_default().max(1), + _config.batch_limit.get_or_default().max(1), )), #[cfg(feature = "retry")] - local_retry: super::retry::RetryQueue::new(config.retry_factor.get_or_default()), + local_retry: super::retry::RetryQueue::new(_config.retry_factor.get_or_default()), } } /// Updates the local queues based on the provided `config`: /// If `batching` is enabled, resizes the batch queue if necessary. /// If `retry` is enabled, updates the retry factor. - fn update(&self, config: &Config) { + fn update(&self, _config: &Config) { #[cfg(feature = "batching")] - self.update_batch(config); + self.update_batch(_config); #[cfg(feature = "retry")] - self.local_retry.set_delay_factor(config.retry_factor.get_or_default()); + self.local_retry + .set_delay_factor(_config.retry_factor.get_or_default()); } /// Consumes this `ChannelWorkerQueues` and drains the tasks currently in the queues into diff --git a/src/hive/inner/queue/mod.rs b/src/hive/inner/queue/mod.rs index 0823a58..69683ec 100644 --- a/src/hive/inner/queue/mod.rs +++ b/src/hive/inner/queue/mod.rs @@ -4,6 +4,7 @@ mod retry; mod workstealing; pub use self::channel::ChannelTaskQueues; +pub use self::workstealing::WorkstealingTaskQueues; use super::{Config, Task, Token}; use crate::bee::Worker; diff --git a/src/hive/inner/queue/workstealing.rs b/src/hive/inner/queue/workstealing.rs index e8dba0f..aeba049 100644 --- a/src/hive/inner/queue/workstealing.rs +++ b/src/hive/inner/queue/workstealing.rs @@ -1,5 +1,5 @@ use super::{Config, PopTaskError, Task, TaskQueues, Token, WorkerQueues}; -use crate::atomic::{Atomic, AtomicBool, AtomicUsize}; +use crate::atomic::{Atomic, AtomicBool}; use crate::bee::Worker; use crossbeam_deque::{Injector, Stealer}; use crossbeam_queue::SegQueue; @@ -8,7 +8,7 @@ use rand::prelude::*; use std::ops::Deref; use std::sync::Arc; -struct WorkstealingTaskQueues { +pub struct WorkstealingTaskQueues { global: Arc>, local: RwLock>>>, } @@ -46,7 +46,7 @@ impl TaskQueues for WorkstealingTaskQueues { let local_queue = crossbeam_deque::Worker::new_fifo(); self.global.add_stealer(local_queue.stealer()); let shared = &self.local.read()[thread_index]; - WorkstealingWorkerQueues::new(local_queue, Arc::clone(&shared)) + WorkstealingWorkerQueues::new(local_queue, Arc::clone(shared)) } fn try_push_global(&self, task: Task) -> Result<(), Task> { @@ -193,31 +193,31 @@ struct LocalQueueShared { /// queue of abandon tasks local_abandoned: SegQueue>, #[cfg(feature = "batching")] - batch_limit: AtomicUsize, + batch_limit: crate::atomic::AtomicUsize, /// thread-local queues used for tasks that are waiting to be retried after a failure #[cfg(feature = "retry")] local_retry: super::retry::RetryQueue, } impl LocalQueueShared { - fn new(thread_index: usize, global: &Arc>, config: &Config) -> Self { + fn new(thread_index: usize, global: &Arc>, _config: &Config) -> Self { Self { _thread_index: thread_index, global: Arc::clone(global), local_abandoned: Default::default(), #[cfg(feature = "batching")] - batch_limit: AtomicUsize::new(config.batch_limit.get_or_default()), + batch_limit: crate::atomic::AtomicUsize::new(_config.batch_limit.get_or_default()), #[cfg(feature = "retry")] - local_retry: super::retry::RetryQueue::new(config.retry_factor.get_or_default()), + local_retry: super::retry::RetryQueue::new(_config.retry_factor.get_or_default()), } } - fn update(&self, config: &Config) { + fn update(&self, _config: &Config) { #[cfg(feature = "batching")] - self.batch_limit.set(config.batch_limit.get_or_default()); + self.batch_limit.set(_config.batch_limit.get_or_default()); #[cfg(feature = "retry")] self.local_retry - .set_delay_factor(config.retry_factor.get_or_default()); + .set_delay_factor(_config.retry_factor.get_or_default()); } fn try_pop( diff --git a/src/hive/inner/shared.rs b/src/hive/inner/shared.rs index c881c83..96119fd 100644 --- a/src/hive/inner/shared.rs +++ b/src/hive/inner/shared.rs @@ -359,39 +359,6 @@ impl, T: TaskQueues> Shared { self.no_work_notify_all(); } - /// Returns the local queue batch size. - pub fn worker_batch_limit(&self) -> usize { - self.config.batch_limit.get().unwrap_or_default() - } - - /// Changes the local queue batch size. This requires allocating a new queue for each - /// worker thread. - /// - /// Note: this method will block the current thread waiting for all local queues to become - /// writable; if `batch_limit` is less than the current batch size, this method will also - /// block while any thread's queue length is > `batch_limit` before moving the elements. - /// - /// TODO: this needs to be moved to an extension that is specific to channel hive - pub fn set_worker_batch_limit(&self, batch_limit: usize) -> usize { - // update the batch size first so any new threads spawned won't need to have their - // queues resized - let prev_batch_limit = self - .config - .batch_limit - .try_set(batch_limit) - .unwrap_or_default(); - if prev_batch_limit == batch_limit { - return prev_batch_limit; - } - let num_threads = self.num_threads(); - if num_threads == 0 { - return prev_batch_limit; - } - self.task_queues - .update_for_threads(0, num_threads, &self.config); - prev_batch_limit - } - /// Returns a reference to the `Queen`. /// /// Note that, if the queen is a `QueenMut`, the returned value will be a `QueenCell`, and it @@ -611,6 +578,53 @@ mod affinity { } } +#[cfg(feature = "batching")] +mod batching { + use super::Shared; + use crate::bee::{Queen, Worker}; + use crate::hive::TaskQueues; + + impl Shared + where + W: Worker, + Q: Queen, + T: TaskQueues, + { + /// Returns the local queue batch size. + pub fn worker_batch_limit(&self) -> usize { + self.config.batch_limit.get().unwrap_or_default() + } + + /// Changes the local queue batch size. This requires allocating a new queue for each + /// worker thread. + /// + /// Note: this method will block the current thread waiting for all local queues to become + /// writable; if `batch_limit` is less than the current batch size, this method will also + /// block while any thread's queue length is > `batch_limit` before moving the elements. + /// + /// TODO: this needs to be moved to an extension that is specific to channel hive + pub fn set_worker_batch_limit(&self, batch_limit: usize) -> usize { + // update the batch size first so any new threads spawned won't need to have their + // queues resized + let prev_batch_limit = self + .config + .batch_limit + .try_set(batch_limit) + .unwrap_or_default(); + if prev_batch_limit == batch_limit { + return prev_batch_limit; + } + let num_threads = self.num_threads(); + if num_threads == 0 { + return prev_batch_limit; + } + self.task_queues + .update_for_threads(0, num_threads, &self.config); + prev_batch_limit + } + } +} + #[cfg(feature = "retry")] mod retry { use crate::bee::{Queen, TaskId, Worker}; diff --git a/src/hive/inner/task.rs b/src/hive/inner/task.rs index 7123acb..c4ad1ac 100644 --- a/src/hive/inner/task.rs +++ b/src/hive/inner/task.rs @@ -20,60 +20,82 @@ impl Task { } #[cfg(not(feature = "retry"))] -impl Task { - /// Creates a new `Task`. - pub fn new(id: TaskId, input: W::Input, outcome_tx: Option>) -> Self { - Task { - id, - input, - outcome_tx, +mod no_retry { + use super::Task; + use crate::bee::{TaskId, Worker}; + use crate::hive::OutcomeSender; + + impl Task { + /// Creates a new `Task`. + pub fn new(id: TaskId, input: W::Input, outcome_tx: Option>) -> Self { + Task { + id, + input, + outcome_tx, + } + } + + pub fn into_parts(self) -> (TaskId, W::Input, Option>) { + (self.id, self.input, self.outcome_tx) } } - pub fn into_parts(self) -> (TaskId, W::Input, Option>) { - (self.id, self.input, self.outcome_tx) + impl> Clone for Task { + fn clone(&self) -> Self { + Self { + id: self.id.clone(), + input: self.input.clone(), + outcome_tx: self.outcome_tx.clone(), + } + } } } #[cfg(feature = "retry")] -impl Task { - /// Creates a new `Task`. - pub fn new(id: TaskId, input: W::Input, outcome_tx: Option>) -> Self { - Task { - id, - input, - outcome_tx, - attempt: 0, +mod retry { + use super::Task; + use crate::bee::{TaskId, Worker}; + use crate::hive::OutcomeSender; + + impl Task { + /// Creates a new `Task`. + pub fn new(id: TaskId, input: W::Input, outcome_tx: Option>) -> Self { + Task { + id, + input, + outcome_tx, + attempt: 0, + } } - } - /// Creates a new `Task`. - pub fn with_attempt( - id: TaskId, - input: W::Input, - outcome_tx: Option>, - attempt: u32, - ) -> Self { - Task { - id, - input, - outcome_tx, - attempt, + /// Creates a new `Task`. + pub fn with_attempt( + id: TaskId, + input: W::Input, + outcome_tx: Option>, + attempt: u32, + ) -> Self { + Task { + id, + input, + outcome_tx, + attempt, + } } - } - pub fn into_parts(self) -> (TaskId, W::Input, u32, Option>) { - (self.id, self.input, self.attempt, self.outcome_tx) + pub fn into_parts(self) -> (TaskId, W::Input, u32, Option>) { + (self.id, self.input, self.attempt, self.outcome_tx) + } } -} -impl> Clone for Task { - fn clone(&self) -> Self { - Self { - id: self.id.clone(), - input: self.input.clone(), - outcome_tx: self.outcome_tx.clone(), - attempt: self.attempt.clone(), + impl> Clone for Task { + fn clone(&self) -> Self { + Self { + id: self.id, + input: self.input.clone(), + outcome_tx: self.outcome_tx.clone(), + attempt: self.attempt, + } } } } diff --git a/src/hive/mod.rs b/src/hive/mod.rs index 6125e5c..3b0c537 100644 --- a/src/hive/mod.rs +++ b/src/hive/mod.rs @@ -367,7 +367,7 @@ mod outcome; pub use self::builder::{BeeBuilder, ChannelBuilder, FullBuilder, OpenBuilder}; pub use self::hive::{Hive, Poisoned}; pub use self::husk::Husk; -pub use self::inner::{set_config::*, Builder, ChannelTaskQueues}; +pub use self::inner::{set_config::*, Builder, ChannelTaskQueues, WorkstealingTaskQueues}; pub use self::outcome::{Outcome, OutcomeBatch, OutcomeIteratorExt, OutcomeStore}; use self::inner::{Config, Shared, Task, TaskQueues, WorkerQueues}; From 931393f3269a13781ac1bd0cacff05ed708ba015 Mon Sep 17 00:00:00 2001 From: jdidion Date: Wed, 19 Feb 2025 11:43:48 -0800 Subject: [PATCH 14/67] add workstealing builder --- src/hive/builder/workstealing.rs | 73 ++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 src/hive/builder/workstealing.rs diff --git a/src/hive/builder/workstealing.rs b/src/hive/builder/workstealing.rs new file mode 100644 index 0000000..e8f7006 --- /dev/null +++ b/src/hive/builder/workstealing.rs @@ -0,0 +1,73 @@ +use super::{BuilderConfig, FullBuilder, Token}; +use crate::bee::{CloneQueen, DefaultQueen, Queen, QueenCell, QueenMut, Worker}; +use crate::hive::{Config, WorkstealingTaskQueues}; + +#[derive(Clone, Default)] +pub struct WorkstealingBuilder(Config); + +impl WorkstealingBuilder { + /// Creates a new `WorkstealingBuilder` with the given queen and no options configured. + pub fn empty() -> Self { + Self(Config::empty()) + } + + /// Consumes this `Builder` and returns a new [`FullBuilder`] using the given [`Queen`] to + /// create [`Worker`]s. + pub fn with_queen(self, queen: I) -> FullBuilder> + where + Q: Queen, + I: Into, + { + FullBuilder::from(self.0, queen.into()) + } + + /// Consumes this `Builder` and returns a new [`FullBuilder`] using a [`Queen`] created with + /// [`Q::default()`](std::default::Default) to create [`Worker`]s. + pub fn with_queen_default(self) -> FullBuilder> + where + Q: Queen + Default, + { + FullBuilder::from(self.0, Q::default()) + } + + /// Consumes this `Builder` and returns a new [`FullBuilder`] using a [`QueenMut`] created with + /// [`Q::default()`](std::default::Default) to create [`Worker`]s. + pub fn with_queen_mut_default( + self, + ) -> FullBuilder, WorkstealingTaskQueues> + where + Q: QueenMut + Default, + { + FullBuilder::from(self.0, QueenCell::new(Q::default())) + } + + /// Consumes this `Builder` and returns a new [`FullBuilder`] with [`Worker`]s created by + /// cloning `worker`. + pub fn with_worker(self, worker: W) -> FullBuilder, WorkstealingTaskQueues> + where + W: Worker + Send + Sync + Clone, + { + FullBuilder::from(self.0, CloneQueen::new(worker)) + } + + /// Consumes this `Builder` and returns a new [`FullBuilder`] with [`Worker`]s created using + /// [`W::default()`](std::default::Default). + pub fn with_worker_default(self) -> FullBuilder, WorkstealingTaskQueues> + where + W: Worker + Send + Sync + Default, + { + FullBuilder::from(self.0, DefaultQueen::default()) + } +} + +impl BuilderConfig for WorkstealingBuilder { + fn config(&mut self, _: Token) -> &mut Config { + &mut self.0 + } +} + +impl From for WorkstealingBuilder { + fn from(value: Config) -> Self { + Self(value) + } +} From f5e7901209958c3ae051755fce059257dda7e0bd Mon Sep 17 00:00:00 2001 From: jdidion Date: Wed, 19 Feb 2025 22:49:22 -0800 Subject: [PATCH 15/67] fix tests --- CHANGELOG.md | 1 + Cargo.toml | 2 +- README.md | 3 +- src/atomic.rs | 2 + src/bee/queen.rs | 4 +- src/hive/builder/mod.rs | 30 ++- src/hive/builder/open.rs | 95 ++++----- src/hive/hive.rs | 106 ++++++---- src/hive/husk.rs | 14 +- src/hive/inner/builder.rs | 53 ++--- src/hive/inner/queue/channel.rs | 280 ++++++++++++++++++--------- src/hive/inner/queue/mod.rs | 14 +- src/hive/inner/queue/retry.rs | 2 +- src/hive/inner/queue/workstealing.rs | 109 ++++++----- src/hive/inner/shared.rs | 9 +- src/hive/mod.rs | 69 ++++--- src/lib.rs | 27 ++- 17 files changed, 503 insertions(+), 317 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ae21dc8..2b25839 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ The general theme of this release is performance improvement by eliminating thre * `BeeBuilder` and `FullBuilder` are intermediate types that generally should not be instantiated directly. * `beekeeper::bee::Queen::create` now takes `&self` rather than `&mut self`. There is a new type, `beekeeper::bee::QueenMut`, with a `create(&mut self)` method, and needs to be wrapped in a `beekeeper::bee::QueenCell` to implement the `Queen` trait. This enables the `Hive` to create new workers without locking in the case of a `Queen` that does not need mutable state. * `beekeeper::bee::Context` now takes a generic parameter that must be input type of the `Worker`. + * `beekeeper::hive::Hive::try_into_husk` now has an `urgent` parameter to indicate whether queued tasks should be abandoned when shutting down the hive (`true`) or if they should be allowed to finish processing (`false`). * Features * Added the `TaskQueues` trait, which enables `Hive` to be specialized for different implementations of global (i.e., sending tasks from the `Hive` to worker threads) and local (i.e., worker thread-specific) queues. * `ChannelTaskQueues` implements the existing behavior, using a channel for sending tasks. diff --git a/Cargo.toml b/Cargo.toml index 061ea0e..de54cfd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,7 +39,7 @@ name = "perf" harness = false [features] -default = ["affinity", "batching", "retry"] +default = ["batching", "retry"] affinity = ["dep:core_affinity"] batching = [] retry = [] diff --git a/README.md b/README.md index 0bea8b5..68266e4 100644 --- a/README.md +++ b/README.md @@ -108,7 +108,8 @@ There are multiple methods in each group that differ by how the task results (ca * The methods with the `_unordered` suffix instead return an unordered iterator, which may be more performant than the ordered iterator * The methods with the `_send` suffix accept a channel `Sender` and send the `Outcome`s to that - channel as they are completed + channel as they are completed. + * Note that, for these methods, the `tx` parameter is of type `Borrow>`, which allows you to pass in either a value or a reference. Passing a value causes the `Sender` to be dropped after the call, while passing a reference allows you to use the same `Sender` for multiple `_send` calls. Note that in the later case, you need to explicitly drop the sender (e.g., `drop(tx)`), pass it by value to the last `_send` call, or be careful about how you obtain outcomes from the `Receiver` as methods such as `recv` and `iter` will block until the `Sender` is dropped. You should *not* pass clones of the `Sender` to `_send` methods as this results in slightly worse performance and still has the requirement that you manually drop the original `Sender` value. * The methods with the `_store` suffix store the `Outcome`s in the `Hive`; these may be retrieved later using the [`Hive::take_stored()`](https://docs.rs/beekeeper/latest/beekeeper/hive/struct.Hive.html#method.take_stored) method, using one of the `remove*` methods (which requires diff --git a/src/atomic.rs b/src/atomic.rs index 20b0e97..6452717 100644 --- a/src/atomic.rs +++ b/src/atomic.rs @@ -142,6 +142,7 @@ macro_rules! atomic_int { } atomic!(bool); +atomic_int!(u8); atomic_int!(u32); atomic_int!(u64); atomic_int!(usize); @@ -458,6 +459,7 @@ mod tests { }; } + test_numeric_type!(AtomicU8); test_numeric_type!(AtomicU32); test_numeric_type!(AtomicU64); test_numeric_type!(AtomicUsize); diff --git a/src/bee/queen.rs b/src/bee/queen.rs index 7642467..8f9d808 100644 --- a/src/bee/queen.rs +++ b/src/bee/queen.rs @@ -78,7 +78,7 @@ impl From for QueenCell { /// type Output = u8; /// type Error = (); /// -/// fn apply(&mut self, input: u8, _: &Context) -> WorkerResult { +/// fn apply(&mut self, input: u8, _: &Context) -> WorkerResult { /// Ok(self.0.saturating_add(input)) /// } /// } @@ -88,7 +88,7 @@ impl From for QueenCell { /// impl Queen for MyQueen { /// type Kind = MyWorker; /// -/// fn create(&mut self) -> Self::Kind { +/// fn create(&self) -> Self::Kind { /// MyWorker::default() /// } /// } diff --git a/src/hive/builder/mod.rs b/src/hive/builder/mod.rs index 3f6401e..b707dfc 100644 --- a/src/hive/builder/mod.rs +++ b/src/hive/builder/mod.rs @@ -1,5 +1,4 @@ -//! There are a few different builder types. All builders implement the `BuilderConfig` trait, -//! which provides methods to set configuration parameters. +//! There are a few different builder types. //! //! * Open: has no type parameters; can only set config parameters. Has methods to create //! typed builders. @@ -13,6 +12,33 @@ //! Bee / //! | / //! Full +//! +//! All builders implement the `BuilderConfig` trait, which provides methods to set configuration +//! parameters. The configuration options available: +//! * [`Builder::num_threads`]: number of worker threads that will be spawned by the built `Hive`. +//! * [`Builder::with_default_num_threads`] will set `num_threads` to the global default value. +//! * [`Builder::with_thread_per_core`] will set `num_threads` to the number of available CPU +//! cores. +//! * [`Builder::thread_name`]: thread name for each of the threads spawned by the built `Hive`. By +//! default, threads are unnamed. +//! * [`Builder::thread_stack_size`]: stack size (in bytes) for each of the threads spawned by the +//! built `Hive`. See the +//! [`std::thread`](https://doc.rust-lang.org/stable/std/thread/index.html#stack-size) +//! documentation for details on the default stack size. +//! +//! The following configuration options are available when the `retry` feature is enabled: +//! * [`Builder::max_retries`]: maximum number of times a `Worker` will retry an +//! [`ApplyError::Retryable`](crate::bee::ApplyError#Retryable) before giving up. +//! * [`Builder::retry_factor`]: [`Duration`](std::time::Duration) factor for exponential backoff +//! when retrying an `ApplyError::Retryable` error. +//! * [`Builder::with_default_retries`] sets the retry options to the global defaults, while +//! [`Builder::with_no_retries`] disabled retrying. +//! +//! The following configuration options are available when the `affinity` feature is enabled: +//! * [`Builder::core_affinity`]: List of CPU core indices to which the threads should be pinned. +//! * [`Builder::with_default_core_affinity`] will set the list to all CPU core indices, though +//! only the first `num_threads` indices will be used. +//! mod bee; mod channel; mod full; diff --git a/src/hive/builder/open.rs b/src/hive/builder/open.rs index b9791b5..fd539b1 100644 --- a/src/hive/builder/open.rs +++ b/src/hive/builder/open.rs @@ -4,41 +4,22 @@ use crate::hive::Config; /// A builder for a [`Hive`](crate::hive::Hive). /// -/// Calling [`Builder::new()`] creates an unconfigured `Builder`, while calling -/// [`Builder::default()`] creates a `Builder` with fields preset to the global default values. +/// Calling [`OpenBuilder::empty()`] creates an unconfigured `Builder`, while calling +/// [`OpenBuilder::default()`] creates a `Builder` with fields preset to the global default values. /// Global defaults can be changed using the /// [`beekeeper::hive::set_*_default`](crate::hive#functions) functions. /// -/// The configuration options available: -/// * [`Builder::num_threads`]: number of worker threads that will be spawned by the built `Hive`. -/// * [`Builder::with_default_num_threads`] will set `num_threads` to the global default value. -/// * [`Builder::with_thread_per_core`] will set `num_threads` to the number of available CPU -/// cores. -/// * [`Builder::thread_name`]: thread name for each of the threads spawned by the built `Hive`. By -/// default, threads are unnamed. -/// * [`Builder::thread_stack_size`]: stack size (in bytes) for each of the threads spawned by the -/// built `Hive`. See the -/// [`std::thread`](https://doc.rust-lang.org/stable/std/thread/index.html#stack-size) -/// documentation for details on the default stack size. +/// See the [module documentation](crate::hive::builder) for details on the available configuration +/// options. /// -/// The following configuration options are available when the `retry` feature is enabled: -/// * [`Builder::max_retries`]: maximum number of times a `Worker` will retry an -/// [`ApplyError::Retryable`](crate::bee::ApplyError#Retryable) before giving up. -/// * [`Builder::retry_factor`]: [`Duration`](std::time::Duration) factor for exponential backoff -/// when retrying an `ApplyError::Retryable` error. -/// * [`Builder::with_default_retries`] sets the retry options to the global defaults, while -/// [`Builder::with_no_retries`] disabled retrying. +/// This builder needs to be specialized to both the `Queen` and `TaskQueues` types. You can do +/// this in either order. /// -/// The following configuration options are available when the `affinity` feature is enabled: -/// * [`Builder::core_affinity`]: List of CPU core indices to which the threads should be pinned. -/// * [`Builder::with_default_core_affinity`] will set the list to all CPU core indices, though -/// only the first `num_threads` indices will be used. -/// -/// To create the [`Hive`], call one of the `build*` methods: -/// * [`Builder::build`] requires a [`Queen`] instance. -/// * [`Builder::build_default`] requires a [`Queen`] type that implements [`Default`]. -/// * [`Builder::build_with`] requires a [`Worker`] instance that implements [`Clone`]. -/// * [`Builder::build_with_default`] requires a [`Worker`] type that implements [`Default`]. +/// * Calling one of the `with_queen*` methods returns a `BeeBuilder` specialized to a `Queen`. +/// * Calling `with_worker` or `with_worker_default` returns a `BeeBuilder` specialized to a +/// `CloneQueen` or `DefaultQueen` (respectively) for a specific `Worker` type. +/// * Calling `with_channel_queues` or `with_workstealing_queues` returns a `ChannelBuilder` or +/// `WorkstealingBuilder` specialized to a `TaskQueues` type. /// /// # Examples /// @@ -46,9 +27,10 @@ use crate::hive::Config; /// a 8 MB stack size: /// /// ``` +/// # use beekeeper::hive::{Builder, OpenBuilder}; /// type MyWorker = beekeeper::bee::stock::ThunkWorker<()>; /// -/// let hive = beekeeper::hive::Builder::empty() +/// let hive = OpenBuilder::empty() /// .num_threads(8) /// .thread_stack_size(8_000_000) /// .with_worker_default::() @@ -70,8 +52,8 @@ impl OpenBuilder { /// # Examples /// /// ``` - /// # use beekeeper::hive::{Builder, Hive}; - /// # use beekeeper::bee::{Context, Queen, Worker, WorkerResult}; + /// # use beekeeper::hive::{Builder, ChannelBuilder, Hive}; + /// # use beekeeper::bee::{Context, QueenMut, Worker, WorkerResult}; /// /// #[derive(Debug)] /// struct CounterWorker { @@ -95,7 +77,7 @@ impl OpenBuilder { /// type Output = String; /// type Error = (); /// - /// fn apply(&mut self, input: Self::Input, _: &Context) -> WorkerResult { + /// fn apply(&mut self, input: Self::Input, _: &Context) -> WorkerResult { /// self.input_count += 1; /// self.input_sum += input; /// let s = format!( @@ -111,7 +93,7 @@ impl OpenBuilder { /// num_workers: usize /// } /// - /// impl Queen for CounterQueen { + /// impl QueenMut for CounterQueen { /// type Kind = CounterWorker; /// /// fn create(&mut self) -> Self::Kind { @@ -121,16 +103,17 @@ impl OpenBuilder { /// } /// /// # fn main() { - /// let hive = Builder::new() + /// let hive = ChannelBuilder::empty() /// .num_threads(8) /// .thread_stack_size(4_000_000) - /// .build(CounterQueen::default()); + /// .with_queen_mut_default::() + /// .build(); /// /// for i in 0..100 { /// hive.apply_store(i); /// } - /// let husk = hive.try_into_husk().unwrap(); - /// assert_eq!(husk.queen().num_workers, 8); + /// let husk = hive.try_into_husk(false).unwrap(); + /// assert_eq!(husk.queen().get().num_workers, 8); /// # } /// ``` pub fn with_queen>(self, queen: I) -> BeeBuilder { @@ -155,7 +138,7 @@ impl OpenBuilder { /// # Examples /// /// ``` - /// # use beekeeper::hive::{Builder, OutcomeIteratorExt}; + /// # use beekeeper::hive::{Builder, ChannelBuilder, OutcomeIteratorExt}; /// # use beekeeper::bee::{Context, Worker, WorkerResult}; /// /// #[derive(Debug, Clone)] @@ -173,13 +156,13 @@ impl OpenBuilder { /// type Output = isize; /// type Error = (); /// - /// fn apply(&mut self, input: Self::Input, _: &Context) -> WorkerResult { + /// fn apply(&mut self, input: Self::Input, _: &Context) -> WorkerResult { /// let (operand, operator) = input; /// let value = match operator % 4 { - /// 0 => operand + self.config(Token), - /// 1 => operand - self.config(Token), - /// 2 => operand * self.config(Token), - /// 3 => operand / self.config(Token), + /// 0 => operand + self.0, + /// 1 => operand - self.0, + /// 2 => operand * self.0, + /// 3 => operand / self.0, /// _ => unreachable!(), /// }; /// Ok(value) @@ -187,10 +170,11 @@ impl OpenBuilder { /// } /// /// # fn main() { - /// let hive = Builder::new() + /// let hive = ChannelBuilder::empty() /// .num_threads(8) /// .thread_stack_size(4_000_000) - /// .build_with(MathWorker(5isize)); + /// .with_worker(MathWorker(5isize)) + /// .build(); /// /// let sum: isize = hive /// .map((0..100).zip((0..4).cycle())) @@ -212,7 +196,7 @@ impl OpenBuilder { /// # Examples /// /// ``` - /// # use beekeeper::hive::{Builder, OutcomeIteratorExt}; + /// # use beekeeper::hive::{Builder, ChannelBuilder, OutcomeIteratorExt}; /// # use beekeeper::bee::{Context, Worker, WorkerResult}; /// # use std::num::NonZeroIsize; /// @@ -224,13 +208,13 @@ impl OpenBuilder { /// type Output = isize; /// type Error = (); /// - /// fn apply(&mut self, input: Self::Input, _: &Context) -> WorkerResult { + /// fn apply(&mut self, input: Self::Input, _: &Context) -> WorkerResult { /// let (operand, operator) = input; /// let result = match operator % 4 { - /// 0 => self.config(Token) + operand.get(), - /// 1 => self.config(Token) - operand.get(), - /// 2 => self.config(Token) * operand.get(), - /// 3 => self.config(Token) / operand.get(), + /// 0 => self.0 + operand.get(), + /// 1 => self.0 - operand.get(), + /// 2 => self.0 * operand.get(), + /// 3 => self.0 / operand.get(), /// _ => unreachable!(), /// }; /// Ok(result) @@ -238,10 +222,11 @@ impl OpenBuilder { /// } /// /// # fn main() { - /// let hive = Builder::new() + /// let hive = ChannelBuilder::empty() /// .num_threads(8) /// .thread_stack_size(4_000_000) - /// .build_with_default::(); + /// .with_worker_default::() + /// .build(); /// /// let sum: isize = hive /// .map((1..=100).map(|i| NonZeroIsize::new(i).unwrap()).zip((0..4).cycle())) diff --git a/src/hive/hive.rs b/src/hive/hive.rs index bb5f863..cf22a60 100644 --- a/src/hive/hive.rs +++ b/src/hive/hive.rs @@ -3,6 +3,7 @@ use super::{ OutcomeIteratorExt, OutcomeSender, Shared, SpawnError, TaskQueues, WorkerQueues, }; use crate::bee::{DefaultQueen, Queen, TaskContext, TaskId, Worker}; +use std::borrow::Borrow; use std::collections::HashMap; use std::fmt; use std::ops::{Deref, DerefMut}; @@ -100,13 +101,14 @@ impl, T: TaskQueues> Hive { pub fn apply(&self, input: W::Input) -> Outcome { let (tx, rx) = super::outcome_channel(); let task_id = self.shared().send_one_global(input, Some(&tx)); + drop(tx); rx.recv().unwrap_or_else(|_| Outcome::Missing { task_id }) } /// Sends one `input` to the `Hive` for processing and returns its ID. The [`Outcome`] of /// the task will be sent to `tx` upon completion. - pub fn apply_send(&self, input: W::Input, tx: &OutcomeSender) -> TaskId { - self.shared().send_one_global(input, Some(tx)) + pub fn apply_send>>(&self, input: W::Input, tx: S) -> TaskId { + self.shared().send_one_global(input, Some(tx.borrow())) } /// Sends one `input` to the `Hive` for processing and returns its ID immediately. The @@ -127,6 +129,7 @@ impl, T: TaskQueues> Hive { { let (tx, rx) = super::outcome_channel(); let task_ids = self.shared().send_batch_global(batch, Some(&tx)); + drop(tx); rx.select_ordered(task_ids) } @@ -152,12 +155,14 @@ impl, T: TaskQueues> Hive { /// /// This method is more efficient than [`map_send`](Self::map_send) when the input is an /// [`ExactSizeIterator`]. - pub fn swarm_send(&self, batch: I, outcome_tx: &OutcomeSender) -> Vec + pub fn swarm_send(&self, batch: I, outcome_tx: S) -> Vec where + S: Borrow>, I: IntoIterator, I::IntoIter: ExactSizeIterator, { - self.shared().send_batch_global(batch, Some(outcome_tx)) + self.shared() + .send_batch_global(batch, Some(outcome_tx.borrow())) } /// Sends a `batch` of inputs to the `Hive` for processing, and returns a [`Vec`] of task IDs. @@ -185,6 +190,7 @@ impl, T: TaskQueues> Hive { .into_iter() .map(|task| self.apply_send(task, &tx)) .collect(); + drop(tx); rx.select_ordered(task_ids) } @@ -203,6 +209,7 @@ impl, T: TaskQueues> Hive { .into_iter() .map(|task| self.apply_send(task, &tx)) .collect(); + drop(tx); rx.select_unordered(task_ids) } @@ -211,14 +218,14 @@ impl, T: TaskQueues> Hive { /// /// [`swarm_send`](Self::swarm_send) should be preferred when `inputs` is an /// [`ExactSizeIterator`]. - pub fn map_send( + pub fn map_send>>( &self, inputs: impl IntoIterator, - tx: &OutcomeSender, + tx: S, ) -> Vec { inputs .into_iter() - .map(|input| self.apply_send(input, tx)) + .map(|input| self.apply_send(input, tx.borrow())) .collect() } @@ -237,7 +244,7 @@ impl, T: TaskQueues> Hive { /// Iterates over `items` and calls `f` with a mutable reference to a state value (initialized /// to `init`) and each item. `F` returns an input that is sent to the `Hive` for processing. /// Returns an [`OutcomeBatch`] of the outputs and the final state value. - pub fn scan( + pub fn scan( &self, items: impl IntoIterator, init: St, @@ -248,6 +255,7 @@ impl, T: TaskQueues> Hive { { let (tx, rx) = super::outcome_channel(); let (task_ids, fold_value) = self.scan_send(items, &tx, init, f); + drop(tx); let outcomes = rx.select_unordered(task_ids).into(); (outcomes, fold_value) } @@ -256,7 +264,7 @@ impl, T: TaskQueues> Hive { /// to `init`) and each item. `F` returns an input that is sent to the `Hive` for processing, /// or an error. Returns an [`OutcomeBatch`] of the outputs, a [`Vec`] of errors, and the final /// state value. - pub fn try_scan( + pub fn try_scan( &self, items: impl IntoIterator, init: St, @@ -276,6 +284,7 @@ impl, T: TaskQueues> Hive { (task_ids, errors, acc) }, ); + drop(tx); let outcomes = rx.select_unordered(task_ids).into(); (outcomes, errors, fold_value) } @@ -284,21 +293,22 @@ impl, T: TaskQueues> Hive { /// to `init`) and each item. `f` returns an input that is sent to the `Hive` for processing. /// The outputs are sent to `tx` in the order they become available. Returns a [`Vec`] of the /// task IDs and the final state value. - pub fn scan_send( + pub fn scan_send( &self, items: impl IntoIterator, - tx: &OutcomeSender, + tx: S, init: St, mut f: F, ) -> (Vec, St) where + S: Borrow>, F: FnMut(&mut St, I) -> W::Input, { items .into_iter() .fold((Vec::new(), init), |(mut task_ids, mut acc), item| { let input = f(&mut acc, item); - task_ids.push(self.apply_send(input, tx)); + task_ids.push(self.apply_send(input, tx.borrow())); (task_ids, acc) }) } @@ -308,20 +318,21 @@ impl, T: TaskQueues> Hive { /// or an error. The outputs are sent to `tx` in the order they become available. This /// function returns the final state value and a [`Vec`] of results, where each result is /// either a task ID or an error. - pub fn try_scan_send( + pub fn try_scan_send( &self, items: impl IntoIterator, - tx: &OutcomeSender, + tx: S, init: St, mut f: F, ) -> (Vec>, St) where + S: Borrow>, F: FnMut(&mut St, I) -> Result, { items .into_iter() .fold((Vec::new(), init), |(mut results, mut acc), inp| { - results.push(f(&mut acc, inp).map(|input| self.apply_send(input, tx))); + results.push(f(&mut acc, inp).map(|input| self.apply_send(input, tx.borrow()))); (results, acc) }) } @@ -330,7 +341,7 @@ impl, T: TaskQueues> Hive { /// to `init`) and each item. `f` returns an input that is sent to the `Hive` for processing. /// This function returns the final state value and a [`Vec`] of task IDs. The [`Outcome`]s of /// the tasks are retained and available for later retrieval. - pub fn scan_store( + pub fn scan_store( &self, items: impl IntoIterator, init: St, @@ -353,7 +364,7 @@ impl, T: TaskQueues> Hive { /// or an error. This function returns the final value of the state value and a [`Vec`] of /// results, where each result is either a task ID or an error. The [`Outcome`]s of the /// tasks are retained and available for later retrieval. - pub fn try_scan_store( + pub fn try_scan_store( &self, items: impl IntoIterator, init: St, @@ -451,14 +462,15 @@ impl, T: TaskQueues> Hive { /// /// ``` /// use beekeeper::bee::stock::{Thunk, ThunkWorker}; - /// use beekeeper::hive::Builder; + /// use beekeeper::hive::{Builder, ChannelBuilder}; /// use std::thread; /// use std::time::Duration; /// /// # fn main() { - /// let hive = Builder::new() + /// let hive = ChannelBuilder::empty() /// .num_threads(4) - /// .build_with_default::>(); + /// .with_worker_default::>() + /// .build(); /// hive.map((0..10).map(|_| Thunk::of(|| thread::sleep(Duration::from_secs(3))))); /// thread::sleep(Duration::from_secs(1)); // Allow first set of tasks to be started. /// // There should be 4 active tasks and 6 queued tasks. @@ -519,14 +531,22 @@ impl, T: TaskQueues> Hive { self.shared().take_outcomes() } - fn try_close(mut self) -> Option> { + /// Consumes this `Hive` and attempts to acquire the shared data object. + /// + /// This closes the task queues so that no more tasks may be submitted. If `urgent` is `true`, + /// worker threads are also prevented from taking any more tasks from the queues; otherwise, + /// this method blocks while all queued are processed. + /// + /// If this `Hive` has been cloned, and those clones have not been dropped, this method returns + /// `None`. + fn try_close(mut self, urgent: bool) -> Option> { if self.shared().num_referrers() > 1 { return None; } // take the inner value and replace it with `None` let shared = self.0.take().unwrap(); // close the global queue to prevent new tasks from being submitted - shared.close_task_queues(); + shared.close_task_queues(urgent); // wait for all tasks to finish shared.wait_on_done(); // unwrap the Arc and return the inner Shared value @@ -535,24 +555,33 @@ impl, T: TaskQueues> Hive { /// Consumes this `Hive` and attempts to shut it down gracefully. /// - /// All unprocessed tasks and stored outcomes are discarded. - /// /// If this `Hive` has been cloned, and those clones have not been dropped, this method returns /// `false`. /// + /// This closes the task queues so that no more tasks may be submitted. If `urgent` is `true`, + /// worker threads are also prevented from taking any more tasks from the queues, and all + /// queued tasks are converted to `Unprocessed` outcomes and sent or discarded; otherwise, + /// this method blocks while all queued tasks are processed. + /// /// Note that it is not necessary to call this method explicitly - all resources are dropped /// automatically when the last clone of the hive is dropped. - pub fn close(self) -> bool { - self.try_close().is_some() + pub fn close(self, urgent: bool) -> bool { + self.try_close(urgent).is_some() } - /// Consumes this `Hive` and attempts to convert any remaining unprocessed tasks into - /// `Unprocessed` outcomes and either sends each to its outcome channel or adds it to the - /// stored outcomes. + /// Consumes this `Hive` and returns a map of stored outcomes. /// - /// Returns a map of stored outcomes. - pub fn try_into_outcomes(self) -> Option>> { - self.try_close().map(|shared| shared.into_outcomes()) + /// If this `Hive` has been cloned, and those clones have not been dropped, this method + /// returns `None` since it cannot take exclusive ownership of the internal shared data. + /// + /// This closes the task queues so that no more tasks may be submitted. If `urgent` is `true`, + /// worker threads are also prevented from taking any more tasks from the queues, and all + /// queued tasks are converted to `Unprocessed` outcomes and sent or stored; otherwise, + /// this method blocks while all queued tasks are processed. + /// + /// This method first joins on the `Hive` to wait for all tasks to finish. + pub fn try_into_outcomes(self, urgent: bool) -> Option>> { + self.try_close(urgent).map(|shared| shared.into_outcomes()) } /// Consumes this `Hive` and attempts to return a [`Husk`] containing the remnants of this @@ -562,13 +591,20 @@ impl, T: TaskQueues> Hive { /// If this `Hive` has been cloned, and those clones have not been dropped, this method /// returns `None` since it cannot take exclusive ownership of the internal shared data. /// + /// This closes the task queues so that no more tasks may be submitted. If `urgent` is `true`, + /// worker threads are also prevented from taking any more tasks from the queues, and all + /// queued tasks are converted to `Unprocessed` outcomes and sent or stored; otherwise, + /// this method blocks while all queued tasks are processed. + /// /// This method first joins on the `Hive` to wait for all tasks to finish. - pub fn try_into_husk(self) -> Option> { - self.try_close().map(|shared| shared.into_husk()) + pub fn try_into_husk(self, urgent: bool) -> Option> { + self.try_close(urgent).map(|shared| shared.into_husk()) } } -impl Default for Hive, ChannelTaskQueues> { +pub type DefaultHive = Hive, ChannelTaskQueues>; + +impl Default for DefaultHive { fn default() -> Self { ChannelBuilder::default().with_worker_default().build() } diff --git a/src/hive/husk.rs b/src/hive/husk.rs index e6ce12c..95b18cd 100644 --- a/src/hive/husk.rs +++ b/src/hive/husk.rs @@ -152,7 +152,7 @@ mod tests { let mut task_ids = hive.map_store((0..10).map(|i| Thunk::of(move || i))); // cancel and smash the hive before the tasks can be processed hive.suspend(); - let mut husk = hive.try_into_husk().unwrap(); + let mut husk = hive.try_into_husk(false).unwrap(); assert!(husk.has_unprocessed()); for i in task_ids.iter() { assert!(husk.get(*i).unwrap().is_unprocessed()); @@ -178,12 +178,12 @@ mod tests { let _ = hive1.map_store((0..10).map(|i| Thunk::of(move || i))); // cancel and smash the hive before the tasks can be processed hive1.suspend(); - let husk1 = hive1.try_into_husk().unwrap(); + let husk1 = hive1.try_into_husk(false).unwrap(); let (hive2, _) = husk1.into_hive_swarm_store_unprocessed::>(); // now spin up worker threads to process the tasks hive2.grow(8).expect("error spawning threads"); hive2.join(); - let husk2 = hive2.try_into_husk().unwrap(); + let husk2 = hive2.try_into_husk(false).unwrap(); assert!(!husk2.has_unprocessed()); assert!(husk2.has_successes()); assert_eq!(husk2.iter_successes().count(), 10); @@ -199,13 +199,13 @@ mod tests { let _ = hive1.map_store((0..10).map(|i| Thunk::of(move || i))); // cancel and smash the hive before the tasks can be processed hive1.suspend(); - let husk1 = hive1.try_into_husk().unwrap(); + let husk1 = hive1.try_into_husk(false).unwrap(); let (tx, rx) = outcome_channel(); let (hive2, task_ids) = husk1.into_hive_swarm_send_unprocessed::>(&tx); // now spin up worker threads to process the tasks hive2.grow(8).expect("error spawning threads"); hive2.join(); - let husk2 = hive2.try_into_husk().unwrap(); + let husk2 = hive2.try_into_husk(false).unwrap(); assert!(husk2.is_empty()); let mut outputs = rx .select_ordered(task_ids) @@ -223,7 +223,7 @@ mod tests { .build(); hive.map_store((0..10).map(|i| Thunk::of(move || i))); hive.join(); - let mut outputs = hive.try_into_husk().unwrap().into_parts().1.unwrap(); + let mut outputs = hive.try_into_husk(false).unwrap().into_parts().1.unwrap(); outputs.sort(); assert_eq!(outputs, (0..10).collect::>()); } @@ -239,7 +239,7 @@ mod tests { (0..10).map(|i| Thunk::of(move || if i == 5 { panic!("oh no!") } else { i })), ); hive.join(); - let (_, result) = hive.try_into_husk().unwrap().into_parts(); + let (_, result) = hive.try_into_husk(false).unwrap().into_parts(); let _ = result.ok_or_unwrap_errors(true); } } diff --git a/src/hive/inner/builder.rs b/src/hive/inner/builder.rs index 808ddb9..c10e102 100644 --- a/src/hive/inner/builder.rs +++ b/src/hive/inner/builder.rs @@ -20,12 +20,13 @@ pub trait Builder: BuilderConfig + Sized { /// /// ``` /// use beekeeper::bee::stock::{Thunk, ThunkWorker}; - /// use beekeeper::hive::{Builder, Hive}; + /// use beekeeper::hive::{Builder, ChannelBuilder, Hive}; /// /// # fn main() { - /// let hive = Builder::new() + /// let hive = ChannelBuilder::empty() /// .num_threads(8) - /// .build_with_default::>(); + /// .with_worker_default::>() + /// .build(); /// /// for _ in 0..100 { /// hive.apply_store(Thunk::of(|| { @@ -56,12 +57,13 @@ pub trait Builder: BuilderConfig + Sized { /// /// ``` /// use beekeeper::bee::stock::{Thunk, ThunkWorker}; - /// use beekeeper::hive::{Builder, Hive}; + /// use beekeeper::hive::{Builder, ChannelBuilder, Hive}; /// /// # fn main() { - /// let hive = Builder::new() + /// let hive = ChannelBuilder::empty() /// .with_thread_per_core() - /// .build_with_default::>(); + /// .with_worker_default::>() + /// .build(); /// /// for _ in 0..100 { /// hive.apply_store(Thunk::of(|| { @@ -84,13 +86,14 @@ pub trait Builder: BuilderConfig + Sized { /// /// ``` /// use beekeeper::bee::stock::{Thunk, ThunkWorker}; - /// use beekeeper::hive::{Builder, Hive}; + /// use beekeeper::hive::{Builder, ChannelBuilder, Hive}; /// use std::thread; /// /// # fn main() { - /// let hive = Builder::default() + /// let hive = ChannelBuilder::default() /// .thread_name("foo") - /// .build_with_default::>(); + /// .with_worker_default::>() + /// .build(); /// /// for _ in 0..100 { /// hive.apply_store(Thunk::of(|| { @@ -117,12 +120,13 @@ pub trait Builder: BuilderConfig + Sized { /// /// ``` /// use beekeeper::bee::stock::{Thunk, ThunkWorker}; - /// use beekeeper::hive::{Builder, Hive}; + /// use beekeeper::hive::{Builder, ChannelBuilder, Hive}; /// /// # fn main() { - /// let hive = Builder::default() + /// let hive = ChannelBuilder::default() /// .thread_stack_size(4_000_000) - /// .build_with_default::>(); + /// .with_worker_default::>() + /// .build(); /// /// for _ in 0..100 { /// hive.apply_store(Thunk::of(|| { @@ -154,13 +158,14 @@ pub trait Builder: BuilderConfig + Sized { /// /// ``` /// use beekeeper::bee::stock::{Thunk, ThunkWorker}; - /// use beekeeper::hive::{Builder, Hive}; + /// use beekeeper::hive::{Builder, ChannelBuilder, Hive}; /// /// # fn main() { - /// let hive = Builder::new() + /// let hive = ChannelBuilder::empty() /// .num_threads(4) /// .core_affinity(0..4) - /// .build_with_default::>(); + /// .with_worker_default::>() + /// .build(); /// /// for _ in 0..100 { /// hive.apply_store(Thunk::of(|| { @@ -229,12 +234,12 @@ pub trait Builder: BuilderConfig + Sized { /// ``` /// use beekeeper::bee::{ApplyError, Context}; /// use beekeeper::bee::stock::RetryCaller; - /// use beekeeper::hive::{Builder, Hive}; + /// use beekeeper::hive::{Builder, ChannelBuilder, Hive}; /// use std::time; /// /// fn sometimes_fail( /// i: usize, - /// _: &Context + /// _: &Context /// ) -> Result> { /// match i % 3 { /// 0 => Ok("Success".into()), @@ -245,9 +250,10 @@ pub trait Builder: BuilderConfig + Sized { /// } /// /// # fn main() { - /// let hive = Builder::default() + /// let hive = ChannelBuilder::default() /// .max_retries(3) - /// .build_with(RetryCaller::of(sometimes_fail)); + /// .with_worker(RetryCaller::of(sometimes_fail)) + /// .build(); /// /// for i in 0..10 { /// hive.apply_store(i); @@ -275,10 +281,10 @@ pub trait Builder: BuilderConfig + Sized { /// ``` /// use beekeeper::bee::{ApplyError, Context}; /// use beekeeper::bee::stock::RetryCaller; - /// use beekeeper::hive::{Builder, Hive}; + /// use beekeeper::hive::{Builder, ChannelBuilder, Hive}; /// use std::time; /// - /// fn echo_time(i: usize, ctx: &Context) -> Result> { + /// fn echo_time(i: usize, ctx: &Context) -> Result> { /// let attempt = ctx.attempt(); /// if attempt == 3 { /// Ok("Success".into()) @@ -290,10 +296,11 @@ pub trait Builder: BuilderConfig + Sized { /// } /// /// # fn main() { - /// let hive = Builder::default() + /// let hive = ChannelBuilder::default() /// .max_retries(3) /// .retry_factor(time::Duration::from_secs(1)) - /// .build_with(RetryCaller::of(echo_time)); + /// .with_worker(RetryCaller::of(echo_time)) + /// .build(); /// /// for i in 0..10 { /// hive.apply_store(i); diff --git a/src/hive/inner/queue/channel.rs b/src/hive/inner/queue/channel.rs index 548d7b8..1a72689 100644 --- a/src/hive/inner/queue/channel.rs +++ b/src/hive/inner/queue/channel.rs @@ -1,8 +1,7 @@ //! Implementation of `TaskQueues` that uses `crossbeam` channels for the global queue (i.e., for //! sending tasks from the `Hive` to the worker threads) and a default implementation of local //! queues that depends on which combination of the `retry` and `batching` features are enabled. -use super::{Config, PopTaskError, Task, TaskQueues, Token, WorkerQueues}; -use crate::atomic::{Atomic, AtomicBool}; +use super::{Closed, Config, PopTaskError, Task, TaskQueues, Token, WorkerQueues}; use crate::bee::Worker; use crossbeam_channel::RecvTimeoutError; use crossbeam_queue::SegQueue; @@ -20,12 +19,11 @@ type TaskReceiver = crossbeam_channel::Receiver>; pub struct ChannelTaskQueues { global: Arc>, - local: RwLock>>>, + local: RwLock>>>, } impl TaskQueues for ChannelTaskQueues { type WorkerQueues = ChannelWorkerQueues; - type WorkerQueuesTarget = Arc; fn new(_: Token) -> Self { Self { @@ -38,36 +36,34 @@ impl TaskQueues for ChannelTaskQueues { let mut local_queues = self.local.write(); assert_eq!(local_queues.len(), start_index); (start_index..end_index).for_each(|thread_index| { - local_queues.push(Arc::new(ChannelWorkerQueues::new( - thread_index, - &self.global, - config, - ))) + local_queues.push(Arc::new(LocalQueueShared::new(thread_index, config))) }); } fn update_for_threads(&self, start_index: usize, end_index: usize, config: &Config) { let local_queues = self.local.write(); - assert!(local_queues.len() > end_index); + assert!(local_queues.len() >= end_index); local_queues[start_index..end_index] .iter() - .for_each(|queue| queue.update(config)); + .for_each(|queue| queue.update(&self.global, config)); } - fn worker_queues(&self, thread_index: usize) -> Arc { - Arc::clone(&self.local.read()[thread_index]) + fn worker_queues(&self, thread_index: usize) -> Self::WorkerQueues { + ChannelWorkerQueues::new(&self.global, &self.local.read()[thread_index]) } fn try_push_global(&self, task: Task) -> Result<(), Task> { self.global.try_push(task) } - fn close(&self, _: Token) { - self.global.close() + fn close(&self, urgent: bool, _: Token) { + self.global.close(urgent) } fn drain(self) -> Vec> { - self.close(Token); + if !self.global.is_closed() { + panic!("close must be called before drain"); + } let mut tasks = Vec::new(); let global = crate::hive::unwrap_arc(self.global); global.drain_into(&mut tasks); @@ -82,7 +78,7 @@ impl TaskQueues for ChannelTaskQueues { pub struct GlobalQueue { global_tx: TaskSender, global_rx: TaskReceiver, - closed: AtomicBool, + closed: Closed, } impl GlobalQueue { @@ -97,7 +93,7 @@ impl GlobalQueue { #[inline] fn try_push(&self, task: Task) -> Result<(), Task> { - if self.closed.get() { + if !self.closed.can_push() { return Err(task); } self.global_tx.send(task).map_err(|err| err.into_inner()) @@ -108,7 +104,7 @@ impl GlobalQueue { match self.global_rx.recv_timeout(RECV_TIMEOUT) { Ok(task) => Ok(task), Err(RecvTimeoutError::Disconnected) => Err(PopTaskError::Closed), - Err(RecvTimeoutError::Timeout) if self.closed.get() && self.global_rx.is_empty() => { + Err(RecvTimeoutError::Timeout) if self.is_closed() && self.global_rx.is_empty() => { Err(PopTaskError::Closed) } Err(RecvTimeoutError::Timeout) => Err(PopTaskError::Empty), @@ -120,8 +116,13 @@ impl GlobalQueue { self.global_rx.try_iter() } - fn close(&self) { - self.closed.set(true); + #[inline] + fn is_closed(&self) -> bool { + self.closed.is_closed() + } + + fn close(&self, urgent: bool) { + self.closed.set(urgent); } fn drain_into(self, tasks: &mut Vec>) { @@ -130,78 +131,94 @@ impl GlobalQueue { } pub struct ChannelWorkerQueues { - _thread_index: usize, global: Arc>, + shared: Arc>, +} + +impl ChannelWorkerQueues { + fn new(global_queue: &Arc>, shared: &Arc>) -> Self { + Self { + global: Arc::clone(global_queue), + shared: Arc::clone(shared), + } + } +} + +impl WorkerQueues for ChannelWorkerQueues { + fn push(&self, task: Task) { + self.shared.push(task, &self.global); + } + + fn try_pop(&self) -> Result, PopTaskError> { + self.shared.try_pop(&self.global) + } + + #[cfg(feature = "retry")] + fn try_push_retry(&self, task: Task) -> Result> { + self.shared.try_push_retry(task) + } +} + +struct LocalQueueShared { + _thread_index: usize, /// queue of abandon tasks local_abandoned: SegQueue>, /// thread-local queue of tasks used when the `batching` feature is enabled #[cfg(feature = "batching")] - local_batch: RwLock>>, + local_batch: batching::WorkerBatchQueue, /// thread-local queues used for tasks that are waiting to be retried after a failure #[cfg(feature = "retry")] - local_retry: super::retry::RetryQueue, + local_retry: super::RetryQueue, } -impl ChannelWorkerQueues { - fn new(thread_index: usize, global_queue: &Arc>, _config: &Config) -> Self { +impl LocalQueueShared { + fn new(thread_index: usize, _config: &Config) -> Self { Self { _thread_index: thread_index, - global: Arc::clone(global_queue), local_abandoned: Default::default(), #[cfg(feature = "batching")] - local_batch: RwLock::new(crossbeam_queue::ArrayQueue::new( - _config.batch_limit.get_or_default().max(1), - )), + local_batch: batching::WorkerBatchQueue::new(_config.batch_limit.get_or_default()), #[cfg(feature = "retry")] - local_retry: super::retry::RetryQueue::new(_config.retry_factor.get_or_default()), + local_retry: super::RetryQueue::new(_config.retry_factor.get_or_default()), } } /// Updates the local queues based on the provided `config`: /// If `batching` is enabled, resizes the batch queue if necessary. /// If `retry` is enabled, updates the retry factor. - fn update(&self, _config: &Config) { + fn update(&self, _global: &GlobalQueue, _config: &Config) { #[cfg(feature = "batching")] - self.update_batch(_config); + self.local_batch + .set_limit(_config.batch_limit.get_or_default(), _global, self); #[cfg(feature = "retry")] self.local_retry .set_delay_factor(_config.retry_factor.get_or_default()); } - /// Consumes this `ChannelWorkerQueues` and drains the tasks currently in the queues into - /// `tasks`. - fn drain_into(self, tasks: &mut Vec>) { - while let Some(task) = self.local_abandoned.pop() { - tasks.push(task); - } - #[cfg(feature = "batching")] - { - let batch = self.local_batch.into_inner(); - tasks.reserve(batch.len()); - while let Some(task) = batch.pop() { - tasks.push(task); - } - } - #[cfg(feature = "retry")] - self.local_retry.drain_into(tasks); - } -} - -impl WorkerQueues for ChannelWorkerQueues { - fn push(&self, task: Task) { + #[inline] + fn push(&self, task: Task, global: &GlobalQueue) { #[cfg(feature = "batching")] - let task = match self.local_batch.read().push(task) { + let task = match self.local_batch.try_push(task) { Ok(_) => return, Err(task) => task, }; - let task = match self.global.try_push(task) { + self.push_global(task, global); + } + + #[inline] + fn push_global(&self, task: Task, global: &GlobalQueue) { + let task = match global.try_push(task) { Ok(_) => return, Err(task) => task, }; self.local_abandoned.push(task); } - fn try_pop(&self) -> Result, PopTaskError> { + #[inline] + fn try_pop(&self, global: &GlobalQueue) -> Result, PopTaskError> { + if !global.closed.can_pop() { + return Err(PopTaskError::Closed); + } // first try to get a previously abandoned task if let Some(task) = self.local_abandoned.pop() { return Ok(task); @@ -215,74 +232,145 @@ impl WorkerQueues for ChannelWorkerQueues { // and try to refill it from the global queue if it's empty #[cfg(feature = "batching")] { - self.try_pop_batch_or_refill().ok_or(PopTaskError::Empty) + self.local_batch.try_pop_or_refill(global, self) } // fall back to requesting a task from the global queue #[cfg(not(feature = "batching"))] - self.global.try_pop() + { + self.global.try_pop() + } } #[cfg(feature = "retry")] fn try_push_retry(&self, task: Task) -> Result> { self.local_retry.try_push(task) } + + /// Consumes this `ChannelWorkerQueues` and drains the tasks currently in the queues into + /// `tasks`. + fn drain_into(self, tasks: &mut Vec>) { + while let Some(task) = self.local_abandoned.pop() { + tasks.push(task); + } + #[cfg(feature = "batching")] + self.local_batch.drain_into(tasks); + #[cfg(feature = "retry")] + self.local_retry.drain_into(tasks); + } } #[cfg(feature = "batching")] mod batching { - use super::{ChannelWorkerQueues, Config, Task}; + use super::{GlobalQueue, LocalQueueShared, Task}; + use crate::atomic::{Atomic, AtomicUsize}; use crate::bee::Worker; + use crate::hive::inner::queue::PopTaskError; use crossbeam_queue::ArrayQueue; - use std::time::Duration; - - impl ChannelWorkerQueues { - pub fn update_batch(&self, config: &Config) { - let batch_limit = config.batch_limit.get_or_default().max(1); - let mut queue = self.local_batch.write(); - // block until the current queue is small enough that it can fit into the new queue - while queue.len() > batch_limit { - std::thread::sleep(Duration::from_millis(10)); - } - let new_queue = ArrayQueue::new(batch_limit); - while let Some(task) = queue.pop() { - if let Err(task) = new_queue - .push(task) - .or_else(|task| self.global.try_push(task)) - { - self.local_abandoned.push(task); - break; + use parking_lot::RwLock; + + pub struct WorkerBatchQueue { + inner: RwLock>>>, + limit: AtomicUsize, + } + + impl WorkerBatchQueue { + pub fn new(batch_limit: usize) -> Self { + if batch_limit == 0 { + Self { + inner: RwLock::new(None), + limit: Default::default(), + } + } else { + Self { + inner: RwLock::new(Some(ArrayQueue::new(batch_limit))), + limit: AtomicUsize::new(batch_limit), } } - assert!(queue.is_empty()); - *queue = new_queue; } - pub(super) fn try_pop_batch_or_refill(&self) -> Option> { - // pop from the local queue if it has any tasks - let local_queue = self.local_batch.read(); - if !local_queue.is_empty() { - return local_queue.pop(); + pub fn set_limit( + &self, + limit: usize, + global: &GlobalQueue, + parent: &LocalQueueShared, + ) { + // acquire the exclusive lock first to prevent simultaneous updates + let mut queue = self.inner.write(); + let old_limit = self.limit.set(limit); + if old_limit == limit { + return; + } + let old_queue = if limit == 0 { + queue.take() + } else { + queue.replace(ArrayQueue::new(limit)) + }; + if let Some(old_queue) = old_queue { + // try to push tasks from the old queue to the new one and fall back to pushing + // them to the global queue + old_queue + .into_iter() + .filter_map(|task| { + if let Some(new_queue) = queue.as_ref() { + new_queue.push(task).err() + } else { + Some(task) + } + }) + .for_each(|task| parent.push_global(task, global)); } - // otherwise pull at least 1 and up to `batch_limit + 1` tasks from the input channel - // wait for the next task from the receiver - let first = self.global.try_pop().ok(); - // if we fail after trying to get one, don't keep trying to fill the queue - if first.is_some() { - let batch_limit = local_queue.capacity(); + } + + pub fn try_push(&self, task: Task) -> Result<(), Task> { + if let Some(queue) = self.inner.read().as_ref() { + queue.push(task) + } else { + Err(task) + } + } + + pub fn try_pop_or_refill( + &self, + global: &GlobalQueue, + parent: &LocalQueueShared, + ) -> Result, PopTaskError> { + // pop from the local queue if it has any tasks + if let Some(local) = self.inner.read().as_ref() { + if !local.is_empty() { + if let Some(task) = local.pop() { + return Ok(task); + } + } + // otherwise pull at least 1 and up to `batch_limit + 1` tasks from the input channel + // wait for the next task from the receiver + let first = global.try_pop()?; + // if we succeed in getting the first task, try to refill the local queue + let limit = self.limit.get(); // batch size 0 means batching is disabled - if batch_limit > 0 { + if limit > 0 { // otherwise try to take up to `batch_limit` tasks from the input channel // and add them to the local queue, but don't block if the input channel // is empty - for task in self.global.try_iter().take(batch_limit) { - if let Err(task) = local_queue.push(task) { - self.local_abandoned.push(task); + for task in global.try_iter().take(limit) { + if let Err(task) = local.push(task) { + parent.local_abandoned.push(task); break; } } } + Ok(first) + } else { + global.try_pop() + } + } + + pub fn drain_into(self, tasks: &mut Vec>) { + if let Some(queue) = self.inner.into_inner() { + tasks.reserve(queue.len()); + while let Some(task) = queue.pop() { + tasks.push(task); + } } - first } } } diff --git a/src/hive/inner/queue/mod.rs b/src/hive/inner/queue/mod.rs index 69683ec..4dae392 100644 --- a/src/hive/inner/queue/mod.rs +++ b/src/hive/inner/queue/mod.rs @@ -1,4 +1,5 @@ mod channel; +mod closed; #[cfg(feature = "retry")] mod retry; mod workstealing; @@ -6,9 +7,11 @@ mod workstealing; pub use self::channel::ChannelTaskQueues; pub use self::workstealing::WorkstealingTaskQueues; +use self::closed::Closed; +#[cfg(feature = "retry")] +use self::retry::RetryQueue; use super::{Config, Task, Token}; use crate::bee::Worker; -use std::ops::Deref; /// Errors that may occur when trying to pop tasks from the global queue. #[derive(thiserror::Error, Debug)] @@ -25,7 +28,6 @@ pub enum PopTaskError { /// This trait is sealed - it cannot be implemented outside of this crate. pub trait TaskQueues: Sized + Send + Sync + 'static { type WorkerQueues: WorkerQueues; - type WorkerQueuesTarget: Deref; /// Returns a new instance. /// @@ -39,7 +41,7 @@ pub trait TaskQueues: Sized + Send + Sync + 'static { fn update_for_threads(&self, start_index: usize, end_index: usize, config: &Config); /// Returns a new `WorkerQueues` instance for a thread. - fn worker_queues(&self, thread_index: usize) -> Self::WorkerQueuesTarget; + fn worker_queues(&self, thread_index: usize) -> Self::WorkerQueues; /// Tries to add a task to the global queue. /// @@ -48,13 +50,17 @@ pub trait TaskQueues: Sized + Send + Sync + 'static { /// Closes this `GlobalQueue` so no more tasks may be pushed. /// + /// If `urgent` is `true`, this also prevents queued tasks from being popped. + /// /// The private `Token` is used to prevent this method from being called externally. - fn close(&self, token: Token); + fn close(&self, urgent: bool, token: Token); /// Drains all tasks from all global and local queues and returns them as a `Vec`. /// /// This is a destructive operation - if `close` has not been called, it will be called before /// draining the queues. + /// + /// This method panics if `close` has not been called. fn drain(self) -> Vec>; } diff --git a/src/hive/inner/queue/retry.rs b/src/hive/inner/queue/retry.rs index 347f8a7..e018ac9 100644 --- a/src/hive/inner/queue/retry.rs +++ b/src/hive/inner/queue/retry.rs @@ -177,7 +177,7 @@ mod tests { assert_eq!(queue.try_pop(), Some(task2)); assert_eq!(queue.len(), 1); - thread::sleep(Duration::from_secs(1)); + thread::sleep(Duration::from_secs(2)); assert_eq!(queue.try_pop(), Some(task3)); assert_eq!(queue.len(), 0); diff --git a/src/hive/inner/queue/workstealing.rs b/src/hive/inner/queue/workstealing.rs index aeba049..48acb9d 100644 --- a/src/hive/inner/queue/workstealing.rs +++ b/src/hive/inner/queue/workstealing.rs @@ -1,5 +1,12 @@ -use super::{Config, PopTaskError, Task, TaskQueues, Token, WorkerQueues}; -use crate::atomic::{Atomic, AtomicBool}; +//! Implementation of `TaskQueues` that uses workstealing to distribute tasks among worker threads. +//! Tasks are sent from the `Hive` via a global `Injector` queue. Each worker thread has a local +//! `Worker` queue where tasks can be pushed. If the local queue is empty, the worker thread first +//! tries to steal a task from the global queue and falls back to stealing from another worker +//! thread. If the `batching` feature is enabled, a worker thread will try to fill its local queue +//! up to the limit when stealing from the global queue. +use super::{Closed, Config, PopTaskError, Task, TaskQueues, Token, WorkerQueues}; +#[cfg(feature = "batching")] +use crate::atomic::Atomic; use crate::bee::Worker; use crossbeam_deque::{Injector, Stealer}; use crossbeam_queue::SegQueue; @@ -15,7 +22,6 @@ pub struct WorkstealingTaskQueues { impl TaskQueues for WorkstealingTaskQueues { type WorkerQueues = WorkstealingWorkerQueues; - type WorkerQueuesTarget = Self::WorkerQueues; fn new(_: Token) -> Self { Self { @@ -28,11 +34,7 @@ impl TaskQueues for WorkstealingTaskQueues { let mut local_queues = self.local.write(); assert_eq!(local_queues.len(), start_index); (start_index..end_index).for_each(|thread_index| { - local_queues.push(Arc::new(LocalQueueShared::new( - thread_index, - &self.global, - config, - ))); + local_queues.push(Arc::new(LocalQueueShared::new(thread_index, config))); }); } @@ -42,23 +44,24 @@ impl TaskQueues for WorkstealingTaskQueues { (start_index..end_index).for_each(|thread_index| local_queues[thread_index].update(config)); } - fn worker_queues(&self, thread_index: usize) -> Self::WorkerQueuesTarget { + fn worker_queues(&self, thread_index: usize) -> Self::WorkerQueues { let local_queue = crossbeam_deque::Worker::new_fifo(); self.global.add_stealer(local_queue.stealer()); - let shared = &self.local.read()[thread_index]; - WorkstealingWorkerQueues::new(local_queue, Arc::clone(shared)) + WorkstealingWorkerQueues::new(local_queue, &self.global, &self.local.read()[thread_index]) } fn try_push_global(&self, task: Task) -> Result<(), Task> { self.global.try_push(task) } - fn close(&self, _: Token) { - self.global.close(); + fn close(&self, urgent: bool, _: Token) { + self.global.close(urgent); } fn drain(self) -> Vec> { - self.close(Token); + if !self.global.is_closed() { + panic!("close must be called before drain"); + } let mut tasks = Vec::new(); let global = crate::hive::unwrap_arc(self.global); global.drain_into(&mut tasks); @@ -73,7 +76,7 @@ impl TaskQueues for WorkstealingTaskQueues { pub struct GlobalQueue { queue: Injector>, stealers: RwLock>>>, - closed: AtomicBool, + closed: Closed, } impl GlobalQueue { @@ -81,7 +84,7 @@ impl GlobalQueue { Self { queue: Injector::new(), stealers: Default::default(), - closed: AtomicBool::default(), + closed: Default::default(), } } @@ -90,18 +93,18 @@ impl GlobalQueue { } fn try_push(&self, task: Task) -> Result<(), Task> { - if self.closed.get() { + if !self.closed.can_push() { return Err(task); } self.queue.push(task); Ok(()) } + /// Tries to steal a task from a random worker using its `Stealer`. fn try_steal(&self) -> Option> { let stealers = self.stealers.read(); let n = stealers.len(); // randomize the stealing order, to prevent always stealing from the same thread - // TODO: put this into a shared global to prevent creating a new instance on every call std::iter::from_fn(|| Some(rand::rng().random_range(0..n))) .take(n) .filter_map(|i| stealers[i].steal().success()) @@ -110,8 +113,7 @@ impl GlobalQueue { /// Tries to steal a task from the global queue, otherwise tries to steal a task from another /// worker thread. - #[cfg(not(feature = "batching"))] - fn try_pop(&self) -> Result, PopTaskError> { + fn try_pop_unchecked(&self) -> Result, PopTaskError> { if let Some(task) = self.queue.steal().success() { Ok(task) } else { @@ -119,8 +121,9 @@ impl GlobalQueue { } } - /// Tries to steal up to `limit` tasks from the global queue. If at least one task was stolen, - /// it is popped and returned. Otherwise tries to steal a task from another worker thread. + /// Tries to steal up to `limit + 1` tasks from the global queue. If at least one task was + /// stolen, it is popped and returned. Otherwise tries to steal a task from another worker + /// thread. #[cfg(feature = "batching")] fn try_refill_and_pop( &self, @@ -129,7 +132,7 @@ impl GlobalQueue { ) -> Result, PopTaskError> { if let Some(task) = self .queue - .steal_batch_with_limit_and_pop(local_batch, limit) + .steal_batch_with_limit_and_pop(local_batch, limit + 1) .success() { return Ok(task); @@ -137,14 +140,21 @@ impl GlobalQueue { self.try_steal().ok_or(PopTaskError::Empty) } - fn close(&self) { - self.closed.set(true); + fn is_closed(&self) -> bool { + self.closed.is_closed() + } + + fn close(&self, urgent: bool) { + self.closed.set(urgent); } fn drain_into(self, tasks: &mut Vec>) { while let Some(task) = self.queue.steal().success() { tasks.push(task); } + // since the `TaskQueues` instance does not retain a reference to the workers' queues + // (it can't, because they're not Send/Sync), the only way we have to drain them is via + // their stealers self.stealers.into_inner().into_iter().for_each(|stealer| { while let Some(task) = stealer.steal().success() { tasks.push(task); @@ -154,23 +164,32 @@ impl GlobalQueue { } pub struct WorkstealingWorkerQueues { - queue: crossbeam_deque::Worker>, + local: crossbeam_deque::Worker>, + global: Arc>, shared: Arc>, } impl WorkstealingWorkerQueues { - fn new(queue: crossbeam_deque::Worker>, shared: Arc>) -> Self { - Self { queue, shared } + fn new( + local: crossbeam_deque::Worker>, + global: &Arc>, + shared: &Arc>, + ) -> Self { + Self { + global: Arc::clone(global), + local, + shared: Arc::clone(shared), + } } } impl WorkerQueues for WorkstealingWorkerQueues { fn push(&self, task: Task) { - self.queue.push(task); + self.local.push(task); } fn try_pop(&self) -> Result, PopTaskError> { - self.shared.try_pop(&self.queue) + self.shared.try_pop(&self.global, &self.local) } #[cfg(feature = "retry")] @@ -189,26 +208,24 @@ impl Deref for WorkstealingWorkerQueues { struct LocalQueueShared { _thread_index: usize, - global: Arc>, /// queue of abandon tasks local_abandoned: SegQueue>, #[cfg(feature = "batching")] batch_limit: crate::atomic::AtomicUsize, /// thread-local queues used for tasks that are waiting to be retried after a failure #[cfg(feature = "retry")] - local_retry: super::retry::RetryQueue, + local_retry: super::RetryQueue, } impl LocalQueueShared { - fn new(thread_index: usize, global: &Arc>, _config: &Config) -> Self { + fn new(thread_index: usize, _config: &Config) -> Self { Self { _thread_index: thread_index, - global: Arc::clone(global), local_abandoned: Default::default(), #[cfg(feature = "batching")] batch_limit: crate::atomic::AtomicUsize::new(_config.batch_limit.get_or_default()), #[cfg(feature = "retry")] - local_retry: super::retry::RetryQueue::new(_config.retry_factor.get_or_default()), + local_retry: super::RetryQueue::new(_config.retry_factor.get_or_default()), } } @@ -222,8 +239,12 @@ impl LocalQueueShared { fn try_pop( &self, + global: &GlobalQueue, local_batch: &crossbeam_deque::Worker>, ) -> Result, PopTaskError> { + if !global.closed.can_pop() { + return Err(PopTaskError::Closed); + } // first try to get a previously abandoned task if let Some(task) = self.local_abandoned.pop() { return Ok(task); @@ -237,20 +258,16 @@ impl LocalQueueShared { if let Some(task) = local_batch.pop() { return Ok(task); } - // fall back to requesting a task from the global queue - this will also refill the local - // batch queue if the batching feature is enabled - if let Some(task) = local_batch.pop() { - return Ok(task); - } + // fall back to requesting a task from the global queue - if batching is enabled, this will + // also try to refill the local queue #[cfg(feature = "batching")] { - self.global - .try_refill_and_pop(local_batch, self.batch_limit.get()) - } - #[cfg(not(feature = "batching"))] - { - self.global.try_pop() + let limit = self.batch_limit.get(); + if limit > 0 { + return global.try_refill_and_pop(local_batch, limit); + } } + global.try_pop_unchecked() } fn drain_into(self, tasks: &mut Vec>) { diff --git a/src/hive/inner/shared.rs b/src/hive/inner/shared.rs index 96119fd..41be1c9 100644 --- a/src/hive/inner/shared.rs +++ b/src/hive/inner/shared.rs @@ -147,7 +147,7 @@ impl, T: TaskQueues> Shared { } /// Returns the `WorkerQueues` instance for the worker thread with the specified index. - pub fn worker_queues(&self, thread_index: usize) -> T::WorkerQueuesTarget { + pub fn worker_queues(&self, thread_index: usize) -> T::WorkerQueues { self.task_queues.worker_queues(thread_index) } @@ -421,7 +421,7 @@ impl, T: TaskQueues> Shared { /// 3. Resumes the hive if it is suspendend, which enables blocked worker threads to terminate. pub fn poison(&self) { self.poisoned.set(true); - self.close_task_queues(); + self.close_task_queues(true); self.set_suspended(false); } @@ -487,15 +487,14 @@ impl, T: TaskQueues> Shared { } /// Close the tasks queues so no more tasks can be added. - pub fn close_task_queues(&self) { - self.task_queues.close(Token); + pub fn close_task_queues(&self, urgent: bool) { + self.task_queues.close(urgent, Token); } fn flush( task_queues: T, mut outcomes: HashMap>, ) -> HashMap> { - task_queues.close(Token); for task in task_queues.drain().into_iter() { let task_id = task.id(); let (outcome, outcome_tx) = task.into_unprocessed(); diff --git a/src/hive/mod.rs b/src/hive/mod.rs index 3b0c537..40672fa 100644 --- a/src/hive/mod.rs +++ b/src/hive/mod.rs @@ -32,24 +32,25 @@ //! the `Builder`. //! //! ``` -//! use beekeeper::hive::Builder; +//! use beekeeper::hive::{Builder, ChannelBuilder}; //! # type MyWorker1 = beekeeper::bee::stock::EchoWorker; //! # type MyWorker2 = beekeeper::bee::stock::EchoWorker; //! -//! let builder1 = Builder::default(); +//! let builder1 = ChannelBuilder::default(); //! let builder2 = builder1.clone(); //! -//! let hive1 = builder1.build_with_default::(); -//! let hive2 = builder2.build_with_default::(); +//! let hive1 = builder1.with_worker_default::().build(); +//! let hive2 = builder2.with_worker_default::().build(); //! ``` //! //! If you want a `Hive` with the global defaults for a `Worker` type that implements `Default`, -//! you can call [`Hive::default`](crate::hive::Hive::default) rather than use a `Builder`. +//! you can call [`DefaultHive::::default`](crate::hive::Hive::default) rather than use a +//! `Builder`. //! //! ``` -//! # use beekeeper::hive::Hive; +//! # use beekeeper::hive::DefaultHive; //! # type MyWorker = beekeeper::bee::stock::EchoWorker; -//! let hive: Hive = Hive::default(); +//! let hive = DefaultHive::::default(); //! ``` //! //! ## Thread affinity (requires `feature = "affinity"`) @@ -78,14 +79,15 @@ //! started with no core affinity. //! //! ``` -//! use beekeeper::hive::Builder; +//! use beekeeper::hive::{Builder, ChannelBuilder}; //! # type MyWorker = beekeeper::bee::stock::EchoWorker; //! -//! let hive = Builder::new() +//! let hive = ChannelBuilder::empty() //! .num_threads(4) //! // 16 cores will be available for pinning but only 4 will be used initially //! .core_affinity(0..16) -//! .build_with_default::(); +//! .with_worker_default::() +//! .build(); //! //! // increase the number of threads by 12 - the new threads will use the additiona //! // 12 available cores for pinning @@ -266,13 +268,13 @@ //! task IDs. //! //! ``` -//! use beekeeper::hive::{Hive, OutcomeIteratorExt, outcome_channel}; +//! use beekeeper::hive::{DefaultHive, OutcomeIteratorExt, outcome_channel}; //! # type MyWorker = beekeeper::bee::stock::EchoWorker; //! -//! let hive: Hive = Hive::default(); +//! let hive = DefaultHive::::default(); //! let (tx, rx) = outcome_channel::(); -//! let batch1 = hive.swarm_send(0..10, tx.clone()); -//! let batch2 = hive.swarm_send(10..20, tx.clone()); +//! let batch1 = hive.swarm_send(0..10, &tx); +//! let batch2 = hive.swarm_send(10..20, tx); //! let outputs: Vec<_> = rx.into_iter() //! .select_ordered_outputs(batch1.into_iter().chain(batch2.into_iter())) //! .collect(); @@ -287,10 +289,10 @@ //! (see below), which provides a common interface for accessing stored `Outcome`s. //! //! ``` -//! use beekeeper::hive::{Hive, OutcomeStore}; +//! use beekeeper::hive::{DefaultHive, OutcomeStore}; //! # type MyWorker = beekeeper::bee::stock::EchoWorker; //! -//! let hive: Hive = Hive::default(); +//! let hive = DefaultHive::::default(); //! let (outcomes, sum) = hive.scan(0..10, 0, |sum, i| { //! *sum += i; //! i * 2 @@ -365,7 +367,7 @@ mod inner; mod outcome; pub use self::builder::{BeeBuilder, ChannelBuilder, FullBuilder, OpenBuilder}; -pub use self::hive::{Hive, Poisoned}; +pub use self::hive::{DefaultHive, Hive, Poisoned}; pub use self::husk::Husk; pub use self::inner::{set_config::*, Builder, ChannelTaskQueues, WorkstealingTaskQueues}; pub use self::outcome::{Outcome, OutcomeBatch, OutcomeIteratorExt, OutcomeStore}; @@ -505,7 +507,7 @@ mod tests { } #[test] - fn test_grow() { + fn test_grow_from_nonzero() { let hive = void_thunk_hive(TEST_TASKS, false); // queue some long-running tasks for _ in 0..TEST_TASKS { @@ -523,7 +525,7 @@ mod tests { } thread::sleep(ONE_SEC); assert_eq!(hive.num_tasks().1, total_threads as u64); - let husk = hive.try_into_husk().unwrap(); + let husk = hive.try_into_husk(false).unwrap(); assert_eq!(husk.iter_successes().count(), total_threads); } @@ -636,7 +638,7 @@ mod tests { hive.join(); // Ensure that none of the threads have panicked assert_eq!(hive.num_panics(), TEST_TASKS); - let husk = hive.try_into_husk().unwrap(); + let husk = hive.try_into_husk(false).unwrap(); assert_eq!(husk.num_panics(), TEST_TASKS); } @@ -789,7 +791,7 @@ mod tests { let debug = format!("{:?}", hive); assert_eq!( debug, - "Hive { task_tx: Sender { .. }, shared: Shared { name: None, num_threads: 4, num_tasks_queued: 0, num_tasks_active: 0 } }" + "Hive { shared: Shared { name: None, num_threads: 4, num_tasks_queued: 0, num_tasks_active: 0 } }" ); let hive: THive = ChannelBuilder::empty() @@ -800,7 +802,7 @@ mod tests { let debug = format!("{:?}", hive); assert_eq!( debug, - "Hive { task_tx: Sender { .. }, shared: Shared { name: \"hello\", num_threads: 4, num_tasks_queued: 0, num_tasks_active: 0 } }" + "Hive { shared: Shared { name: \"hello\", num_threads: 4, num_tasks_queued: 0, num_tasks_active: 0 } }" ); let hive = thunk_hive(4, true); @@ -809,7 +811,7 @@ mod tests { let debug = format!("{:?}", hive); assert_eq!( debug, - "Hive { task_tx: Sender { .. }, shared: Shared { name: None, num_threads: 4, num_tasks_queued: 0, num_tasks_active: 1 } }" + "Hive { shared: Shared { name: None, num_threads: 4, num_tasks_queued: 0, num_tasks_active: 1 } }" ); } @@ -975,7 +977,7 @@ mod tests { i }) }), - &tx, + tx, ); let (mut outcome_task_ids, values): (Vec, Vec) = rx .iter() @@ -1055,7 +1057,7 @@ mod tests { i }) }), - &tx, + tx, ); let (mut outcome_task_ids, values): (Vec, Vec) = rx .iter() @@ -1127,7 +1129,7 @@ mod tests { .num_threads(4) .build(); let (tx, rx) = super::outcome_channel(); - let (mut task_ids, state) = hive.scan_send(0..10, &tx, 0, |acc, i| { + let (mut task_ids, state) = hive.scan_send(0..10, tx, 0, |acc, i| { *acc += i; *acc }); @@ -1163,7 +1165,7 @@ mod tests { .num_threads(4) .build(); let (tx, rx) = super::outcome_channel(); - let (results, state) = hive.try_scan_send(0..10, &tx, 0, |acc, i| { + let (results, state) = hive.try_scan_send(0..10, tx, 0, |acc, i| { *acc += i; Ok::<_, String>(*acc) }); @@ -1302,7 +1304,7 @@ mod tests { let hive1 = thunk_hive::(8, false); let task_ids = hive1.map_store((0..8u8).map(|i| Thunk::of(move || i))); hive1.join(); - let mut husk1 = hive1.try_into_husk().unwrap(); + let mut husk1 = hive1.try_into_husk(false).unwrap(); for i in task_ids.iter() { assert!(husk1.outcomes_deref().get(i).unwrap().is_success()); assert!(matches!(husk1.get(*i), Some(Outcome::Success { .. }))); @@ -1321,7 +1323,7 @@ mod tests { }) })); hive2.join(); - let mut husk2 = hive2.try_into_husk().unwrap(); + let mut husk2 = hive2.try_into_husk(false).unwrap(); let mut outputs1 = husk1 .remove_all() @@ -1345,7 +1347,7 @@ mod tests { }) })); hive3.join(); - let husk3 = hive3.try_into_husk().unwrap(); + let husk3 = hive3.try_into_husk(false).unwrap(); let (_, outcomes3) = husk3.into_parts(); let mut outputs3 = outcomes3 .into_iter() @@ -1660,7 +1662,12 @@ mod tests { assert_eq!(output, b"abcdefgh"); // shutdown the hive, use the Queen to wait on child processes, and report errors - let mut queen = hive.try_into_husk().unwrap().into_parts().0.into_inner(); + let mut queen = hive + .try_into_husk(false) + .unwrap() + .into_parts() + .0 + .into_inner(); let (wait_ok, wait_err): (Vec<_>, Vec<_>) = queen.wait_for_all().into_iter().partition(Result::is_ok); if !wait_err.is_empty() { diff --git a/src/lib.rs b/src/lib.rs index a60facf..4faf4c6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -96,6 +96,15 @@ //! more performant than the ordered iterator //! * The methods with the `_send` suffix accept a channel [`Sender`](crate::channel::Sender) and //! send the `Outcome`s to that channel as they are completed +//! * Note that, for these methods, the `tx` parameter is of type `Borrow>`, which +//! allows you to pass in either a value or a reference. Passing a value causes the `Sender` +//! to be dropped after the call, while passing a reference allows you to use the same +//! `Sender` for multiple `_send` calls. Note that in the later case, you need to explicitly +//! drop the sender (e.g., `drop(tx)`), pass it by value to the last `_send` call, or be +//! careful about how you obtain outcomes from the `Receiver` as methods such as `recv` and +//! `iter` will block until the `Sender` is dropped. You should *not* pass clones of the +//! `Sender` to `_send` methods as this results in slightly worse performance and still has +//! the requirement that you manually drop the original `Sender` value. //! * The methods with the `_store` suffix store the `Outcome`s in the `Hive`; these may be //! retrieved later using the [`Hive::take_stored()`](crate::hive::Hive::take_stored) method, //! using one of the `remove*` methods (which requires @@ -151,10 +160,11 @@ //! # fn main() { //! // create a hive to process `Thunk`s - no-argument closures with the //! // same return type (`i32`) -//! let hive = Builder::new() +//! let hive = ChannelBuilder::empty() //! .num_threads(4) //! .thread_name("thunk_hive") -//! .build_with_default::>(); +//! .with_worker_default::>() +//! .build(); //! //! // return results to your own channel... //! let (tx, rx) = outcome_channel(); @@ -220,7 +230,7 @@ //! fn apply( //! &mut self, //! input: Self::Input, -//! _: &Context +//! _: &Context //! ) -> WorkerResult { //! self.write_char(input).map_err(|error| { //! ApplyError::Fatal { input: Some(input), error } @@ -242,7 +252,7 @@ //! } //! } //! -//! impl Queen for CatQueen { +//! impl QueenMut for CatQueen { //! type Kind = CatWorker; //! //! fn create(&mut self) -> Self::Kind { @@ -277,9 +287,10 @@ //! //! # fn main() { //! // build the Hive -//! let hive = Builder::new() +//! let hive = ChannelBuilder::empty() //! .num_threads(4) -//! .build_default::(); +//! .with_queen_mut_default::() +//! .build(); //! //! // prepare inputs //! let inputs = (0..8).map(|i| 97 + i); @@ -300,9 +311,9 @@ //! //! // shutdown the hive, use the Queen to wait on child processes, and //! // report errors -//! let (mut queen, _outcomes) = hive.try_into_husk().unwrap().into_parts(); +//! let (queen, _outcomes) = hive.try_into_husk(false).unwrap().into_parts(); //! let (wait_ok, wait_err): (Vec<_>, Vec<_>) = -//! queen.wait_for_all().into_iter().partition(Result::is_ok); +//! queen.into_inner().wait_for_all().into_iter().partition(Result::is_ok); //! if !wait_err.is_empty() { //! panic!( //! "Error(s) occurred while waiting for child processes: {:?}", From 2ec401561870ec2c11a784d8120076872dee8dfd Mon Sep 17 00:00:00 2001 From: jdidion Date: Wed, 19 Feb 2025 22:49:43 -0800 Subject: [PATCH 16/67] add missing files --- src/hive/inner/queue/closed.rs | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 src/hive/inner/queue/closed.rs diff --git a/src/hive/inner/queue/closed.rs b/src/hive/inner/queue/closed.rs new file mode 100644 index 0000000..a0e447a --- /dev/null +++ b/src/hive/inner/queue/closed.rs @@ -0,0 +1,31 @@ +use crate::atomic::{Atomic, AtomicU8}; + +const OPEN: u8 = 0; +const CLOSED_PUSH: u8 = 1; +const CLOSED_POP: u8 = 2; + +pub struct Closed(AtomicU8); + +impl Closed { + pub fn is_closed(&self) -> bool { + self.0.get() > OPEN + } + + pub fn can_push(&self) -> bool { + self.0.get() < CLOSED_PUSH + } + + pub fn can_pop(&self) -> bool { + self.0.get() < CLOSED_POP + } + + pub fn set(&self, urgent: bool) { + self.0.set(if urgent { CLOSED_POP } else { CLOSED_PUSH }); + } +} + +impl Default for Closed { + fn default() -> Self { + Self(AtomicU8::new(OPEN)) + } +} From 5671c45137c81d49e9248d24f95d0e0c23afee20 Mon Sep 17 00:00:00 2001 From: jdidion Date: Thu, 20 Feb 2025 09:28:41 -0800 Subject: [PATCH 17/67] fix formatting --- src/bee/worker.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/bee/worker.rs b/src/bee/worker.rs index cd85e16..adbc93b 100644 --- a/src/bee/worker.rs +++ b/src/bee/worker.rs @@ -133,7 +133,11 @@ mod tests { type Output = u8; type Error = (); - fn apply_ref(&mut self, input: &Self::Input, _: &Context) -> RefWorkerResult { + fn apply_ref( + &mut self, + input: &Self::Input, + _: &Context, + ) -> RefWorkerResult { match *input { 0 => Err(ApplyRefError::Retryable(())), 1 => Err(ApplyRefError::Fatal(())), From dfa184c314dd8168cb3d88d1417961d1b842d9bb Mon Sep 17 00:00:00 2001 From: jdidion Date: Thu, 20 Feb 2025 09:29:42 -0800 Subject: [PATCH 18/67] fix --- Cargo.toml | 2 +- src/hive/inner/queue/channel.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index de54cfd..174fe97 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,7 +39,7 @@ name = "perf" harness = false [features] -default = ["batching", "retry"] +default = [] affinity = ["dep:core_affinity"] batching = [] retry = [] diff --git a/src/hive/inner/queue/channel.rs b/src/hive/inner/queue/channel.rs index 1a72689..1fa5bbf 100644 --- a/src/hive/inner/queue/channel.rs +++ b/src/hive/inner/queue/channel.rs @@ -237,7 +237,7 @@ impl LocalQueueShared { // fall back to requesting a task from the global queue #[cfg(not(feature = "batching"))] { - self.global.try_pop() + global.try_pop() } } From 887dd6fb78c4299fb22f8b80cd179f69071890d7 Mon Sep 17 00:00:00 2001 From: jdidion Date: Thu, 20 Feb 2025 09:42:38 -0800 Subject: [PATCH 19/67] fix test --- src/hive/mod.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/hive/mod.rs b/src/hive/mod.rs index 40672fa..d9d4550 100644 --- a/src/hive/mod.rs +++ b/src/hive/mod.rs @@ -1049,11 +1049,13 @@ mod tests { #[test] fn test_swarm_send() { let hive = thunk_hive::(8, false); + #[cfg(feature = "batching")] + assert_eq!(hive.worker_batch_limit(), 0); let (tx, rx) = super::outcome_channel(); let mut task_ids = hive.swarm_send( (0..8u8).map(|i| { Thunk::of(move || { - thread::sleep(Duration::from_millis((8 - i as u64) * 100)); + thread::sleep(Duration::from_millis((8 - i as u64) * 200)); i }) }), From 36801608a56837a4bd2760c4ce6dbe69ae73216b Mon Sep 17 00:00:00 2001 From: jdidion Date: Thu, 20 Feb 2025 11:26:28 -0800 Subject: [PATCH 20/67] cleanup --- Cargo.toml | 2 +- src/hive/hive.rs | 2 +- src/hive/inner/queue/channel.rs | 41 ++++++------ src/hive/inner/queue/mod.rs | 21 +++--- src/hive/inner/queue/retry.rs | 66 +++++++++---------- src/hive/inner/queue/{closed.rs => status.rs} | 6 +- src/hive/inner/queue/workstealing.rs | 20 +++--- src/hive/mod.rs | 16 +++-- 8 files changed, 91 insertions(+), 83 deletions(-) rename src/hive/inner/queue/{closed.rs => status.rs} (88%) diff --git a/Cargo.toml b/Cargo.toml index 174fe97..de54cfd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,7 +39,7 @@ name = "perf" harness = false [features] -default = [] +default = ["batching", "retry"] affinity = ["dep:core_affinity"] batching = [] retry = [] diff --git a/src/hive/hive.rs b/src/hive/hive.rs index cf22a60..8622127 100644 --- a/src/hive/hive.rs +++ b/src/hive/hive.rs @@ -550,7 +550,7 @@ impl, T: TaskQueues> Hive { // wait for all tasks to finish shared.wait_on_done(); // unwrap the Arc and return the inner Shared value - Some(super::unwrap_arc(shared)) + Some(super::unwrap_arc(shared).expect("timeout waiting to take ownership of shared data")) } /// Consumes this `Hive` and attempts to shut it down gracefully. diff --git a/src/hive/inner/queue/channel.rs b/src/hive/inner/queue/channel.rs index 1fa5bbf..1bdf321 100644 --- a/src/hive/inner/queue/channel.rs +++ b/src/hive/inner/queue/channel.rs @@ -1,7 +1,7 @@ //! Implementation of `TaskQueues` that uses `crossbeam` channels for the global queue (i.e., for //! sending tasks from the `Hive` to the worker threads) and a default implementation of local //! queues that depends on which combination of the `retry` and `batching` features are enabled. -use super::{Closed, Config, PopTaskError, Task, TaskQueues, Token, WorkerQueues}; +use super::{Config, PopTaskError, Status, Task, TaskQueues, Token, WorkerQueues}; use crate::bee::Worker; use crossbeam_channel::RecvTimeoutError; use crossbeam_queue::SegQueue; @@ -9,7 +9,7 @@ use parking_lot::RwLock; use std::sync::Arc; use std::time::Duration; -// time to wait in between polling the retry queue and then the task receiver +// time to wait when polling the global queue const RECV_TIMEOUT: Duration = Duration::from_secs(1); /// Type alias for the input task channel sender @@ -48,14 +48,14 @@ impl TaskQueues for ChannelTaskQueues { .for_each(|queue| queue.update(&self.global, config)); } - fn worker_queues(&self, thread_index: usize) -> Self::WorkerQueues { - ChannelWorkerQueues::new(&self.global, &self.local.read()[thread_index]) - } - fn try_push_global(&self, task: Task) -> Result<(), Task> { self.global.try_push(task) } + fn worker_queues(&self, thread_index: usize) -> Self::WorkerQueues { + ChannelWorkerQueues::new(&self.global, &self.local.read()[thread_index]) + } + fn close(&self, urgent: bool, _: Token) { self.global.close(urgent) } @@ -65,10 +65,12 @@ impl TaskQueues for ChannelTaskQueues { panic!("close must be called before drain"); } let mut tasks = Vec::new(); - let global = crate::hive::unwrap_arc(self.global); + let global = crate::hive::unwrap_arc(self.global) + .unwrap_or_else(|_| panic!("timeout waiting to take ownership of global queue")); global.drain_into(&mut tasks); for local in self.local.into_inner().into_iter() { - let local = crate::hive::unwrap_arc(local); + let local = crate::hive::unwrap_arc(local) + .unwrap_or_else(|_| panic!("timeout waiting to take ownership of local queue")); local.drain_into(&mut tasks); } tasks @@ -78,7 +80,7 @@ impl TaskQueues for ChannelTaskQueues { pub struct GlobalQueue { global_tx: TaskSender, global_rx: TaskReceiver, - closed: Closed, + status: Status, } impl GlobalQueue { @@ -87,13 +89,13 @@ impl GlobalQueue { Self { global_tx: tx, global_rx: rx, - closed: Default::default(), + status: Default::default(), } } #[inline] fn try_push(&self, task: Task) -> Result<(), Task> { - if !self.closed.can_push() { + if !self.status.can_push() { return Err(task); } self.global_tx.send(task).map_err(|err| err.into_inner()) @@ -111,23 +113,24 @@ impl GlobalQueue { } } - #[cfg(feature = "batching")] - fn try_iter(&self) -> impl Iterator> + '_ { - self.global_rx.try_iter() - } - #[inline] fn is_closed(&self) -> bool { - self.closed.is_closed() + self.status.is_closed() } fn close(&self, urgent: bool) { - self.closed.set(urgent); + self.status.set(urgent); } fn drain_into(self, tasks: &mut Vec>) { + tasks.reserve(self.global_rx.len()); tasks.extend(self.global_rx.try_iter()); } + + #[cfg(feature = "batching")] + fn try_iter(&self) -> impl Iterator> + '_ { + self.global_rx.try_iter() + } } pub struct ChannelWorkerQueues { @@ -216,7 +219,7 @@ impl LocalQueueShared { #[inline] fn try_pop(&self, global: &GlobalQueue) -> Result, PopTaskError> { - if !global.closed.can_pop() { + if !global.status.can_pop() { return Err(PopTaskError::Closed); } // first try to get a previously abandoned task diff --git a/src/hive/inner/queue/mod.rs b/src/hive/inner/queue/mod.rs index 4dae392..902e25f 100644 --- a/src/hive/inner/queue/mod.rs +++ b/src/hive/inner/queue/mod.rs @@ -1,5 +1,5 @@ mod channel; -mod closed; +mod status; #[cfg(feature = "retry")] mod retry; mod workstealing; @@ -7,7 +7,7 @@ mod workstealing; pub use self::channel::ChannelTaskQueues; pub use self::workstealing::WorkstealingTaskQueues; -use self::closed::Closed; +use self::status::Status; #[cfg(feature = "retry")] use self::retry::RetryQueue; use super::{Config, Task, Token}; @@ -40,14 +40,14 @@ pub trait TaskQueues: Sized + Send + Sync + 'static { /// Updates the queue settings from `config` for the given range of worker threads. fn update_for_threads(&self, start_index: usize, end_index: usize, config: &Config); - /// Returns a new `WorkerQueues` instance for a thread. - fn worker_queues(&self, thread_index: usize) -> Self::WorkerQueues; - /// Tries to add a task to the global queue. /// /// Returns an error with the task if the queue is disconnected. fn try_push_global(&self, task: Task) -> Result<(), Task>; + /// Returns a `WorkerQueues` instance for the worker thread with the given `index`. + fn worker_queues(&self, thread_index: usize) -> Self::WorkerQueues; + /// Closes this `GlobalQueue` so no more tasks may be pushed. /// /// If `urgent` is `true`, this also prevents queued tasks from being popped. @@ -55,10 +55,8 @@ pub trait TaskQueues: Sized + Send + Sync + 'static { /// The private `Token` is used to prevent this method from being called externally. fn close(&self, urgent: bool, token: Token); - /// Drains all tasks from all global and local queues and returns them as a `Vec`. - /// - /// This is a destructive operation - if `close` has not been called, it will be called before - /// draining the queues. + /// Consumes this `TaskQueues` and Drains all tasks from all global and local queues and + /// returns them as a `Vec`. /// /// This method panics if `close` has not been called. fn drain(self) -> Vec>; @@ -69,7 +67,8 @@ pub trait TaskQueues: Sized + Send + Sync + 'static { pub trait WorkerQueues { /// Attempts to add a task to the local queue if space is available, otherwise adds it to the /// global queue. If adding to the global queue fails, the task is added to a local "abandoned" - /// queue from which it may be popped or will otherwise be converted. + /// queue from which it may be popped or will otherwise be converted to an `Unprocessed` + /// outcome. fn push(&self, task: Task); /// Attempts to remove a task from the local queue for the given worker thread index. If there @@ -79,7 +78,7 @@ pub trait WorkerQueues { /// Returns an error if a task is not available, where each implementation may have a different /// definition of "available". /// - /// Also returns an error if the queue is empty or disconnected. + /// Also returns an error if the queues are closed. fn try_pop(&self) -> Result, PopTaskError>; /// Attempts to add `task` to the local retry queue. diff --git a/src/hive/inner/queue/retry.rs b/src/hive/inner/queue/retry.rs index e018ac9..4a55056 100644 --- a/src/hive/inner/queue/retry.rs +++ b/src/hive/inner/queue/retry.rs @@ -6,9 +6,9 @@ use std::cmp::Ordering; use std::collections::BinaryHeap; use std::time::{Duration, Instant}; -/// A queue where each item has an associated `Instant` at which it will be available. +/// A task queue where each task has an associated `Instant` at which it will be available. /// -/// This is implemented internally as a `UnsafeCell`. +/// This is implemented internally as `UnsafeCell`. /// /// SAFETY: This data structure is designed to enable the queue to be modified (using `push` and /// `try_pop`) by a *single thread* using interior mutability. The `drain` method is called by a @@ -24,6 +24,7 @@ pub struct RetryQueue { } impl RetryQueue { + /// Creates a new `RetryQueue` with the given `delay_factor` (in nanoseconds). pub fn new(delay_factor: u64) -> Self { Self { inner: UnsafeCell::new(BinaryHeap::new()), @@ -31,30 +32,31 @@ impl RetryQueue { } } + /// Changes the delay factor for the queue. pub fn set_delay_factor(&self, delay_factor: u64) { self.delay_factor.set(delay_factor); } - /// Pushes an item onto the queue. Returns the `Instant` at which the item will be available, - /// or an error with `item` if there was an error pushing the item. + /// Pushes an item onto the queue. Returns the `Instant` at which the task will be available, + /// or an error with `task` if there was an error pushing it. /// /// SAFETY: this method is only ever called within a single thread. pub fn try_push(&self, task: Task) -> Result> { - // compute the delay - let delay = 2u64 - .checked_pow(task.attempt - 1) - .and_then(|multiplier| { - self.delay_factor - .get() - .checked_mul(multiplier) - .or(Some(u64::MAX)) - .map(Duration::from_nanos) - }) - .unwrap_or_default(); unsafe { match self.inner.get().as_mut() { Some(queue) => { - let delayed = Delayed::new(task, delay); + // compute the delay + let delay = 2u64 + .checked_pow(task.attempt - 1) + .and_then(|multiplier| { + self.delay_factor + .get() + .checked_mul(multiplier) + .or(Some(u64::MAX)) + .map(Duration::from_nanos) + }) + .unwrap_or_default(); + let delayed = DelayedTask::new(task, delay); let until = delayed.until; queue.push(delayed); Ok(until) @@ -64,7 +66,7 @@ impl RetryQueue { } } - /// Returns the item at the head of the queue, if one exists and is available (i.e., its delay + /// Returns the task at the head of the queue, if one exists and is available (i.e., its delay /// has been exceeded), and removes it. /// /// SAFETY: this method is only ever called within a single thread. @@ -87,7 +89,7 @@ impl RetryQueue { } } - /// Consumes this `RetryQueue` and drains all items from the queue into `sink`. + /// Consumes this `RetryQueue` and drains all tasks from the queue into `sink`. pub fn drain_into(self, sink: &mut Vec>) { let mut queue = self.inner.into_inner(); sink.reserve(queue.len()); @@ -97,17 +99,15 @@ impl RetryQueue { unsafe impl Sync for RetryQueue {} -type DelayedTask = Delayed>; - -#[derive(Debug)] -struct Delayed { - value: T, +/// Wrapper for a Task with an associated `Instant` at which it will be available. +struct DelayedTask { + value: Task, until: Instant, } -impl Delayed { - pub fn new(value: T, delay: Duration) -> Self { - Delayed { +impl DelayedTask { + pub fn new(value: Task, delay: Duration) -> Self { + Self { value, until: Instant::now() + delay, } @@ -119,25 +119,25 @@ impl Delayed { /// /// Earlier entries have higher priority (should be popped first), so they are Greater that later /// entries. -impl Ord for Delayed { - fn cmp(&self, other: &Delayed) -> Ordering { +impl Ord for DelayedTask { + fn cmp(&self, other: &DelayedTask) -> Ordering { other.until.cmp(&self.until) } } -impl PartialOrd for Delayed { - fn partial_cmp(&self, other: &Delayed) -> Option { +impl PartialOrd for DelayedTask { + fn partial_cmp(&self, other: &DelayedTask) -> Option { Some(self.cmp(other)) } } -impl PartialEq for Delayed { - fn eq(&self, other: &Delayed) -> bool { +impl PartialEq for DelayedTask { + fn eq(&self, other: &DelayedTask) -> bool { self.cmp(other) == Ordering::Equal } } -impl Eq for Delayed {} +impl Eq for DelayedTask {} #[cfg(test)] mod tests { diff --git a/src/hive/inner/queue/closed.rs b/src/hive/inner/queue/status.rs similarity index 88% rename from src/hive/inner/queue/closed.rs rename to src/hive/inner/queue/status.rs index a0e447a..a59d550 100644 --- a/src/hive/inner/queue/closed.rs +++ b/src/hive/inner/queue/status.rs @@ -4,9 +4,9 @@ const OPEN: u8 = 0; const CLOSED_PUSH: u8 = 1; const CLOSED_POP: u8 = 2; -pub struct Closed(AtomicU8); +pub struct Status(AtomicU8); -impl Closed { +impl Status { pub fn is_closed(&self) -> bool { self.0.get() > OPEN } @@ -24,7 +24,7 @@ impl Closed { } } -impl Default for Closed { +impl Default for Status { fn default() -> Self { Self(AtomicU8::new(OPEN)) } diff --git a/src/hive/inner/queue/workstealing.rs b/src/hive/inner/queue/workstealing.rs index 48acb9d..dd3e01a 100644 --- a/src/hive/inner/queue/workstealing.rs +++ b/src/hive/inner/queue/workstealing.rs @@ -4,7 +4,7 @@ //! tries to steal a task from the global queue and falls back to stealing from another worker //! thread. If the `batching` feature is enabled, a worker thread will try to fill its local queue //! up to the limit when stealing from the global queue. -use super::{Closed, Config, PopTaskError, Task, TaskQueues, Token, WorkerQueues}; +use super::{Config, PopTaskError, Status, Task, TaskQueues, Token, WorkerQueues}; #[cfg(feature = "batching")] use crate::atomic::Atomic; use crate::bee::Worker; @@ -63,10 +63,12 @@ impl TaskQueues for WorkstealingTaskQueues { panic!("close must be called before drain"); } let mut tasks = Vec::new(); - let global = crate::hive::unwrap_arc(self.global); + let global = crate::hive::unwrap_arc(self.global) + .unwrap_or_else(|_| panic!("timeout waiting to take ownership of global queue")); global.drain_into(&mut tasks); for local in self.local.into_inner().into_iter() { - let local = crate::hive::unwrap_arc(local); + let local = crate::hive::unwrap_arc(local) + .unwrap_or_else(|_| panic!("timeout waiting to take ownership of local queue")); local.drain_into(&mut tasks); } tasks @@ -76,7 +78,7 @@ impl TaskQueues for WorkstealingTaskQueues { pub struct GlobalQueue { queue: Injector>, stealers: RwLock>>>, - closed: Closed, + status: Status, } impl GlobalQueue { @@ -84,7 +86,7 @@ impl GlobalQueue { Self { queue: Injector::new(), stealers: Default::default(), - closed: Default::default(), + status: Default::default(), } } @@ -93,7 +95,7 @@ impl GlobalQueue { } fn try_push(&self, task: Task) -> Result<(), Task> { - if !self.closed.can_push() { + if !self.status.can_push() { return Err(task); } self.queue.push(task); @@ -141,11 +143,11 @@ impl GlobalQueue { } fn is_closed(&self) -> bool { - self.closed.is_closed() + self.status.is_closed() } fn close(&self, urgent: bool) { - self.closed.set(urgent); + self.status.set(urgent); } fn drain_into(self, tasks: &mut Vec>) { @@ -242,7 +244,7 @@ impl LocalQueueShared { global: &GlobalQueue, local_batch: &crossbeam_deque::Worker>, ) -> Result, PopTaskError> { - if !global.closed.can_pop() { + if !global.status.can_pop() { return Err(PopTaskError::Closed); } // first try to get a previously abandoned task diff --git a/src/hive/mod.rs b/src/hive/mod.rs index d9d4550..e4cacfd 100644 --- a/src/hive/mod.rs +++ b/src/hive/mod.rs @@ -397,17 +397,20 @@ pub mod prelude { }; } -fn unwrap_arc(mut arc: std::sync::Arc) -> T { +/// Utility function to loop (with exponential backoff) waiting for other references to `arc` to +/// drop so it can be unwrapped into its inner value. +/// +/// If `arc` cannot be unwrapped with a certain number of loops (with an exponentially increasing +/// amount of time between each iteration), `arc` is returned as an error. +fn unwrap_arc(mut arc: std::sync::Arc) -> Result> { + const MAX_LOOPS: usize = 100; // wait for worker threads to drop, then take ownership of the shared data and convert it // into a Husk let mut backoff = None::; - loop { - // TODO: may want to have some timeout or other kind of limit to prevent this from - // looping forever if a worker thread somehow gets stuck, or if the `num_referrers` - // counter is corrupted + for _ in 0..MAX_LOOPS { arc = match std::sync::Arc::try_unwrap(arc) { Ok(inner) => { - return inner; + return Ok(inner); } Err(arc) => { backoff @@ -417,6 +420,7 @@ fn unwrap_arc(mut arc: std::sync::Arc) -> T { } }; } + Err(arc) } #[cfg(test)] From 13f6ce27b8708908c79b02136ce474e54ae6c2a1 Mon Sep 17 00:00:00 2001 From: jdidion Date: Thu, 20 Feb 2025 11:41:43 -0800 Subject: [PATCH 21/67] fix --- src/hive/hive.rs | 5 ++- src/hive/inner/queue/channel.rs | 4 +- src/hive/inner/queue/workstealing.rs | 4 +- src/hive/mod.rs | 56 ++++++++++++++++------------ 4 files changed, 41 insertions(+), 28 deletions(-) diff --git a/src/hive/hive.rs b/src/hive/hive.rs index 8622127..88f9397 100644 --- a/src/hive/hive.rs +++ b/src/hive/hive.rs @@ -550,7 +550,10 @@ impl, T: TaskQueues> Hive { // wait for all tasks to finish shared.wait_on_done(); // unwrap the Arc and return the inner Shared value - Some(super::unwrap_arc(shared).expect("timeout waiting to take ownership of shared data")) + Some( + super::util::unwrap_arc(shared) + .expect("timeout waiting to take ownership of shared data"), + ) } /// Consumes this `Hive` and attempts to shut it down gracefully. diff --git a/src/hive/inner/queue/channel.rs b/src/hive/inner/queue/channel.rs index 1bdf321..bd354f5 100644 --- a/src/hive/inner/queue/channel.rs +++ b/src/hive/inner/queue/channel.rs @@ -65,11 +65,11 @@ impl TaskQueues for ChannelTaskQueues { panic!("close must be called before drain"); } let mut tasks = Vec::new(); - let global = crate::hive::unwrap_arc(self.global) + let global = crate::hive::util::unwrap_arc(self.global) .unwrap_or_else(|_| panic!("timeout waiting to take ownership of global queue")); global.drain_into(&mut tasks); for local in self.local.into_inner().into_iter() { - let local = crate::hive::unwrap_arc(local) + let local = crate::hive::util::unwrap_arc(local) .unwrap_or_else(|_| panic!("timeout waiting to take ownership of local queue")); local.drain_into(&mut tasks); } diff --git a/src/hive/inner/queue/workstealing.rs b/src/hive/inner/queue/workstealing.rs index dd3e01a..a05c77c 100644 --- a/src/hive/inner/queue/workstealing.rs +++ b/src/hive/inner/queue/workstealing.rs @@ -63,11 +63,11 @@ impl TaskQueues for WorkstealingTaskQueues { panic!("close must be called before drain"); } let mut tasks = Vec::new(); - let global = crate::hive::unwrap_arc(self.global) + let global = crate::hive::util::unwrap_arc(self.global) .unwrap_or_else(|_| panic!("timeout waiting to take ownership of global queue")); global.drain_into(&mut tasks); for local in self.local.into_inner().into_iter() { - let local = crate::hive::unwrap_arc(local) + let local = crate::hive::util::unwrap_arc(local) .unwrap_or_else(|_| panic!("timeout waiting to take ownership of local queue")); local.drain_into(&mut tasks); } diff --git a/src/hive/mod.rs b/src/hive/mod.rs index e4cacfd..840ec95 100644 --- a/src/hive/mod.rs +++ b/src/hive/mod.rs @@ -397,30 +397,40 @@ pub mod prelude { }; } -/// Utility function to loop (with exponential backoff) waiting for other references to `arc` to -/// drop so it can be unwrapped into its inner value. -/// -/// If `arc` cannot be unwrapped with a certain number of loops (with an exponentially increasing -/// amount of time between each iteration), `arc` is returned as an error. -fn unwrap_arc(mut arc: std::sync::Arc) -> Result> { - const MAX_LOOPS: usize = 100; - // wait for worker threads to drop, then take ownership of the shared data and convert it - // into a Husk - let mut backoff = None::; - for _ in 0..MAX_LOOPS { - arc = match std::sync::Arc::try_unwrap(arc) { - Ok(inner) => { - return Ok(inner); - } - Err(arc) => { - backoff - .get_or_insert_with(crossbeam_utils::Backoff::new) - .spin(); - arc - } - }; +mod util { + use crossbeam_utils::Backoff; + use std::sync::Arc; + use std::time::{Duration, Instant}; + + const MAX_WAIT: Duration = Duration::from_secs(10); + + /// Utility function to loop (with exponential backoff) waiting for other references to `arc` to + /// drop so it can be unwrapped into its inner value. + /// + /// If `arc` cannot be unwrapped with a certain amount of time (with an exponentially + /// increasing gap between each iteration), `arc` is returned as an error. + pub fn unwrap_arc(mut arc: Arc) -> Result> { + // wait for worker threads to drop, then take ownership of the shared data and convert it + // into a Husk + let mut backoff = None::; + let mut start = None::; + loop { + arc = match std::sync::Arc::try_unwrap(arc) { + Ok(inner) => { + return Ok(inner); + } + Err(arc) if start.is_none() => { + let _ = start.insert(Instant::now()); + arc + } + Err(arc) if Instant::now() - start.unwrap() > MAX_WAIT => return Err(arc), + Err(arc) => { + backoff.get_or_insert_with(Backoff::new).spin(); + arc + } + }; + } } - Err(arc) } #[cfg(test)] From 3ac8020787ae506d02fc06a3702650c7fa36bebb Mon Sep 17 00:00:00 2001 From: jdidion Date: Thu, 20 Feb 2025 12:17:17 -0800 Subject: [PATCH 22/67] fix import order --- src/hive/inner/queue/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/hive/inner/queue/mod.rs b/src/hive/inner/queue/mod.rs index 902e25f..8980b5f 100644 --- a/src/hive/inner/queue/mod.rs +++ b/src/hive/inner/queue/mod.rs @@ -1,15 +1,15 @@ mod channel; -mod status; #[cfg(feature = "retry")] mod retry; +mod status; mod workstealing; pub use self::channel::ChannelTaskQueues; pub use self::workstealing::WorkstealingTaskQueues; -use self::status::Status; #[cfg(feature = "retry")] use self::retry::RetryQueue; +use self::status::Status; use super::{Config, Task, Token}; use crate::bee::Worker; From 022fb1bf96d0fd38172c14bd7e6d8a8364a32875 Mon Sep 17 00:00:00 2001 From: jdidion Date: Thu, 20 Feb 2025 12:32:28 -0800 Subject: [PATCH 23/67] update docs --- README.md | 3 +-- src/hive/mod.rs | 25 ++++++++++++++++++++++++- src/lib.rs | 9 --------- 3 files changed, 25 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 68266e4..ea26646 100644 --- a/README.md +++ b/README.md @@ -108,8 +108,7 @@ There are multiple methods in each group that differ by how the task results (ca * The methods with the `_unordered` suffix instead return an unordered iterator, which may be more performant than the ordered iterator * The methods with the `_send` suffix accept a channel `Sender` and send the `Outcome`s to that - channel as they are completed. - * Note that, for these methods, the `tx` parameter is of type `Borrow>`, which allows you to pass in either a value or a reference. Passing a value causes the `Sender` to be dropped after the call, while passing a reference allows you to use the same `Sender` for multiple `_send` calls. Note that in the later case, you need to explicitly drop the sender (e.g., `drop(tx)`), pass it by value to the last `_send` call, or be careful about how you obtain outcomes from the `Receiver` as methods such as `recv` and `iter` will block until the `Sender` is dropped. You should *not* pass clones of the `Sender` to `_send` methods as this results in slightly worse performance and still has the requirement that you manually drop the original `Sender` value. + channel as they are completed (see this [note](https://docs.rs/beekeeper/latest/beekeeper/hive/index.html#outcome-channels)). * The methods with the `_store` suffix store the `Outcome`s in the `Hive`; these may be retrieved later using the [`Hive::take_stored()`](https://docs.rs/beekeeper/latest/beekeeper/hive/struct.Hive.html#method.take_stored) method, using one of the `remove*` methods (which requires diff --git a/src/hive/mod.rs b/src/hive/mod.rs index 840ec95..c20c1ff 100644 --- a/src/hive/mod.rs +++ b/src/hive/mod.rs @@ -183,7 +183,8 @@ //! //! Each group of functions has multiple variants: //! * The methods that end with `_send` all take a channel sender as a second argument and will -//! deliver results to that channel as they become available. +//! deliver results to that channel as they become available. See the note below on proper use of +//! channels. //! * The methods that end with `_store` are all non-blocking functions that return the task IDs //! associated with the submitted tasks and will store the task results in the hive. The outcomes //! can be retrieved from the hive later by their IDs, e.g., using `remove_success`. @@ -212,6 +213,28 @@ //! You can create an instance of the enabled outcome channel type using the [`outcome_channel`] //! function. //! +//! `Hive` as several methods (with the `_send` suffix) for submitting tasks whose outcomes will be +//! delivered to a user-specified channel. Note that, for these methods, the `tx` parameter is of +//! type `Borrow>`, which allows you to pass in either a value or a reference. +//! Passing a value causes the `Sender` to be dropped after the call; passing a reference allows +//! you to use the same `Sender` for multiple `_send` calls, but you need to explicitly drop the +//! sender (e.g., `drop(tx)`), pass it by value to the last `_send` call, or be careful about how +//! you obtain outcomes from the `Receiver`. Methods such as `recv` and `iter` will block until the +//! `Sender` is dropped. Since `Receiver` implements `Iterator`, you can use the methods of +//! [`OutcomeIteratorExt`] to iterate over the outcomes for specific task IDs. +//! +//! ```no_run +//! use beekeeper::hive::prelude::*; +//! let (tx, rx) = outcome_channel(); +//! let hive = ... +//! let task_ids = hive.map_send(0..10, tx); +//! rx.select_unordered_outputs(task_ids).for_each(|output| ...); +//! ``` +//! +//! You should *not* pass clones of the `Sender` to `_send` methods as this results in slightly +//! worse performance and still has the requirement that you manually drop the original `Sender` +//! value. +//! //! # Retrieving outcomes //! //! Each task that is successfully submitted to a `Hive` will have a corresponding `Outcome`. diff --git a/src/lib.rs b/src/lib.rs index 4faf4c6..63d64c7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -96,15 +96,6 @@ //! more performant than the ordered iterator //! * The methods with the `_send` suffix accept a channel [`Sender`](crate::channel::Sender) and //! send the `Outcome`s to that channel as they are completed -//! * Note that, for these methods, the `tx` parameter is of type `Borrow>`, which -//! allows you to pass in either a value or a reference. Passing a value causes the `Sender` -//! to be dropped after the call, while passing a reference allows you to use the same -//! `Sender` for multiple `_send` calls. Note that in the later case, you need to explicitly -//! drop the sender (e.g., `drop(tx)`), pass it by value to the last `_send` call, or be -//! careful about how you obtain outcomes from the `Receiver` as methods such as `recv` and -//! `iter` will block until the `Sender` is dropped. You should *not* pass clones of the -//! `Sender` to `_send` methods as this results in slightly worse performance and still has -//! the requirement that you manually drop the original `Sender` value. //! * The methods with the `_store` suffix store the `Outcome`s in the `Hive`; these may be //! retrieved later using the [`Hive::take_stored()`](crate::hive::Hive::take_stored) method, //! using one of the `remove*` methods (which requires From a2313b5d96b1b2dbd4d7f0cea6968218e0d96594 Mon Sep 17 00:00:00 2001 From: jdidion Date: Thu, 20 Feb 2025 13:35:55 -0800 Subject: [PATCH 24/67] refactor queue builders into trait, provide builder functions --- benches/perf.rs | 2 +- src/hive/builder/bee.rs | 2 +- src/hive/builder/channel.rs | 71 ------------------------------- src/hive/builder/full.rs | 2 +- src/hive/builder/mod.rs | 32 ++++++++++++-- src/hive/builder/open.rs | 2 +- src/hive/builder/workstealing.rs | 73 -------------------------------- src/hive/hive.rs | 7 ++- src/hive/husk.rs | 1 + src/hive/inner/builder.rs | 33 ++++++++------- src/hive/mod.rs | 15 ++++--- src/util.rs | 4 +- 12 files changed, 67 insertions(+), 177 deletions(-) delete mode 100644 src/hive/builder/channel.rs delete mode 100644 src/hive/builder/workstealing.rs diff --git a/benches/perf.rs b/benches/perf.rs index 8bce664..3dbda77 100644 --- a/benches/perf.rs +++ b/benches/perf.rs @@ -1,5 +1,5 @@ use beekeeper::bee::stock::EchoWorker; -use beekeeper::hive::{outcome_channel, Builder, ChannelBuilder}; +use beekeeper::hive::{outcome_channel, Builder, ChannelBuilder, TaskQueuesBuilder}; use divan::{bench, black_box_drop, AllocProfiler, Bencher}; use itertools::iproduct; diff --git a/src/hive/builder/bee.rs b/src/hive/builder/bee.rs index 2cb2083..0a20401 100644 --- a/src/hive/builder/bee.rs +++ b/src/hive/builder/bee.rs @@ -122,7 +122,7 @@ impl BeeBuilder> { } impl BuilderConfig for BeeBuilder { - fn config(&mut self, _: Token) -> &mut Config { + fn config_ref(&mut self, _: Token) -> &mut Config { &mut self.config } } diff --git a/src/hive/builder/channel.rs b/src/hive/builder/channel.rs deleted file mode 100644 index 1525d5f..0000000 --- a/src/hive/builder/channel.rs +++ /dev/null @@ -1,71 +0,0 @@ -use super::{BuilderConfig, FullBuilder, Token}; -use crate::bee::{CloneQueen, DefaultQueen, Queen, QueenCell, QueenMut, Worker}; -use crate::hive::{ChannelTaskQueues, Config}; - -#[derive(Clone, Default)] -pub struct ChannelBuilder(Config); - -impl ChannelBuilder { - /// Creates a new `ChannelBuilder` with the given queen and no options configured. - pub fn empty() -> Self { - Self(Config::empty()) - } - - /// Consumes this `Builder` and returns a new [`FullBuilder`] using the given [`Queen`] to - /// create [`Worker`]s. - pub fn with_queen(self, queen: I) -> FullBuilder> - where - Q: Queen, - I: Into, - { - FullBuilder::from(self.0, queen.into()) - } - - /// Consumes this `Builder` and returns a new [`FullBuilder`] using a [`Queen`] created with - /// [`Q::default()`](std::default::Default) to create [`Worker`]s. - pub fn with_queen_default(self) -> FullBuilder> - where - Q: Queen + Default, - { - FullBuilder::from(self.0, Q::default()) - } - - /// Consumes this `Builder` and returns a new [`FullBuilder`] using a [`QueenMut`] created with - /// [`Q::default()`](std::default::Default) to create [`Worker`]s. - pub fn with_queen_mut_default(self) -> FullBuilder, ChannelTaskQueues> - where - Q: QueenMut + Default, - { - FullBuilder::from(self.0, QueenCell::new(Q::default())) - } - - /// Consumes this `Builder` and returns a new [`FullBuilder`] with [`Worker`]s created by - /// cloning `worker`. - pub fn with_worker(self, worker: W) -> FullBuilder, ChannelTaskQueues> - where - W: Worker + Send + Sync + Clone, - { - FullBuilder::from(self.0, CloneQueen::new(worker)) - } - - /// Consumes this `Builder` and returns a new [`FullBuilder`] with [`Worker`]s created using - /// [`W::default()`](std::default::Default). - pub fn with_worker_default(self) -> FullBuilder, ChannelTaskQueues> - where - W: Worker + Send + Sync + Default, - { - FullBuilder::from(self.0, DefaultQueen::default()) - } -} - -impl BuilderConfig for ChannelBuilder { - fn config(&mut self, _: Token) -> &mut Config { - &mut self.0 - } -} - -impl From for ChannelBuilder { - fn from(value: Config) -> Self { - Self(value) - } -} diff --git a/src/hive/builder/full.rs b/src/hive/builder/full.rs index 902e878..138399c 100644 --- a/src/hive/builder/full.rs +++ b/src/hive/builder/full.rs @@ -42,7 +42,7 @@ impl> FullBuilder { } impl> BuilderConfig for FullBuilder { - fn config(&mut self, _: Token) -> &mut Config { + fn config_ref(&mut self, _: Token) -> &mut Config { &mut self.config } } diff --git a/src/hive/builder/mod.rs b/src/hive/builder/mod.rs index b707dfc..e211ed8 100644 --- a/src/hive/builder/mod.rs +++ b/src/hive/builder/mod.rs @@ -40,16 +40,40 @@ //! only the first `num_threads` indices will be used. //! mod bee; -mod channel; mod full; mod open; -mod workstealing; +mod queue; pub use bee::BeeBuilder; -pub use channel::ChannelBuilder; pub use full::FullBuilder; pub use open::OpenBuilder; -pub use workstealing::WorkstealingBuilder; +pub use queue::channel::ChannelBuilder; +pub use queue::workstealing::WorkstealingBuilder; +pub use queue::TaskQueuesBuilder; + +pub fn open(with_defaults: bool) -> OpenBuilder { + if with_defaults { + OpenBuilder::default() + } else { + OpenBuilder::empty() + } +} + +pub fn channel(with_defaults: bool) -> ChannelBuilder { + if with_defaults { + ChannelBuilder::default() + } else { + ChannelBuilder::empty() + } +} + +pub fn workstealing(with_defaults: bool) -> WorkstealingBuilder { + if with_defaults { + WorkstealingBuilder::default() + } else { + WorkstealingBuilder::empty() + } +} use crate::hive::inner::{BuilderConfig, Token}; diff --git a/src/hive/builder/open.rs b/src/hive/builder/open.rs index fd539b1..9985fe4 100644 --- a/src/hive/builder/open.rs +++ b/src/hive/builder/open.rs @@ -254,7 +254,7 @@ impl OpenBuilder { } impl BuilderConfig for OpenBuilder { - fn config(&mut self, _: Token) -> &mut Config { + fn config_ref(&mut self, _: Token) -> &mut Config { &mut self.0 } } diff --git a/src/hive/builder/workstealing.rs b/src/hive/builder/workstealing.rs deleted file mode 100644 index e8f7006..0000000 --- a/src/hive/builder/workstealing.rs +++ /dev/null @@ -1,73 +0,0 @@ -use super::{BuilderConfig, FullBuilder, Token}; -use crate::bee::{CloneQueen, DefaultQueen, Queen, QueenCell, QueenMut, Worker}; -use crate::hive::{Config, WorkstealingTaskQueues}; - -#[derive(Clone, Default)] -pub struct WorkstealingBuilder(Config); - -impl WorkstealingBuilder { - /// Creates a new `WorkstealingBuilder` with the given queen and no options configured. - pub fn empty() -> Self { - Self(Config::empty()) - } - - /// Consumes this `Builder` and returns a new [`FullBuilder`] using the given [`Queen`] to - /// create [`Worker`]s. - pub fn with_queen(self, queen: I) -> FullBuilder> - where - Q: Queen, - I: Into, - { - FullBuilder::from(self.0, queen.into()) - } - - /// Consumes this `Builder` and returns a new [`FullBuilder`] using a [`Queen`] created with - /// [`Q::default()`](std::default::Default) to create [`Worker`]s. - pub fn with_queen_default(self) -> FullBuilder> - where - Q: Queen + Default, - { - FullBuilder::from(self.0, Q::default()) - } - - /// Consumes this `Builder` and returns a new [`FullBuilder`] using a [`QueenMut`] created with - /// [`Q::default()`](std::default::Default) to create [`Worker`]s. - pub fn with_queen_mut_default( - self, - ) -> FullBuilder, WorkstealingTaskQueues> - where - Q: QueenMut + Default, - { - FullBuilder::from(self.0, QueenCell::new(Q::default())) - } - - /// Consumes this `Builder` and returns a new [`FullBuilder`] with [`Worker`]s created by - /// cloning `worker`. - pub fn with_worker(self, worker: W) -> FullBuilder, WorkstealingTaskQueues> - where - W: Worker + Send + Sync + Clone, - { - FullBuilder::from(self.0, CloneQueen::new(worker)) - } - - /// Consumes this `Builder` and returns a new [`FullBuilder`] with [`Worker`]s created using - /// [`W::default()`](std::default::Default). - pub fn with_worker_default(self) -> FullBuilder, WorkstealingTaskQueues> - where - W: Worker + Send + Sync + Default, - { - FullBuilder::from(self.0, DefaultQueen::default()) - } -} - -impl BuilderConfig for WorkstealingBuilder { - fn config(&mut self, _: Token) -> &mut Config { - &mut self.0 - } -} - -impl From for WorkstealingBuilder { - fn from(value: Config) -> Self { - Self(value) - } -} diff --git a/src/hive/hive.rs b/src/hive/hive.rs index 88f9397..cdf9ec0 100644 --- a/src/hive/hive.rs +++ b/src/hive/hive.rs @@ -1,6 +1,7 @@ use super::{ ChannelBuilder, ChannelTaskQueues, Config, DerefOutcomes, Husk, Outcome, OutcomeBatch, - OutcomeIteratorExt, OutcomeSender, Shared, SpawnError, TaskQueues, WorkerQueues, + OutcomeIteratorExt, OutcomeSender, Shared, SpawnError, TaskQueues, TaskQueuesBuilder, + WorkerQueues, }; use crate::bee::{DefaultQueen, Queen, TaskContext, TaskId, Worker}; use std::borrow::Borrow; @@ -1003,7 +1004,9 @@ where mod tests { use super::Poisoned; use crate::bee::stock::{Caller, Thunk, ThunkWorker}; - use crate::hive::{outcome_channel, Builder, ChannelBuilder, Outcome, OutcomeIteratorExt}; + use crate::hive::{ + outcome_channel, Builder, ChannelBuilder, Outcome, OutcomeIteratorExt, TaskQueuesBuilder, + }; use std::collections::HashMap; use std::thread; use std::time::Duration; diff --git a/src/hive/husk.rs b/src/hive/husk.rs index 95b18cd..6051e67 100644 --- a/src/hive/husk.rs +++ b/src/hive/husk.rs @@ -140,6 +140,7 @@ mod tests { use crate::hive::ChannelTaskQueues; use crate::hive::{ outcome_channel, Builder, ChannelBuilder, Outcome, OutcomeIteratorExt, OutcomeStore, + TaskQueuesBuilder, }; #[test] diff --git a/src/hive/inner/builder.rs b/src/hive/inner/builder.rs index c10e102..6c7959f 100644 --- a/src/hive/inner/builder.rs +++ b/src/hive/inner/builder.rs @@ -3,7 +3,7 @@ use super::{Config, Token}; /// Private (sealed) trait depended on by `Builder` that must be implemented by builder types. pub trait BuilderConfig { /// Returns a reference to the underlying `Config`. - fn config(&mut self, token: Token) -> &mut Config; + fn config_ref(&mut self, token: Token) -> &mut Config; } /// Trait that provides `Builder` types with methods for setting configuration parameters. @@ -36,14 +36,14 @@ pub trait Builder: BuilderConfig + Sized { /// # } /// ``` fn num_threads(mut self, num: usize) -> Self { - let _ = self.config(Token).num_threads.set(Some(num)); + let _ = self.config_ref(Token).num_threads.set(Some(num)); self } /// Sets the number of worker threads to the global default value. fn with_default_num_threads(mut self) -> Self { let _ = self - .config(Token) + .config_ref(Token) .num_threads .set(super::config::DEFAULTS.lock().num_threads.get()); self @@ -73,7 +73,10 @@ pub trait Builder: BuilderConfig + Sized { /// # } /// ``` fn with_thread_per_core(mut self) -> Self { - let _ = self.config(Token).num_threads.set(Some(num_cpus::get())); + let _ = self + .config_ref(Token) + .num_threads + .set(Some(num_cpus::get())); self } @@ -104,7 +107,7 @@ pub trait Builder: BuilderConfig + Sized { /// # } /// ``` fn thread_name>(mut self, name: T) -> Self { - let _ = self.config(Token).thread_name.set(Some(name.into())); + let _ = self.config_ref(Token).thread_name.set(Some(name.into())); self } @@ -137,7 +140,7 @@ pub trait Builder: BuilderConfig + Sized { /// # } /// ``` fn thread_stack_size(mut self, size: usize) -> Self { - let _ = self.config(Token).thread_stack_size.set(Some(size)); + let _ = self.config_ref(Token).thread_stack_size.set(Some(size)); self } @@ -203,9 +206,9 @@ pub trait Builder: BuilderConfig + Sized { #[cfg(feature = "batching")] fn batch_limit(mut self, batch_limit: usize) -> Self { if batch_limit == 0 { - self.config(Token).batch_limit.set(None); + self.config_ref(Token).batch_limit.set(None); } else { - self.config(Token).batch_limit.set(Some(batch_limit)); + self.config_ref(Token).batch_limit.set(Some(batch_limit)); } self } @@ -214,7 +217,7 @@ pub trait Builder: BuilderConfig + Sized { #[cfg(feature = "batching")] fn with_default_batch_limit(mut self) -> Self { let _ = self - .config(Token) + .config_ref(Token) .batch_limit .set(super::config::DEFAULTS.lock().batch_limit.get()); self @@ -264,9 +267,9 @@ pub trait Builder: BuilderConfig + Sized { #[cfg(feature = "retry")] fn max_retries(mut self, limit: u32) -> Self { let _ = if limit == 0 { - self.config(Token).max_retries.set(None) + self.config_ref(Token).max_retries.set(None) } else { - self.config(Token).max_retries.set(Some(limit)) + self.config_ref(Token).max_retries.set(Some(limit)) }; self } @@ -311,9 +314,9 @@ pub trait Builder: BuilderConfig + Sized { #[cfg(feature = "retry")] fn retry_factor(mut self, duration: std::time::Duration) -> Self { if duration == std::time::Duration::ZERO { - let _ = self.config(Token).retry_factor.set(None); + let _ = self.config_ref(Token).retry_factor.set(None); } else { - let _ = self.config(Token).set_retry_factor_from(duration); + let _ = self.config_ref(Token).set_retry_factor_from(duration); }; self } @@ -323,11 +326,11 @@ pub trait Builder: BuilderConfig + Sized { fn with_default_retries(mut self) -> Self { let defaults = super::config::DEFAULTS.lock(); let _ = self - .config(Token) + .config_ref(Token) .max_retries .set(defaults.max_retries.get()); let _ = self - .config(Token) + .config_ref(Token) .retry_factor .set(defaults.retry_factor.get()); self diff --git a/src/hive/mod.rs b/src/hive/mod.rs index c20c1ff..7923255 100644 --- a/src/hive/mod.rs +++ b/src/hive/mod.rs @@ -389,7 +389,10 @@ mod husk; mod inner; mod outcome; -pub use self::builder::{BeeBuilder, ChannelBuilder, FullBuilder, OpenBuilder}; +pub use self::builder::{ + channel as channel_builder, open as open_builder, workstealing as workstealing_builder, +}; +pub use self::builder::{BeeBuilder, ChannelBuilder, FullBuilder, OpenBuilder, TaskQueuesBuilder}; pub use self::hive::{DefaultHive, Hive, Poisoned}; pub use self::husk::Husk; pub use self::inner::{set_config::*, Builder, ChannelTaskQueues, WorkstealingTaskQueues}; @@ -415,8 +418,8 @@ pub fn outcome_channel() -> (OutcomeSender, OutcomeReceiver) { pub mod prelude { pub use super::{ - outcome_channel, Builder, ChannelBuilder, ChannelTaskQueues, Hive, Husk, OpenBuilder, - Outcome, OutcomeBatch, OutcomeIteratorExt, OutcomeStore, Poisoned, + channel_builder, open_builder, outcome_channel, workstealing_builder, Builder, Hive, Husk, + Outcome, OutcomeBatch, OutcomeIteratorExt, OutcomeStore, Poisoned, TaskQueuesBuilder, }; } @@ -460,7 +463,7 @@ mod util { mod tests { use super::{ Builder, ChannelBuilder, ChannelTaskQueues, Hive, OpenBuilder, Outcome, OutcomeIteratorExt, - OutcomeStore, + OutcomeStore, TaskQueuesBuilder, }; use crate::barrier::IndexedBarrier; use crate::bee::stock::{Caller, OnceCaller, RefCaller, Thunk, ThunkWorker}; @@ -1779,7 +1782,7 @@ mod batching_tests { use crate::bee::DefaultQueen; use crate::hive::{ Builder, ChannelBuilder, ChannelTaskQueues, Hive, OutcomeIteratorExt, OutcomeReceiver, - OutcomeSender, + OutcomeSender, TaskQueuesBuilder, }; use std::collections::HashMap; use std::thread::{self, ThreadId}; @@ -1919,7 +1922,7 @@ mod batching_tests { mod retry_tests { use crate::bee::stock::RetryCaller; use crate::bee::{ApplyError, Context}; - use crate::hive::{Builder, ChannelBuilder, Outcome, OutcomeIteratorExt}; + use crate::hive::{Builder, ChannelBuilder, Outcome, OutcomeIteratorExt, TaskQueuesBuilder}; use std::time::{Duration, SystemTime}; fn echo_time(i: usize, ctx: &Context) -> Result> { diff --git a/src/util.rs b/src/util.rs index 86916c4..20d057c 100644 --- a/src/util.rs +++ b/src/util.rs @@ -4,7 +4,7 @@ //! creating the [`Hive`](crate::hive::Hive), submitting tasks, collecting results, and shutting //! down the `Hive` properly. use crate::bee::stock::{Caller, OnceCaller}; -use crate::hive::{Builder, ChannelBuilder, Outcome, OutcomeBatch}; +use crate::hive::{Builder, ChannelBuilder, Outcome, OutcomeBatch, TaskQueuesBuilder}; use std::fmt::Debug; /// Convenience function that creates a `Hive` with `num_threads` worker threads that execute the @@ -117,7 +117,7 @@ pub use retry::try_map_retryable; mod retry { use crate::bee::stock::RetryCaller; use crate::bee::{ApplyError, Context}; - use crate::hive::{Builder, ChannelBuilder, OutcomeBatch}; + use crate::hive::{Builder, ChannelBuilder, OutcomeBatch, TaskQueuesBuilder}; use std::fmt::Debug; /// Convenience function that creates a `Hive` with `num_threads` worker threads that execute the From a23eeb2ee0324e17a70cdf0c59f8848335b9aebe Mon Sep 17 00:00:00 2001 From: jdidion Date: Thu, 20 Feb 2025 13:36:13 -0800 Subject: [PATCH 25/67] add queue --- src/hive/builder/queue.rs | 130 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 130 insertions(+) create mode 100644 src/hive/builder/queue.rs diff --git a/src/hive/builder/queue.rs b/src/hive/builder/queue.rs new file mode 100644 index 0000000..00ef159 --- /dev/null +++ b/src/hive/builder/queue.rs @@ -0,0 +1,130 @@ +use super::FullBuilder; +use crate::bee::{CloneQueen, DefaultQueen, Queen, QueenCell, QueenMut, Worker}; +use crate::hive::{Builder, TaskQueues}; + +pub trait TaskQueuesBuilder: Builder + Default + Sized { + type TaskQueues: TaskQueues; + + fn empty() -> Self; + + /// Consumes this `Builder` and returns a new [`FullBuilder`] using the given [`Queen`] to + /// create [`Worker`]s. + fn with_queen>( + self, + queen: I, + ) -> FullBuilder>; + + /// Consumes this `Builder` and returns a new [`FullBuilder`] using a [`Queen`] created with + /// [`Q::default()`](std::default::Default) to create [`Worker`]s. + fn with_queen_default(self) -> FullBuilder> + where + Q: Queen + Default, + { + self.with_queen(Q::default()) + } + + /// Consumes this `Builder` and returns a new [`FullBuilder`] using a [`QueenMut`] created with + /// [`Q::default()`](std::default::Default) to create [`Worker`]s. + fn with_queen_mut_default(self) -> FullBuilder, Self::TaskQueues> + where + Q: QueenMut + Default, + { + self.with_queen(QueenCell::new(Q::default())) + } + + /// Consumes this `Builder` and returns a new [`FullBuilder`] with [`Worker`]s created by + /// cloning `worker`. + fn with_worker(self, worker: W) -> FullBuilder, Self::TaskQueues> + where + W: Worker + Send + Sync + Clone, + { + self.with_queen(CloneQueen::new(worker)) + } + + /// Consumes this `Builder` and returns a new [`FullBuilder`] with [`Worker`]s created using + /// [`W::default()`](std::default::Default). + fn with_worker_default(self) -> FullBuilder, Self::TaskQueues> + where + W: Worker + Send + Sync + Default, + { + self.with_queen(DefaultQueen::default()) + } +} + +pub mod channel { + use super::*; + use crate::hive::builder::{BuilderConfig, Token}; + use crate::hive::{ChannelTaskQueues, Config}; + + #[derive(Clone, Default)] + pub struct ChannelBuilder(Config); + + impl BuilderConfig for ChannelBuilder { + fn config_ref(&mut self, _: Token) -> &mut Config { + &mut self.0 + } + } + + impl TaskQueuesBuilder for ChannelBuilder { + type TaskQueues = ChannelTaskQueues; + + fn empty() -> Self { + Self(Config::empty()) + } + + /// Consumes this `Builder` and returns a new [`FullBuilder`] using the given [`Queen`] to + /// create [`Worker`]s. + fn with_queen(self, queen: I) -> FullBuilder> + where + Q: Queen, + I: Into, + { + FullBuilder::from(self.0, queen.into()) + } + } + + impl From for ChannelBuilder { + fn from(value: Config) -> Self { + Self(value) + } + } +} + +pub mod workstealing { + use super::*; + use crate::hive::builder::{BuilderConfig, Token}; + use crate::hive::{Config, WorkstealingTaskQueues}; + + #[derive(Clone, Default)] + pub struct WorkstealingBuilder(Config); + + impl BuilderConfig for WorkstealingBuilder { + fn config_ref(&mut self, _: Token) -> &mut Config { + &mut self.0 + } + } + + impl TaskQueuesBuilder for WorkstealingBuilder { + type TaskQueues = WorkstealingTaskQueues; + + fn empty() -> Self { + Self(Config::empty()) + } + + /// Consumes this `Builder` and returns a new [`FullBuilder`] using the given [`Queen`] to + /// create [`Worker`]s. + fn with_queen(self, queen: I) -> FullBuilder> + where + Q: Queen, + I: Into, + { + FullBuilder::from(self.0, queen.into()) + } + } + + impl From for WorkstealingBuilder { + fn from(value: Config) -> Self { + Self(value) + } + } +} From 8d09eef60c4b6f7ce4b110737f512a867ab31f03 Mon Sep 17 00:00:00 2001 From: jdidion Date: Thu, 20 Feb 2025 13:58:26 -0800 Subject: [PATCH 26/67] fix --- src/hive/inner/builder.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/hive/inner/builder.rs b/src/hive/inner/builder.rs index 6c7959f..866f8bf 100644 --- a/src/hive/inner/builder.rs +++ b/src/hive/inner/builder.rs @@ -180,7 +180,7 @@ pub trait Builder: BuilderConfig + Sized { /// ``` #[cfg(feature = "affinity")] fn core_affinity>(mut self, affinity: C) -> Self { - let _ = self.config(Token).affinity.set(Some(affinity.into())); + let _ = self.config_ref(Token).affinity.set(Some(affinity.into())); self } @@ -190,7 +190,7 @@ pub trait Builder: BuilderConfig + Sized { #[cfg(feature = "affinity")] fn with_default_core_affinity(mut self) -> Self { let _ = self - .config(Token) + .config_ref(Token) .affinity .set(Some(crate::hive::cores::Cores::all())); self From 81d1ff3fcbe9fb126962166c06f78f144027776f Mon Sep 17 00:00:00 2001 From: jdidion Date: Thu, 20 Feb 2025 14:17:26 -0800 Subject: [PATCH 27/67] fix --- Cargo.toml | 2 +- src/hive/mod.rs | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index de54cfd..061ea0e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,7 +39,7 @@ name = "perf" harness = false [features] -default = ["batching", "retry"] +default = ["affinity", "batching", "retry"] affinity = ["dep:core_affinity"] batching = [] retry = [] diff --git a/src/hive/mod.rs b/src/hive/mod.rs index 7923255..317a811 100644 --- a/src/hive/mod.rs +++ b/src/hive/mod.rs @@ -1736,11 +1736,11 @@ mod tests { #[cfg(all(test, feature = "affinity"))] mod affinity_tests { use crate::bee::stock::{Thunk, ThunkWorker}; - use crate::hive::{Builder, ChannelBuilder}; + use crate::hive::{Builder, TaskQueuesBuilder}; #[test] fn test_affinity() { - let hive = ChannelBuilder::empty() + let hive = crate::hive::channel_builder(false) .thread_name("affinity example") .num_threads(2) .core_affinity(0..2) @@ -1758,7 +1758,7 @@ mod affinity_tests { #[test] fn test_use_all_cores() { - let hive = ChannelBuilder::empty() + let hive = crate::hive::channel_builder(false) .thread_name("affinity example") .with_thread_per_core() .with_default_core_affinity() From 2a1d47287bad29246d0ae9b24d97da2c9743b89d Mon Sep 17 00:00:00 2001 From: jdidion Date: Thu, 20 Feb 2025 16:04:48 -0800 Subject: [PATCH 28/67] fix doc tests --- src/hive/builder/open.rs | 12 ++++++------ src/hive/hive.rs | 4 ++-- src/hive/inner/builder.rs | 28 ++++++++++++++-------------- src/hive/mod.rs | 10 +++++----- src/lib.rs | 4 ++-- 5 files changed, 29 insertions(+), 29 deletions(-) diff --git a/src/hive/builder/open.rs b/src/hive/builder/open.rs index 9985fe4..fa6981c 100644 --- a/src/hive/builder/open.rs +++ b/src/hive/builder/open.rs @@ -52,7 +52,7 @@ impl OpenBuilder { /// # Examples /// /// ``` - /// # use beekeeper::hive::{Builder, ChannelBuilder, Hive}; + /// # use beekeeper::hive::prelude::*; /// # use beekeeper::bee::{Context, QueenMut, Worker, WorkerResult}; /// /// #[derive(Debug)] @@ -103,7 +103,7 @@ impl OpenBuilder { /// } /// /// # fn main() { - /// let hive = ChannelBuilder::empty() + /// let hive = channel_builder(false) /// .num_threads(8) /// .thread_stack_size(4_000_000) /// .with_queen_mut_default::() @@ -138,7 +138,7 @@ impl OpenBuilder { /// # Examples /// /// ``` - /// # use beekeeper::hive::{Builder, ChannelBuilder, OutcomeIteratorExt}; + /// # use beekeeper::hive::prelude::*; /// # use beekeeper::bee::{Context, Worker, WorkerResult}; /// /// #[derive(Debug, Clone)] @@ -170,7 +170,7 @@ impl OpenBuilder { /// } /// /// # fn main() { - /// let hive = ChannelBuilder::empty() + /// let hive = channel_builder(false) /// .num_threads(8) /// .thread_stack_size(4_000_000) /// .with_worker(MathWorker(5isize)) @@ -196,7 +196,7 @@ impl OpenBuilder { /// # Examples /// /// ``` - /// # use beekeeper::hive::{Builder, ChannelBuilder, OutcomeIteratorExt}; + /// # use beekeeper::hive::prelude::*; /// # use beekeeper::bee::{Context, Worker, WorkerResult}; /// # use std::num::NonZeroIsize; /// @@ -222,7 +222,7 @@ impl OpenBuilder { /// } /// /// # fn main() { - /// let hive = ChannelBuilder::empty() + /// let hive = channel_builder(false) /// .num_threads(8) /// .thread_stack_size(4_000_000) /// .with_worker_default::() diff --git a/src/hive/hive.rs b/src/hive/hive.rs index cdf9ec0..a08c4ea 100644 --- a/src/hive/hive.rs +++ b/src/hive/hive.rs @@ -463,12 +463,12 @@ impl, T: TaskQueues> Hive { /// /// ``` /// use beekeeper::bee::stock::{Thunk, ThunkWorker}; - /// use beekeeper::hive::{Builder, ChannelBuilder}; + /// use beekeeper::hive::prelude::*; /// use std::thread; /// use std::time::Duration; /// /// # fn main() { - /// let hive = ChannelBuilder::empty() + /// let hive = channel_builder(false) /// .num_threads(4) /// .with_worker_default::>() /// .build(); diff --git a/src/hive/inner/builder.rs b/src/hive/inner/builder.rs index 866f8bf..856f304 100644 --- a/src/hive/inner/builder.rs +++ b/src/hive/inner/builder.rs @@ -20,10 +20,10 @@ pub trait Builder: BuilderConfig + Sized { /// /// ``` /// use beekeeper::bee::stock::{Thunk, ThunkWorker}; - /// use beekeeper::hive::{Builder, ChannelBuilder, Hive}; + /// use beekeeper::hive::prelude::*; /// /// # fn main() { - /// let hive = ChannelBuilder::empty() + /// let hive = channel_builder(false) /// .num_threads(8) /// .with_worker_default::>() /// .build(); @@ -57,10 +57,10 @@ pub trait Builder: BuilderConfig + Sized { /// /// ``` /// use beekeeper::bee::stock::{Thunk, ThunkWorker}; - /// use beekeeper::hive::{Builder, ChannelBuilder, Hive}; + /// use beekeeper::hive::prelude::*; /// /// # fn main() { - /// let hive = ChannelBuilder::empty() + /// let hive = channel_builder(false) /// .with_thread_per_core() /// .with_worker_default::>() /// .build(); @@ -89,11 +89,11 @@ pub trait Builder: BuilderConfig + Sized { /// /// ``` /// use beekeeper::bee::stock::{Thunk, ThunkWorker}; - /// use beekeeper::hive::{Builder, ChannelBuilder, Hive}; + /// use beekeeper::hive::prelude::*; /// use std::thread; /// /// # fn main() { - /// let hive = ChannelBuilder::default() + /// let hive = channel_builder(true) /// .thread_name("foo") /// .with_worker_default::>() /// .build(); @@ -123,10 +123,10 @@ pub trait Builder: BuilderConfig + Sized { /// /// ``` /// use beekeeper::bee::stock::{Thunk, ThunkWorker}; - /// use beekeeper::hive::{Builder, ChannelBuilder, Hive}; + /// use beekeeper::hive::prelude::*; /// /// # fn main() { - /// let hive = ChannelBuilder::default() + /// let hive = channel_builder(true) /// .thread_stack_size(4_000_000) /// .with_worker_default::>() /// .build(); @@ -161,10 +161,10 @@ pub trait Builder: BuilderConfig + Sized { /// /// ``` /// use beekeeper::bee::stock::{Thunk, ThunkWorker}; - /// use beekeeper::hive::{Builder, ChannelBuilder, Hive}; + /// use beekeeper::hive::prelude::*; /// /// # fn main() { - /// let hive = ChannelBuilder::empty() + /// let hive = channel_builder(false) /// .num_threads(4) /// .core_affinity(0..4) /// .with_worker_default::>() @@ -237,7 +237,7 @@ pub trait Builder: BuilderConfig + Sized { /// ``` /// use beekeeper::bee::{ApplyError, Context}; /// use beekeeper::bee::stock::RetryCaller; - /// use beekeeper::hive::{Builder, ChannelBuilder, Hive}; + /// use beekeeper::hive::prelude::*; /// use std::time; /// /// fn sometimes_fail( @@ -253,7 +253,7 @@ pub trait Builder: BuilderConfig + Sized { /// } /// /// # fn main() { - /// let hive = ChannelBuilder::default() + /// let hive = channel_builder(true) /// .max_retries(3) /// .with_worker(RetryCaller::of(sometimes_fail)) /// .build(); @@ -284,7 +284,7 @@ pub trait Builder: BuilderConfig + Sized { /// ``` /// use beekeeper::bee::{ApplyError, Context}; /// use beekeeper::bee::stock::RetryCaller; - /// use beekeeper::hive::{Builder, ChannelBuilder, Hive}; + /// use beekeeper::hive::prelude::*; /// use std::time; /// /// fn echo_time(i: usize, ctx: &Context) -> Result> { @@ -299,7 +299,7 @@ pub trait Builder: BuilderConfig + Sized { /// } /// /// # fn main() { - /// let hive = ChannelBuilder::default() + /// let hive = channel_builder(true) /// .max_retries(3) /// .retry_factor(time::Duration::from_secs(1)) /// .with_worker(RetryCaller::of(echo_time)) diff --git a/src/hive/mod.rs b/src/hive/mod.rs index 317a811..cd18120 100644 --- a/src/hive/mod.rs +++ b/src/hive/mod.rs @@ -32,11 +32,11 @@ //! the `Builder`. //! //! ``` -//! use beekeeper::hive::{Builder, ChannelBuilder}; +//! use beekeeper::hive::prelude::*; //! # type MyWorker1 = beekeeper::bee::stock::EchoWorker; //! # type MyWorker2 = beekeeper::bee::stock::EchoWorker; //! -//! let builder1 = ChannelBuilder::default(); +//! let builder1 = channel_builder(true); //! let builder2 = builder1.clone(); //! //! let hive1 = builder1.with_worker_default::().build(); @@ -79,10 +79,10 @@ //! started with no core affinity. //! //! ``` -//! use beekeeper::hive::{Builder, ChannelBuilder}; +//! use beekeeper::hive::prelude::*; //! # type MyWorker = beekeeper::bee::stock::EchoWorker; //! -//! let hive = ChannelBuilder::empty() +//! let hive = channel_builder(false) //! .num_threads(4) //! // 16 cores will be available for pinning but only 4 will be used initially //! .core_affinity(0..16) @@ -223,7 +223,7 @@ //! `Sender` is dropped. Since `Receiver` implements `Iterator`, you can use the methods of //! [`OutcomeIteratorExt`] to iterate over the outcomes for specific task IDs. //! -//! ```no_run +//! ```rust,ignore //! use beekeeper::hive::prelude::*; //! let (tx, rx) = outcome_channel(); //! let hive = ... diff --git a/src/lib.rs b/src/lib.rs index 63d64c7..c55e8bf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -151,7 +151,7 @@ //! # fn main() { //! // create a hive to process `Thunk`s - no-argument closures with the //! // same return type (`i32`) -//! let hive = ChannelBuilder::empty() +//! let hive = channel_builder(false) //! .num_threads(4) //! .thread_name("thunk_hive") //! .with_worker_default::>() @@ -278,7 +278,7 @@ //! //! # fn main() { //! // build the Hive -//! let hive = ChannelBuilder::empty() +//! let hive = channel_builder(false) //! .num_threads(4) //! .with_queen_mut_default::() //! .build(); From 5e970d11ec2ba89312d9dfdfb748a2d6b731c3d5 Mon Sep 17 00:00:00 2001 From: jdidion Date: Sat, 22 Feb 2025 14:01:56 -0800 Subject: [PATCH 29/67] update to rust 2024/1.85 --- Cargo.toml | 7 +- src/hive/builder/queue.rs | 2 +- src/hive/hive.rs | 177 ++++++------ src/hive/mod.rs | 553 ++++++++++++++++++++++++++------------ 4 files changed, 480 insertions(+), 259 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 061ea0e..db4c6a5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,8 +2,8 @@ name = "beekeeper" description = "A full-featured worker pool library for parallelizing tasks" version = "0.3.0" -edition = "2021" -rust-version = "1.83" +edition = "2024" +rust-version = "1.85" authors = ["John Didion "] repository = "https://github.com/jdidion/beekeeper" license = "MIT OR Apache-2.0" @@ -29,9 +29,10 @@ loole = { version = "0.4.0", optional = true } [dev-dependencies] divan = "0.1.17" +generic-tests = "0.1.3" itertools = "0.14.0" serial_test = "3.2.0" -#rstest = "0.22.0" +rstest = "0.22.0" stacker = "0.1.17" [[bench]] diff --git a/src/hive/builder/queue.rs b/src/hive/builder/queue.rs index 00ef159..66428e0 100644 --- a/src/hive/builder/queue.rs +++ b/src/hive/builder/queue.rs @@ -2,7 +2,7 @@ use super::FullBuilder; use crate::bee::{CloneQueen, DefaultQueen, Queen, QueenCell, QueenMut, Worker}; use crate::hive::{Builder, TaskQueues}; -pub trait TaskQueuesBuilder: Builder + Default + Sized { +pub trait TaskQueuesBuilder: Builder + Clone + Default + Sized { type TaskQueues: TaskQueues; fn empty() -> Self; diff --git a/src/hive/hive.rs b/src/hive/hive.rs index a08c4ea..27e4673 100644 --- a/src/hive/hive.rs +++ b/src/hive/hive.rs @@ -108,8 +108,12 @@ impl, T: TaskQueues> Hive { /// Sends one `input` to the `Hive` for processing and returns its ID. The [`Outcome`] of /// the task will be sent to `tx` upon completion. - pub fn apply_send>>(&self, input: W::Input, tx: S) -> TaskId { - self.shared().send_one_global(input, Some(tx.borrow())) + pub fn apply_send(&self, input: W::Input, outcome_tx: X) -> TaskId + where + X: Borrow>, + { + self.shared() + .send_one_global(input, Some(outcome_tx.borrow())) } /// Sends one `input` to the `Hive` for processing and returns its ID immediately. The @@ -123,10 +127,10 @@ impl, T: TaskQueues> Hive { /// /// This method is more efficient than [`map`](Self::map) when the input is an /// [`ExactSizeIterator`]. - pub fn swarm(&self, batch: I) -> impl Iterator> + pub fn swarm(&self, batch: B) -> impl Iterator> + use where - I: IntoIterator, - I::IntoIter: ExactSizeIterator, + B: IntoIterator, + B::IntoIter: ExactSizeIterator, { let (tx, rx) = super::outcome_channel(); let task_ids = self.shared().send_batch_global(batch, Some(&tx)); @@ -141,10 +145,10 @@ impl, T: TaskQueues> Hive { /// instead receive the `Outcome`s in the order they were submitted. This method is more /// efficient than [`map_unordered`](Self::map_unordered) when the input is an /// [`ExactSizeIterator`]. - pub fn swarm_unordered(&self, batch: I) -> impl Iterator> + pub fn swarm_unordered(&self, batch: B) -> impl Iterator> + use where - I: IntoIterator, - I::IntoIter: ExactSizeIterator, + B: IntoIterator, + B::IntoIter: ExactSizeIterator, { let (tx, rx) = super::outcome_channel(); let task_ids = self.shared().send_batch_global(batch, Some(&tx)); @@ -156,11 +160,11 @@ impl, T: TaskQueues> Hive { /// /// This method is more efficient than [`map_send`](Self::map_send) when the input is an /// [`ExactSizeIterator`]. - pub fn swarm_send(&self, batch: I, outcome_tx: S) -> Vec + pub fn swarm_send(&self, batch: B, outcome_tx: S) -> Vec where S: Borrow>, - I: IntoIterator, - I::IntoIter: ExactSizeIterator, + B: IntoIterator, + B::IntoIter: ExactSizeIterator, { self.shared() .send_batch_global(batch, Some(outcome_tx.borrow())) @@ -170,10 +174,10 @@ impl, T: TaskQueues> Hive { /// The [`Outcome`]s of the task are retained and available for later retrieval. /// /// This method is more efficient than `map_store` when the input is an [`ExactSizeIterator`]. - pub fn swarm_store(&self, batch: I) -> Vec + pub fn swarm_store(&self, batch: B) -> Vec where - I: IntoIterator, - I::IntoIter: ExactSizeIterator, + B: IntoIterator, + B::IntoIter: ExactSizeIterator, { self.shared().send_batch_global(batch, None) } @@ -182,12 +186,12 @@ impl, T: TaskQueues> Hive { /// iterator over the [`Outcome`]s in the same order as the inputs. /// /// [`swarm`](Self::swarm) should be preferred when `inputs` is an [`ExactSizeIterator`]. - pub fn map( - &self, - inputs: impl IntoIterator, - ) -> impl Iterator> { + pub fn map(&self, batch: B) -> impl Iterator> + use + where + B: IntoIterator, + { let (tx, rx) = super::outcome_channel(); - let task_ids: Vec<_> = inputs + let task_ids: Vec<_> = batch .into_iter() .map(|task| self.apply_send(task, &tx)) .collect(); @@ -200,13 +204,13 @@ impl, T: TaskQueues> Hive { /// /// [`swarm_unordered`](Self::swarm_unordered) should be preferred when `inputs` is an /// [`ExactSizeIterator`]. - pub fn map_unordered( - &self, - inputs: impl IntoIterator, - ) -> impl Iterator> { + pub fn map_unordered(&self, batch: B) -> impl Iterator> + use + where + B: IntoIterator, + { let (tx, rx) = super::outcome_channel(); // `map` is required (rather than `inspect`) because we need owned items - let task_ids: Vec<_> = inputs + let task_ids: Vec<_> = batch .into_iter() .map(|task| self.apply_send(task, &tx)) .collect(); @@ -219,14 +223,14 @@ impl, T: TaskQueues> Hive { /// /// [`swarm_send`](Self::swarm_send) should be preferred when `inputs` is an /// [`ExactSizeIterator`]. - pub fn map_send>>( - &self, - inputs: impl IntoIterator, - tx: S, - ) -> Vec { - inputs + pub fn map_send(&self, batch: B, outcome_tx: X) -> Vec + where + B: IntoIterator, + X: Borrow>, + { + batch .into_iter() - .map(|input| self.apply_send(input, tx.borrow())) + .map(|input| self.apply_send(input, outcome_tx.borrow())) .collect() } @@ -235,8 +239,11 @@ impl, T: TaskQueues> Hive { /// /// [`swarm_store`](Self::swarm_store) should be preferred when `inputs` is an /// [`ExactSizeIterator`]. - pub fn map_store(&self, inputs: impl IntoIterator) -> Vec { - inputs + pub fn map_store(&self, batch: B) -> Vec + where + B: IntoIterator, + { + batch .into_iter() .map(|input| self.apply_store(input)) .collect() @@ -245,17 +252,13 @@ impl, T: TaskQueues> Hive { /// Iterates over `items` and calls `f` with a mutable reference to a state value (initialized /// to `init`) and each item. `F` returns an input that is sent to the `Hive` for processing. /// Returns an [`OutcomeBatch`] of the outputs and the final state value. - pub fn scan( - &self, - items: impl IntoIterator, - init: St, - f: F, - ) -> (OutcomeBatch, St) + pub fn scan(&self, batch: B, init: S, f: F) -> (OutcomeBatch, S) where - F: FnMut(&mut St, I) -> W::Input, + B: IntoIterator, + F: FnMut(&mut S, I) -> W::Input, { let (tx, rx) = super::outcome_channel(); - let (task_ids, fold_value) = self.scan_send(items, &tx, init, f); + let (task_ids, fold_value) = self.scan_send(batch, &tx, init, f); drop(tx); let outcomes = rx.select_unordered(task_ids).into(); (outcomes, fold_value) @@ -265,17 +268,18 @@ impl, T: TaskQueues> Hive { /// to `init`) and each item. `F` returns an input that is sent to the `Hive` for processing, /// or an error. Returns an [`OutcomeBatch`] of the outputs, a [`Vec`] of errors, and the final /// state value. - pub fn try_scan( + pub fn try_scan( &self, - items: impl IntoIterator, - init: St, + batch: B, + init: S, mut f: F, - ) -> (OutcomeBatch, Vec, St) + ) -> (OutcomeBatch, Vec, S) where - F: FnMut(&mut St, I) -> Result, + B: IntoIterator, + F: FnMut(&mut S, I) -> Result, { let (tx, rx) = super::outcome_channel(); - let (task_ids, errors, fold_value) = items.into_iter().fold( + let (task_ids, errors, fold_value) = batch.into_iter().fold( (Vec::new(), Vec::new(), init), |(mut task_ids, mut errors, mut acc), inp| { match f(&mut acc, inp) { @@ -294,22 +298,23 @@ impl, T: TaskQueues> Hive { /// to `init`) and each item. `f` returns an input that is sent to the `Hive` for processing. /// The outputs are sent to `tx` in the order they become available. Returns a [`Vec`] of the /// task IDs and the final state value. - pub fn scan_send( + pub fn scan_send( &self, - items: impl IntoIterator, - tx: S, - init: St, + batch: B, + outcome_tx: X, + init: S, mut f: F, - ) -> (Vec, St) + ) -> (Vec, S) where - S: Borrow>, - F: FnMut(&mut St, I) -> W::Input, + B: IntoIterator, + X: Borrow>, + F: FnMut(&mut S, I) -> W::Input, { - items + batch .into_iter() .fold((Vec::new(), init), |(mut task_ids, mut acc), item| { let input = f(&mut acc, item); - task_ids.push(self.apply_send(input, tx.borrow())); + task_ids.push(self.apply_send(input, outcome_tx.borrow())); (task_ids, acc) }) } @@ -319,21 +324,24 @@ impl, T: TaskQueues> Hive { /// or an error. The outputs are sent to `tx` in the order they become available. This /// function returns the final state value and a [`Vec`] of results, where each result is /// either a task ID or an error. - pub fn try_scan_send( + pub fn try_scan_send( &self, - items: impl IntoIterator, - tx: S, - init: St, + batch: B, + outcome_tx: X, + init: S, mut f: F, - ) -> (Vec>, St) + ) -> (Vec>, S) where - S: Borrow>, - F: FnMut(&mut St, I) -> Result, + B: IntoIterator, + X: Borrow>, + F: FnMut(&mut S, I) -> Result, { - items + batch .into_iter() .fold((Vec::new(), init), |(mut results, mut acc), inp| { - results.push(f(&mut acc, inp).map(|input| self.apply_send(input, tx.borrow()))); + results.push( + f(&mut acc, inp).map(|input| self.apply_send(input, outcome_tx.borrow())), + ); (results, acc) }) } @@ -342,16 +350,12 @@ impl, T: TaskQueues> Hive { /// to `init`) and each item. `f` returns an input that is sent to the `Hive` for processing. /// This function returns the final state value and a [`Vec`] of task IDs. The [`Outcome`]s of /// the tasks are retained and available for later retrieval. - pub fn scan_store( - &self, - items: impl IntoIterator, - init: St, - mut f: F, - ) -> (Vec, St) + pub fn scan_store(&self, batch: B, init: S, mut f: F) -> (Vec, S) where - F: FnMut(&mut St, I) -> W::Input, + B: IntoIterator, + F: FnMut(&mut S, I) -> W::Input, { - items + batch .into_iter() .fold((Vec::new(), init), |(mut task_ids, mut acc), item| { let input = f(&mut acc, item); @@ -365,16 +369,17 @@ impl, T: TaskQueues> Hive { /// or an error. This function returns the final value of the state value and a [`Vec`] of /// results, where each result is either a task ID or an error. The [`Outcome`]s of the /// tasks are retained and available for later retrieval. - pub fn try_scan_store( + pub fn try_scan_store( &self, - items: impl IntoIterator, - init: St, + batch: B, + init: S, mut f: F, - ) -> (Vec>, St) + ) -> (Vec>, S) where - F: FnMut(&mut St, I) -> Result, + B: IntoIterator, + F: FnMut(&mut S, I) -> Result, { - items + batch .into_iter() .fold((Vec::new(), init), |(mut results, mut acc), item| { results.push(f(&mut acc, item).map(|input| self.apply_store(input))); @@ -1005,7 +1010,8 @@ mod tests { use super::Poisoned; use crate::bee::stock::{Caller, Thunk, ThunkWorker}; use crate::hive::{ - outcome_channel, Builder, ChannelBuilder, Outcome, OutcomeIteratorExt, TaskQueuesBuilder, + Builder, ChannelBuilder, Outcome, OutcomeIteratorExt, TaskQueuesBuilder, channel_builder, + outcome_channel, }; use std::collections::HashMap; use std::thread; @@ -1013,12 +1019,15 @@ mod tests { #[test] fn test_suspend() { - let hive = ChannelBuilder::empty() + let hive = channel_builder(false) .num_threads(4) .with_worker_default::>() .build(); - let outcome_iter = - hive.map((0..10).map(|_| Thunk::of(|| thread::sleep(Duration::from_secs(3))))); + let (tx, rx) = outcome_channel(); + hive.map_send( + (0..10).map(|_| Thunk::of(|| thread::sleep(Duration::from_secs(3)))), + tx, + ); // Allow first set of tasks to be started. thread::sleep(Duration::from_secs(1)); // There should be 4 active tasks and 6 queued tasks. @@ -1031,7 +1040,7 @@ mod tests { // Wait for remaining tasks to complete. hive.join(); assert_eq!(hive.num_tasks(), (0, 0)); - let outputs: Vec<_> = outcome_iter.into_outputs().collect(); + let outputs: Vec<_> = rx.into_outputs().collect(); assert_eq!(outputs.len(), 10); } diff --git a/src/hive/mod.rs b/src/hive/mod.rs index cd18120..34a0c60 100644 --- a/src/hive/mod.rs +++ b/src/hive/mod.rs @@ -461,9 +461,10 @@ mod util { #[cfg(test)] mod tests { + use super::inner::TaskQueues; use super::{ - Builder, ChannelBuilder, ChannelTaskQueues, Hive, OpenBuilder, Outcome, OutcomeIteratorExt, - OutcomeStore, TaskQueuesBuilder, + channel_builder, workstealing_builder, Builder, ChannelTaskQueues, Hive, Outcome, + OutcomeIteratorExt, OutcomeStore, TaskQueuesBuilder, WorkstealingTaskQueues, }; use crate::barrier::IndexedBarrier; use crate::bee::stock::{Caller, OnceCaller, RefCaller, Thunk, ThunkWorker}; @@ -473,6 +474,7 @@ mod tests { }; use crate::channel::{Message, ReceiverExt}; use crate::hive::outcome::DerefOutcomes; + use rstest::*; use std::fmt::Debug; use std::io::{self, BufRead, BufReader, Write}; use std::process::{Child, ChildStdin, ChildStdout, Command, ExitStatus, Stdio}; @@ -488,33 +490,37 @@ mod tests { const SHORT_TASK: Duration = Duration::from_secs(2); const LONG_TASK: Duration = Duration::from_secs(5); - type TWrk = ThunkWorker; - type THive = Hive>, ChannelTaskQueues>>; + type TWrk = ThunkWorker; /// Convenience function that returns a `Hive` configured with the global defaults, and the /// specified number of workers that execute `Thunk`s, i.e. closures that return `T`. - pub fn thunk_hive( - num_threads: usize, - with_defaults: bool, - ) -> THive { - let builder = if with_defaults { - ChannelBuilder::default() - } else { - ChannelBuilder::empty() - }; + pub fn thunk_hive(num_threads: usize, builder: B) -> Hive>, T> + where + I: Send + Sync + Debug + 'static, + T: TaskQueues>, + B: TaskQueuesBuilder> = T>, + { builder .num_threads(num_threads) .with_queen_default() .build() } - pub fn void_thunk_hive(num_threads: usize, with_defaults: bool) -> THive<()> { - thunk_hive(num_threads, with_defaults) + pub fn void_thunk_hive(num_threads: usize, builder: B) -> Hive>, T> + where + T: TaskQueues>, + B: TaskQueuesBuilder> = T>, + { + thunk_hive(num_threads, builder) } - #[test] - fn test_works() { - let hive = thunk_hive(TEST_TASKS, true); + #[rstest] + fn test_works(#[values(channel_builder, workstealing_builder)] builder_factory: F) + where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + let hive = thunk_hive(TEST_TASKS, builder_factory(true)); let (tx, rx) = mpsc::channel(); assert_eq!(hive.max_workers(), TEST_TASKS); assert_eq!(hive.alive_workers(), TEST_TASKS); @@ -528,9 +534,14 @@ mod tests { assert_eq!(rx.iter().take(TEST_TASKS).sum::(), TEST_TASKS); } - #[test] - fn test_grow_from_zero() { - let hive = thunk_hive::(0, true); + #[rstest] + fn test_grow_from_zero( + #[values(channel_builder, workstealing_builder)] builder_factory: F, + ) where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + let hive = thunk_hive::(0, builder_factory(true)); // check that with 0 threads no tasks are scheduled let (tx, rx) = super::outcome_channel(); let _ = hive.apply_send(Thunk::of(|| 0), &tx); @@ -546,9 +557,14 @@ mod tests { )); } - #[test] - fn test_grow_from_nonzero() { - let hive = void_thunk_hive(TEST_TASKS, false); + #[rstest] + fn test_grow_from_nonzero( + #[values(channel_builder, workstealing_builder)] builder_factory: F, + ) where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + let hive = void_thunk_hive(TEST_TASKS, builder_factory(false)); // queue some long-running tasks for _ in 0..TEST_TASKS { hive.apply_store(Thunk::of(|| thread::sleep(LONG_TASK))); @@ -569,9 +585,13 @@ mod tests { assert_eq!(husk.iter_successes().count(), total_threads); } - #[test] - fn test_suspend() { - let hive = void_thunk_hive(TEST_TASKS, false); + #[rstest] + fn test_suspend(#[values(channel_builder, workstealing_builder)] builder_factory: F) + where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + let hive = void_thunk_hive(TEST_TASKS, builder_factory(false)); // queue some long-running tasks let total_tasks = 2 * TEST_TASKS; for _ in 0..total_tasks { @@ -617,9 +637,14 @@ mod tests { } } - #[test] - fn test_suspend_with_cancelled_tasks() { - let hive: Hive<_, _> = ChannelBuilder::empty() + #[rstest] + fn test_suspend_with_cancelled_tasks( + #[values(channel_builder, workstealing_builder)] builder_factory: F, + ) where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + let hive: Hive<_, _> = builder_factory(false) .num_threads(TEST_TASKS) .with_worker_default::() .build(); @@ -635,9 +660,14 @@ mod tests { assert_eq!(hive.num_successes(), TEST_TASKS); } - #[test] - fn test_num_tasks_active() { - let hive = void_thunk_hive(TEST_TASKS, false); + #[rstest] + fn test_num_tasks_active( + #[values(channel_builder, workstealing_builder)] builder_factory: F, + ) where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + let hive = void_thunk_hive(TEST_TASKS, builder_factory(false)); for _ in 0..2 * TEST_TASKS { hive.apply_store(Thunk::of(|| loop { thread::sleep(LONG_TASK) @@ -649,9 +679,13 @@ mod tests { assert_eq!(num_threads, TEST_TASKS); } - #[test] - fn test_all_threads() { - let hive: THive<()> = ChannelBuilder::empty() + #[rstest] + fn test_all_threads(#[values(channel_builder, workstealing_builder)] builder_factory: F) + where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + let hive: Hive>, _> = builder_factory(false) .with_queen_default() .with_thread_per_core() .build(); @@ -667,9 +701,13 @@ mod tests { assert_eq!(num_threads, num_threads); } - #[test] - fn test_panic() { - let hive = thunk_hive(TEST_TASKS, true); + #[rstest] + fn test_panic(#[values(channel_builder, workstealing_builder)] builder_factory: F) + where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + let hive = thunk_hive(TEST_TASKS, builder_factory(true)); let (tx, _) = super::outcome_channel(); // Panic all the existing threads. for _ in 0..TEST_TASKS { @@ -682,9 +720,13 @@ mod tests { assert_eq!(husk.num_panics(), TEST_TASKS); } - #[test] - fn test_catch_panic() { - let hive: Hive<_, _> = ChannelBuilder::empty() + #[rstest] + fn test_catch_panic(#[values(channel_builder, workstealing_builder)] builder_factory: F) + where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + let hive: Hive<_, _> = builder_factory(false) .with_worker(RefCaller::of(|_: &u8| -> Result { panic!("intentional panic") })) @@ -704,9 +746,14 @@ mod tests { } } - #[test] - fn test_should_not_panic_on_drop_if_subtasks_panic_after_drop() { - let hive = void_thunk_hive(TEST_TASKS, false); + #[rstest] + fn test_should_not_panic_on_drop_if_subtasks_panic_after_drop( + #[values(channel_builder, workstealing_builder)] builder_factory: F, + ) where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + let hive = void_thunk_hive(TEST_TASKS, builder_factory(false)); let waiter = Arc::new(Barrier::new(TEST_TASKS + 1)); let waiter_count = Arc::new(AtomicUsize::new(0)); @@ -733,11 +780,16 @@ mod tests { waiter.wait(); } - #[test] - fn test_massive_task_creation() { + #[rstest] + fn test_massive_task_creation( + #[values(channel_builder, workstealing_builder)] builder_factory: F, + ) where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { let test_tasks = 4_200_000; - let hive = thunk_hive(TEST_TASKS, true); + let hive = thunk_hive(TEST_TASKS, builder_factory(true)); let b0 = IndexedBarrier::new(TEST_TASKS); let b1 = IndexedBarrier::new(TEST_TASKS); @@ -771,10 +823,14 @@ mod tests { ); } - #[test] - fn test_name() { + #[rstest] + fn test_name(#[values(channel_builder, workstealing_builder)] builder_factory: F) + where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { let name = "test"; - let hive: THive<()> = ChannelBuilder::empty() + let hive: Hive>, B::TaskQueues<_>> = builder_factory(false) .with_queen_default() .thread_name(name.to_owned()) .num_threads(2) @@ -803,11 +859,15 @@ mod tests { } } - #[test] - fn test_stack_size() { + #[rstest] + fn test_stack_size(#[values(channel_builder, workstealing_builder)] builder_factory: F) + where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { let stack_size = 4_000_000; - let hive: THive = ChannelBuilder::empty() + let hive: Hive>, B::TaskQueues<_>> = builder_factory(false) .with_queen_default() .num_threads(1) .thread_stack_size(stack_size) @@ -825,16 +885,20 @@ mod tests { assert!(actual_stack_size < (stack_size as f64 * 1.01)); } - #[test] - fn test_debug() { - let hive = void_thunk_hive(4, true); + #[rstest] + fn test_debug(#[values(channel_builder, workstealing_builder)] builder_factory: F) + where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + let hive = void_thunk_hive(4, builder_factory(true)); let debug = format!("{:?}", hive); assert_eq!( debug, "Hive { shared: Shared { name: None, num_threads: 4, num_tasks_queued: 0, num_tasks_active: 0 } }" ); - let hive: THive = ChannelBuilder::empty() + let hive: Hive>, B::TaskQueues<_>> = builder_factory(false) .with_queen_default() .thread_name("hello") .num_threads(4) @@ -845,7 +909,7 @@ mod tests { "Hive { shared: Shared { name: \"hello\", num_threads: 4, num_tasks_queued: 0, num_tasks_active: 0 } }" ); - let hive = thunk_hive(4, true); + let hive = thunk_hive(4, builder_factory(true)); hive.apply_store(Thunk::of(|| thread::sleep(LONG_TASK))); thread::sleep(ONE_SEC); let debug = format!("{:?}", hive); @@ -855,9 +919,13 @@ mod tests { ); } - #[test] - fn test_repeated_join() { - let hive: THive<()> = ChannelBuilder::empty() + #[rstest] + fn test_repeated_join(#[values(channel_builder, workstealing_builder)] builder_factory: F) + where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + let hive: Hive>, B::TaskQueues<_>> = builder_factory(false) .with_queen_default() .thread_name("repeated join test") .num_threads(8) @@ -887,8 +955,12 @@ mod tests { assert_eq!(84, test_count.load(Ordering::Relaxed)); } - #[test] - fn test_multi_join() { + #[rstest] + fn test_multi_join(#[values(channel_builder, workstealing_builder)] builder_factory: F) + where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { // Toggle the following lines to debug the deadlock // fn error(_s: String) { // use ::std::io::Write; @@ -899,12 +971,12 @@ mod tests { // .expect("Failed to write to stderr"); // } - let hive0: THive<()> = ChannelBuilder::empty() + let hive0: Hive>, B::TaskQueues<_>> = builder_factory(false) .with_queen_default() .thread_name("multi join pool0") .num_threads(4) .build(); - let hive1: THive<()> = ChannelBuilder::empty() + let hive1: Hive>, B::TaskQueues<_>> = builder_factory(false) .with_queen_default() .thread_name("multi join pool1") .num_threads(4) @@ -940,23 +1012,31 @@ mod tests { assert_eq!(rx.into_iter().sum::(), (0..8).sum()); } - #[test] - fn test_empty_hive() { + #[rstest] + fn test_empty_hive(#[values(channel_builder, workstealing_builder)] builder_factory: F) + where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { // Joining an empty hive must return imminently // TODO: run this in a thread and kill it after a timeout to prevent hanging the tests - let hive = void_thunk_hive(4, true); + let hive = void_thunk_hive(4, builder_factory(true)); hive.join(); } - #[test] - fn test_no_fun_or_joy() { + #[rstest] + fn test_no_fun_or_joy(#[values(channel_builder, workstealing_builder)] builder_factory: F) + where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { // What happens when you keep adding tasks after a join fn sleepy_function() { thread::sleep(LONG_TASK); } - let hive: THive<()> = ChannelBuilder::empty() + let hive: Hive>, B::TaskQueues<_>> = builder_factory(false) .with_queen_default() .thread_name("no fun or joy") .num_threads(8) @@ -976,9 +1056,13 @@ mod tests { hive.join(); } - #[test] - fn test_map() { - let hive = thunk_hive::(2, false); + #[rstest] + fn test_map(#[values(channel_builder, workstealing_builder)] builder_factory: F) + where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + let hive = thunk_hive::(2, builder_factory(false)); let outputs: Vec<_> = hive .map((0..10u8).map(|i| { Thunk::of(move || { @@ -991,9 +1075,13 @@ mod tests { assert_eq!(outputs, (0..10).collect::>()) } - #[test] - fn test_map_unordered() { - let hive = thunk_hive::(8, false); + #[rstest] + fn test_map_unordered(#[values(channel_builder, workstealing_builder)] builder_factory: F) + where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + let hive = thunk_hive::(8, builder_factory(false)); let outputs: Vec<_> = hive .map_unordered((0..8u8).map(|i| { Thunk::of(move || { @@ -1006,9 +1094,13 @@ mod tests { assert_eq!(outputs, (0..8).rev().collect::>()) } - #[test] - fn test_map_send() { - let hive = thunk_hive::(8, false); + #[rstest] + fn test_map_send(#[values(channel_builder, workstealing_builder)] builder_factory: F) + where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + let hive = thunk_hive::(8, builder_factory(false)); let (tx, rx) = super::outcome_channel(); let mut task_ids = hive.map_send( (0..8u8).map(|i| { @@ -1032,9 +1124,13 @@ mod tests { assert_eq!(task_ids, outcome_task_ids); } - #[test] - fn test_map_store() { - let mut hive = thunk_hive::(8, false); + #[rstest] + fn test_map_store(#[values(channel_builder, workstealing_builder)] builder_factory: F) + where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + let mut hive = thunk_hive::(8, builder_factory(false)); let mut task_ids = hive.map_store((0..8u8).map(|i| { Thunk::of(move || { thread::sleep(Duration::from_millis((8 - i as u64) * 100)); @@ -1056,9 +1152,13 @@ mod tests { assert_eq!(task_ids, outcome_task_ids); } - #[test] - fn test_swarm() { - let hive = thunk_hive::(2, false); + #[rstest] + fn test_swarm(#[values(channel_builder, workstealing_builder)] builder_factory: F) + where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + let hive = thunk_hive::(2, builder_factory(false)); let outputs: Vec<_> = hive .swarm((0..10u8).map(|i| { Thunk::of(move || { @@ -1071,9 +1171,14 @@ mod tests { assert_eq!(outputs, (0..10).collect::>()) } - #[test] - fn test_swarm_unordered() { - let hive = thunk_hive::(8, false); + #[rstest] + fn test_swarm_unordered( + #[values(channel_builder, workstealing_builder)] builder_factory: F, + ) where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + let hive = thunk_hive::(8, builder_factory(false)); let outputs: Vec<_> = hive .swarm_unordered((0..8u8).map(|i| { Thunk::of(move || { @@ -1086,9 +1191,13 @@ mod tests { assert_eq!(outputs, (0..8).rev().collect::>()) } - #[test] - fn test_swarm_send() { - let hive = thunk_hive::(8, false); + #[rstest] + fn test_swarm_send(#[values(channel_builder, workstealing_builder)] builder_factory: F) + where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + let hive = thunk_hive::(8, builder_factory(false)); #[cfg(feature = "batching")] assert_eq!(hive.worker_batch_limit(), 0); let (tx, rx) = super::outcome_channel(); @@ -1114,9 +1223,13 @@ mod tests { assert_eq!(task_ids, outcome_task_ids); } - #[test] - fn test_swarm_store() { - let mut hive = thunk_hive::(8, false); + #[rstest] + fn test_swarm_store(#[values(channel_builder, workstealing_builder)] builder_factory: F) + where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + let mut hive = thunk_hive::(8, builder_factory(false)); let mut task_ids = hive.swarm_store((0..8u8).map(|i| { Thunk::of(move || { thread::sleep(Duration::from_millis((8 - i as u64) * 100)); @@ -1138,9 +1251,13 @@ mod tests { assert_eq!(task_ids, outcome_task_ids); } - #[test] - fn test_scan() { - let hive = ChannelBuilder::empty() + #[rstest] + fn test_scan(#[values(channel_builder, workstealing_builder)] builder_factory: F) + where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + let hive = builder_factory(false) .with_worker(Caller::of(|i| i * i)) .num_threads(4) .build(); @@ -1164,9 +1281,13 @@ mod tests { ); } - #[test] - fn test_scan_send() { - let hive = ChannelBuilder::empty() + #[rstest] + fn test_scan_send(#[values(channel_builder, workstealing_builder)] builder_factory: F) + where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + let hive = builder_factory(false) .with_worker(Caller::of(|i| i * i)) .num_threads(4) .build(); @@ -1200,9 +1321,13 @@ mod tests { assert_eq!(task_ids, outcome_task_ids); } - #[test] - fn test_try_scan_send() { - let hive = ChannelBuilder::empty() + #[rstest] + fn test_try_scan_send(#[values(channel_builder, workstealing_builder)] builder_factory: F) + where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + let hive = builder_factory(false) .with_worker(Caller::of(|i| i * i)) .num_threads(4) .build(); @@ -1237,10 +1362,15 @@ mod tests { assert_eq!(task_ids, outcome_task_ids); } - #[test] + #[rstest] #[should_panic] - fn test_try_scan_send_fail() { - let hive = ChannelBuilder::empty() + fn test_try_scan_send_fail( + #[values(channel_builder, workstealing_builder)] builder_factory: F, + ) where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + let hive = builder_factory(false) .with_worker(OnceCaller::of(|i: i32| Ok::<_, String>(i * i))) .num_threads(4) .build(); @@ -1253,9 +1383,13 @@ mod tests { .collect::>(); } - #[test] - fn test_scan_store() { - let mut hive = ChannelBuilder::empty() + #[rstest] + fn test_scan_store(#[values(channel_builder, workstealing_builder)] builder_factory: F) + where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + let mut hive = builder_factory(false) .with_worker(Caller::of(|i| i * i)) .num_threads(4) .build(); @@ -1289,9 +1423,14 @@ mod tests { assert_eq!(task_ids, outcome_task_ids); } - #[test] - fn test_try_scan_store() { - let mut hive = ChannelBuilder::empty() + #[rstest] + fn test_try_scan_store( + #[values(channel_builder, workstealing_builder)] builder_factory: F, + ) where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + let mut hive = builder_factory(false) .with_worker(Caller::of(|i| i * i)) .num_threads(4) .build(); @@ -1326,10 +1465,15 @@ mod tests { assert_eq!(task_ids, outcome_task_ids); } - #[test] + #[rstest] #[should_panic] - fn test_try_scan_store_fail() { - let hive = ChannelBuilder::empty() + fn test_try_scan_store_fail( + #[values(channel_builder, workstealing_builder)] builder_factory: F, + ) where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + let hive = builder_factory(false) .with_worker(OnceCaller::of(|i: i32| Ok::(i * i))) .num_threads(4) .build(); @@ -1341,9 +1485,13 @@ mod tests { .collect::>(); } - #[test] - fn test_husk() { - let hive1 = thunk_hive::(8, false); + #[rstest] + fn test_husk(#[values(channel_builder, workstealing_builder)] builder_factory: F) + where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + let hive1 = thunk_hive::(8, builder_factory(false)); let task_ids = hive1.map_store((0..8u8).map(|i| Thunk::of(move || i))); hive1.join(); let mut husk1 = hive1.try_into_husk(false).unwrap(); @@ -1399,9 +1547,13 @@ mod tests { assert_eq!(outputs1, outputs3); } - #[test] - fn test_clone() { - let hive: THive<()> = ChannelBuilder::empty() + #[rstest] + fn test_clone(#[values(channel_builder, workstealing_builder)] builder_factory: F) + where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + let hive: Hive>, B::TaskQueues<_>> = builder_factory(false) .with_worker_default() .thread_name("clone example") .num_threads(2) @@ -1462,19 +1614,29 @@ mod tests { ); } - #[test] - fn test_send() { + #[rstest] + fn test_channel_hive_send() { + fn assert_send() {} + assert_send::>, ChannelTaskQueues<_>>>(); + } + + #[rstest] + fn test_workstealing_hive_send() { fn assert_send() {} - assert_send::>(); + assert_send::>, WorkstealingTaskQueues<_>>>(); } - #[test] - fn test_cloned_eq() { - let a = thunk_hive::<()>(2, true); + #[rstest] + fn test_cloned_eq(#[values(channel_builder, workstealing_builder)] builder_factory: F) + where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + let a = thunk_hive::<(), _, _>(2, builder_factory(true)); assert_eq!(a, a.clone()); } - #[test] + #[rstest] /// When a thread joins on a pool, it blocks until all tasks have completed. If a second thread /// adds tasks to the pool and then joins before all the tasks have completed, both threads /// will wait for all tasks to complete. However, as soon as all tasks have completed, all @@ -1484,14 +1646,18 @@ mod tests { /// /// In this example, this means the waiting threads will exit the join in groups of four /// because the waiter pool has four processes. - fn test_join_wavesurfer() { + fn test_join_wavesurfer( + #[values(channel_builder, workstealing_builder)] builder_factory: F, + ) where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { let n_waves = 4; let n_workers = 4; let (tx, rx) = mpsc::channel(); - let builder = OpenBuilder::empty() + let builder = builder_factory(false) .num_threads(n_workers) - .thread_name("join wavesurfer") - .with_channel_queues(); + .thread_name("join wavesurfer"); let waiter_hive = builder .clone() .with_worker_default::>() @@ -1574,10 +1740,14 @@ mod tests { // cargo-llvm-cov doesn't yet support doctests in stable, so we need to duplicate them in // unit tests to get coverage - #[test] - fn doctest_lib_2() { + #[rstest] + fn doctest_lib_2(#[values(channel_builder, workstealing_builder)] builder_factory: F) + where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { // create a hive to process `Thunk`s - no-argument closures with the same return type (`i32`) - let hive: THive = ChannelBuilder::empty() + let hive: Hive>, B::TaskQueues<_>> = builder_factory(false) .with_worker_default() .num_threads(4) .thread_name("thunk_hive") @@ -1597,8 +1767,12 @@ mod tests { assert_eq!(-285, outputs2.into_iter().sum()); } - #[test] - fn doctest_lib_3() { + #[rstest] + fn doctest_lib_3(#[values(channel_builder, workstealing_builder)] builder_factory: F) + where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { #[derive(Debug)] struct CatWorker { stdin: ChildStdin, @@ -1681,7 +1855,7 @@ mod tests { } // build the Hive - let hive = ChannelBuilder::empty() + let hive = builder_factory(false) .with_queen_mut_default::() .num_threads(4) .build(); @@ -1736,11 +1910,16 @@ mod tests { #[cfg(all(test, feature = "affinity"))] mod affinity_tests { use crate::bee::stock::{Thunk, ThunkWorker}; - use crate::hive::{Builder, TaskQueuesBuilder}; - - #[test] - fn test_affinity() { - let hive = crate::hive::channel_builder(false) + use crate::hive::{channel_builder, workstealing_builder, Builder, TaskQueuesBuilder}; + use rstest::*; + + #[rstest] + fn test_affinity(#[values(channel_builder, workstealing_builder)] builder_factory: F) + where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + let hive = builder_factory(false) .thread_name("affinity example") .num_threads(2) .core_affinity(0..2) @@ -1756,7 +1935,7 @@ mod affinity_tests { })); } - #[test] + #[rstest] fn test_use_all_cores() { let hive = crate::hive::channel_builder(false) .thread_name("affinity example") @@ -1781,15 +1960,16 @@ mod batching_tests { use crate::bee::stock::{Thunk, ThunkWorker}; use crate::bee::DefaultQueen; use crate::hive::{ - Builder, ChannelBuilder, ChannelTaskQueues, Hive, OutcomeIteratorExt, OutcomeReceiver, - OutcomeSender, TaskQueuesBuilder, + channel_builder, workstealing_builder, Builder, Hive, OutcomeIteratorExt, OutcomeReceiver, + OutcomeSender, TaskQueues, TaskQueuesBuilder, }; + use rstest::*; use std::collections::HashMap; use std::thread::{self, ThreadId}; use std::time::Duration; - fn launch_tasks( - hive: &Hive>, ChannelTaskQueues>>, + fn launch_tasks>>( + hive: &Hive>, T>, num_threads: usize, num_tasks_per_thread: usize, barrier: &IndexedBarrier, @@ -1836,8 +2016,8 @@ mod batching_tests { }) } - fn run_test( - hive: &Hive>, ChannelTaskQueues>>, + fn run_test>>( + hive: &Hive>, T>, num_threads: usize, batch_limit: usize, ) { @@ -1858,11 +2038,15 @@ mod batching_tests { .all(|&count| count == tasks_per_thread)); } - #[test] - fn test_batching() { + #[rstest] + fn test_batching(#[values(channel_builder, workstealing_builder)] builder_factory: F) + where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { const NUM_THREADS: usize = 4; const BATCH_LIMIT: usize = 24; - let hive = ChannelBuilder::empty() + let hive = builder_factory(false) .with_worker_default() .num_threads(NUM_THREADS) .batch_limit(BATCH_LIMIT) @@ -1870,13 +2054,18 @@ mod batching_tests { run_test(&hive, NUM_THREADS, BATCH_LIMIT); } - #[test] - fn test_set_batch_limit() { + #[rstest] + fn test_set_batch_limit( + #[values(channel_builder, workstealing_builder)] builder_factory: F, + ) where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { const NUM_THREADS: usize = 4; const BATCH_LIMIT_0: usize = 10; const BATCH_LIMIT_1: usize = 20; const BATCH_LIMIT_2: usize = 50; - let hive = ChannelBuilder::empty() + let hive = builder_factory(false) .with_worker_default() .num_threads(NUM_THREADS) .batch_limit(BATCH_LIMIT_0) @@ -1890,13 +2079,18 @@ mod batching_tests { run_test(&hive, NUM_THREADS, BATCH_LIMIT_1); } - #[test] - fn test_shrink_batch_limit() { + #[rstest] + fn test_shrink_batch_limit( + #[values(channel_builder, workstealing_builder)] builder_factory: F, + ) where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { const NUM_THREADS: usize = 4; const NUM_TASKS_PER_THREAD: usize = 125; const BATCH_LIMIT_0: usize = 100; const BATCH_LIMIT_1: usize = 10; - let hive = ChannelBuilder::empty() + let hive = builder_factory(false) .with_worker_default() .num_threads(NUM_THREADS) .batch_limit(BATCH_LIMIT_0) @@ -1922,7 +2116,11 @@ mod batching_tests { mod retry_tests { use crate::bee::stock::RetryCaller; use crate::bee::{ApplyError, Context}; - use crate::hive::{Builder, ChannelBuilder, Outcome, OutcomeIteratorExt, TaskQueuesBuilder}; + use crate::hive::{ + channel_builder, workstealing_builder, Builder, Outcome, OutcomeIteratorExt, + TaskQueuesBuilder, + }; + use rstest::*; use std::time::{Duration, SystemTime}; fn echo_time(i: usize, ctx: &Context) -> Result> { @@ -1939,9 +2137,13 @@ mod retry_tests { } } - #[test] - fn test_retries() { - let hive = ChannelBuilder::empty() + #[rstest] + fn test_retries(#[values(channel_builder, workstealing_builder)] builder_factory: F) + where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + let hive = builder_factory(false) .with_worker(RetryCaller::of(echo_time)) .with_thread_per_core() .max_retries(3) @@ -1952,8 +2154,12 @@ mod retry_tests { assert_eq!(v.unwrap().len(), 10); } - #[test] - fn test_retries_fail() { + #[rstest] + fn test_retries_fail(#[values(channel_builder, workstealing_builder)] builder_factory: F) + where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { fn sometimes_fail( i: usize, _: &Context, @@ -1972,7 +2178,7 @@ mod retry_tests { } } - let hive = ChannelBuilder::empty() + let hive = builder_factory(false) .with_worker(RetryCaller::of(sometimes_fail)) .with_thread_per_core() .max_retries(3) @@ -1993,9 +2199,14 @@ mod retry_tests { assert_eq!(not_retried, 3); } - #[test] - fn test_disable_retries() { - let hive = ChannelBuilder::empty() + #[rstest] + fn test_disable_retries( + #[values(channel_builder, workstealing_builder)] builder_factory: F, + ) where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + let hive = builder_factory(false) .with_worker(RetryCaller::of(echo_time)) .with_thread_per_core() .with_no_retries() From 59ea7f05a66ba758832271928f63e8177fd66990 Mon Sep 17 00:00:00 2001 From: jdidion Date: Sat, 22 Feb 2025 14:03:04 -0800 Subject: [PATCH 30/67] fix test --- src/hive/mod.rs | 71 +++++++++++++++++++++++++++---------------------- 1 file changed, 39 insertions(+), 32 deletions(-) diff --git a/src/hive/mod.rs b/src/hive/mod.rs index 34a0c60..68f10a0 100644 --- a/src/hive/mod.rs +++ b/src/hive/mod.rs @@ -389,19 +389,19 @@ mod husk; mod inner; mod outcome; +pub use self::builder::{BeeBuilder, ChannelBuilder, FullBuilder, OpenBuilder, TaskQueuesBuilder}; pub use self::builder::{ channel as channel_builder, open as open_builder, workstealing as workstealing_builder, }; -pub use self::builder::{BeeBuilder, ChannelBuilder, FullBuilder, OpenBuilder, TaskQueuesBuilder}; pub use self::hive::{DefaultHive, Hive, Poisoned}; pub use self::husk::Husk; -pub use self::inner::{set_config::*, Builder, ChannelTaskQueues, WorkstealingTaskQueues}; +pub use self::inner::{Builder, ChannelTaskQueues, WorkstealingTaskQueues, set_config::*}; pub use self::outcome::{Outcome, OutcomeBatch, OutcomeIteratorExt, OutcomeStore}; use self::inner::{Config, Shared, Task, TaskQueues, WorkerQueues}; use self::outcome::{DerefOutcomes, OutcomeQueue, OwnedOutcomes}; use crate::bee::Worker; -use crate::channel::{channel, Receiver, Sender}; +use crate::channel::{Receiver, Sender, channel}; use std::io::Error as SpawnError; /// Sender type for channel used to send task outcomes. @@ -418,8 +418,8 @@ pub fn outcome_channel() -> (OutcomeSender, OutcomeReceiver) { pub mod prelude { pub use super::{ - channel_builder, open_builder, outcome_channel, workstealing_builder, Builder, Hive, Husk, - Outcome, OutcomeBatch, OutcomeIteratorExt, OutcomeStore, Poisoned, TaskQueuesBuilder, + Builder, Hive, Husk, Outcome, OutcomeBatch, OutcomeIteratorExt, OutcomeStore, Poisoned, + TaskQueuesBuilder, channel_builder, open_builder, outcome_channel, workstealing_builder, }; } @@ -463,8 +463,8 @@ mod util { mod tests { use super::inner::TaskQueues; use super::{ - channel_builder, workstealing_builder, Builder, ChannelTaskQueues, Hive, Outcome, - OutcomeIteratorExt, OutcomeStore, TaskQueuesBuilder, WorkstealingTaskQueues, + Builder, ChannelTaskQueues, Hive, Outcome, OutcomeIteratorExt, OutcomeStore, + TaskQueuesBuilder, WorkstealingTaskQueues, channel_builder, workstealing_builder, }; use crate::barrier::IndexedBarrier; use crate::bee::stock::{Caller, OnceCaller, RefCaller, Thunk, ThunkWorker}; @@ -479,8 +479,9 @@ mod tests { use std::io::{self, BufRead, BufReader, Write}; use std::process::{Child, ChildStdin, ChildStdout, Command, ExitStatus, Stdio}; use std::sync::{ + Arc, Barrier, atomic::{AtomicUsize, Ordering}, - mpsc, Arc, Barrier, + mpsc, }; use std::thread; use std::time::Duration; @@ -669,8 +670,10 @@ mod tests { { let hive = void_thunk_hive(TEST_TASKS, builder_factory(false)); for _ in 0..2 * TEST_TASKS { - hive.apply_store(Thunk::of(|| loop { - thread::sleep(LONG_TASK) + hive.apply_store(Thunk::of(|| { + loop { + thread::sleep(LONG_TASK) + } })); } thread::sleep(ONE_SEC); @@ -691,14 +694,16 @@ mod tests { .build(); let num_threads = num_cpus::get(); for _ in 0..num_threads { - hive.apply_store(Thunk::of(|| loop { - thread::sleep(LONG_TASK) + hive.apply_store(Thunk::of(|| { + loop { + thread::sleep(LONG_TASK) + } })); } thread::sleep(ONE_SEC); assert_eq!(hive.num_tasks().1, num_threads as u64); - let num_threads = hive.max_workers(); - assert_eq!(num_threads, num_threads); + let max_workers = hive.max_workers(); + assert_eq!(num_threads, max_workers); } #[rstest] @@ -894,9 +899,9 @@ mod tests { let hive = void_thunk_hive(4, builder_factory(true)); let debug = format!("{:?}", hive); assert_eq!( - debug, - "Hive { shared: Shared { name: None, num_threads: 4, num_tasks_queued: 0, num_tasks_active: 0 } }" - ); + debug, + "Hive { shared: Shared { name: None, num_threads: 4, num_tasks_queued: 0, num_tasks_active: 0 } }" + ); let hive: Hive>, B::TaskQueues<_>> = builder_factory(false) .with_queen_default() @@ -905,18 +910,18 @@ mod tests { .build(); let debug = format!("{:?}", hive); assert_eq!( - debug, - "Hive { shared: Shared { name: \"hello\", num_threads: 4, num_tasks_queued: 0, num_tasks_active: 0 } }" - ); + debug, + "Hive { shared: Shared { name: \"hello\", num_threads: 4, num_tasks_queued: 0, num_tasks_active: 0 } }" + ); let hive = thunk_hive(4, builder_factory(true)); hive.apply_store(Thunk::of(|| thread::sleep(LONG_TASK))); thread::sleep(ONE_SEC); let debug = format!("{:?}", hive); assert_eq!( - debug, - "Hive { shared: Shared { name: None, num_threads: 4, num_tasks_queued: 0, num_tasks_active: 1 } }" - ); + debug, + "Hive { shared: Shared { name: None, num_threads: 4, num_tasks_queued: 0, num_tasks_active: 1 } }" + ); } #[rstest] @@ -1910,7 +1915,7 @@ mod tests { #[cfg(all(test, feature = "affinity"))] mod affinity_tests { use crate::bee::stock::{Thunk, ThunkWorker}; - use crate::hive::{channel_builder, workstealing_builder, Builder, TaskQueuesBuilder}; + use crate::hive::{Builder, TaskQueuesBuilder, channel_builder, workstealing_builder}; use rstest::*; #[rstest] @@ -1957,11 +1962,11 @@ mod affinity_tests { #[cfg(all(test, feature = "batching"))] mod batching_tests { use crate::barrier::IndexedBarrier; - use crate::bee::stock::{Thunk, ThunkWorker}; use crate::bee::DefaultQueen; + use crate::bee::stock::{Thunk, ThunkWorker}; use crate::hive::{ - channel_builder, workstealing_builder, Builder, Hive, OutcomeIteratorExt, OutcomeReceiver, - OutcomeSender, TaskQueues, TaskQueuesBuilder, + Builder, Hive, OutcomeIteratorExt, OutcomeReceiver, OutcomeSender, TaskQueues, + TaskQueuesBuilder, channel_builder, workstealing_builder, }; use rstest::*; use std::collections::HashMap; @@ -2033,9 +2038,11 @@ mod batching_tests { hive.join(); let thread_counts = count_thread_ids(rx, task_ids); assert_eq!(thread_counts.len(), num_threads); - assert!(thread_counts - .values() - .all(|&count| count == tasks_per_thread)); + assert!( + thread_counts + .values() + .all(|&count| count == tasks_per_thread) + ); } #[rstest] @@ -2117,8 +2124,8 @@ mod retry_tests { use crate::bee::stock::RetryCaller; use crate::bee::{ApplyError, Context}; use crate::hive::{ - channel_builder, workstealing_builder, Builder, Outcome, OutcomeIteratorExt, - TaskQueuesBuilder, + Builder, Outcome, OutcomeIteratorExt, TaskQueuesBuilder, channel_builder, + workstealing_builder, }; use rstest::*; use std::time::{Duration, SystemTime}; From 8f79aa3e409b78aebcc7ef9e4e2bcfcc71981672 Mon Sep 17 00:00:00 2001 From: jdidion Date: Sat, 22 Feb 2025 15:14:45 -0800 Subject: [PATCH 31/67] add workstealng tests --- Cargo.toml | 2 +- src/hive/inner/queue/channel.rs | 2 +- src/hive/inner/queue/workstealing.rs | 24 +++++++-- src/hive/inner/shared.rs | 2 +- src/hive/mod.rs | 80 +++++++++++++++++++--------- 5 files changed, 78 insertions(+), 32 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index db4c6a5..8fba746 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,7 +40,7 @@ name = "perf" harness = false [features] -default = ["affinity", "batching", "retry"] +default = [] affinity = ["dep:core_affinity"] batching = [] retry = [] diff --git a/src/hive/inner/queue/channel.rs b/src/hive/inner/queue/channel.rs index bd354f5..d8eef27 100644 --- a/src/hive/inner/queue/channel.rs +++ b/src/hive/inner/queue/channel.rs @@ -10,7 +10,7 @@ use std::sync::Arc; use std::time::Duration; // time to wait when polling the global queue -const RECV_TIMEOUT: Duration = Duration::from_secs(1); +const RECV_TIMEOUT: Duration = Duration::from_millis(100); /// Type alias for the input task channel sender type TaskSender = crossbeam_channel::Sender>; diff --git a/src/hive/inner/queue/workstealing.rs b/src/hive/inner/queue/workstealing.rs index a05c77c..8810169 100644 --- a/src/hive/inner/queue/workstealing.rs +++ b/src/hive/inner/queue/workstealing.rs @@ -14,6 +14,11 @@ use parking_lot::RwLock; use rand::prelude::*; use std::ops::Deref; use std::sync::Arc; +use std::thread; +use std::time::Duration; + +/// Time to wait after trying to pop and finding all queues empty. +const EMPTY_DELAY: Duration = Duration::from_millis(100); pub struct WorkstealingTaskQueues { global: Arc>, @@ -40,7 +45,7 @@ impl TaskQueues for WorkstealingTaskQueues { fn update_for_threads(&self, start_index: usize, end_index: usize, config: &Config) { let local_queues = self.local.read(); - assert!(local_queues.len() > end_index); + assert!(local_queues.len() >= end_index); (start_index..end_index).for_each(|thread_index| local_queues[thread_index].update(config)); } @@ -103,7 +108,7 @@ impl GlobalQueue { } /// Tries to steal a task from a random worker using its `Stealer`. - fn try_steal(&self) -> Option> { + fn try_steal_from_worker(&self) -> Result, PopTaskError> { let stealers = self.stealers.read(); let n = stealers.len(); // randomize the stealing order, to prevent always stealing from the same thread @@ -111,6 +116,14 @@ impl GlobalQueue { .take(n) .filter_map(|i| stealers[i].steal().success()) .next() + .ok_or_else(|| { + if self.is_closed() && self.queue.is_empty() { + PopTaskError::Closed + } else { + thread::park_timeout(EMPTY_DELAY); + PopTaskError::Empty + } + }) } /// Tries to steal a task from the global queue, otherwise tries to steal a task from another @@ -119,7 +132,7 @@ impl GlobalQueue { if let Some(task) = self.queue.steal().success() { Ok(task) } else { - self.try_steal().ok_or(PopTaskError::Empty) + self.try_steal_from_worker() } } @@ -137,9 +150,10 @@ impl GlobalQueue { .steal_batch_with_limit_and_pop(local_batch, limit + 1) .success() { - return Ok(task); + Ok(task) + } else { + self.try_steal_from_worker() } - self.try_steal().ok_or(PopTaskError::Empty) } fn is_closed(&self) -> bool { diff --git a/src/hive/inner/shared.rs b/src/hive/inner/shared.rs index 41be1c9..1beeca9 100644 --- a/src/hive/inner/shared.rs +++ b/src/hive/inner/shared.rs @@ -717,8 +717,8 @@ mod retry { #[cfg(test)] mod tests { - use crate::bee::stock::ThunkWorker; use crate::bee::DefaultQueen; + use crate::bee::stock::ThunkWorker; use crate::hive::ChannelTaskQueues; type VoidThunkWorker = ThunkWorker<()>; diff --git a/src/hive/mod.rs b/src/hive/mod.rs index 68f10a0..2c2942d 100644 --- a/src/hive/mod.rs +++ b/src/hive/mod.rs @@ -2025,6 +2025,7 @@ mod batching_tests { hive: &Hive>, T>, num_threads: usize, batch_limit: usize, + assert_exact: bool, ) { let tasks_per_thread = batch_limit + 2; let (tx, rx) = crate::hive::outcome_channel(); @@ -2038,52 +2039,83 @@ mod batching_tests { hive.join(); let thread_counts = count_thread_ids(rx, task_ids); assert_eq!(thread_counts.len(), num_threads); - assert!( - thread_counts - .values() - .all(|&count| count == tasks_per_thread) + assert_eq!( + thread_counts.values().sum::(), + tasks_per_thread * num_threads ); + if assert_exact { + assert!( + thread_counts + .values() + .all(|&count| count == tasks_per_thread) + ); + } else { + assert!(thread_counts.values().all(|&count| count > 0)); + } } #[rstest] - fn test_batching(#[values(channel_builder, workstealing_builder)] builder_factory: F) - where - B: TaskQueuesBuilder, - F: Fn(bool) -> B, - { + fn test_batching_channel() { const NUM_THREADS: usize = 4; const BATCH_LIMIT: usize = 24; - let hive = builder_factory(false) + let hive = channel_builder(false) .with_worker_default() .num_threads(NUM_THREADS) .batch_limit(BATCH_LIMIT) .build(); - run_test(&hive, NUM_THREADS, BATCH_LIMIT); + run_test(&hive, NUM_THREADS, BATCH_LIMIT, true); } #[rstest] - fn test_set_batch_limit( - #[values(channel_builder, workstealing_builder)] builder_factory: F, - ) where - B: TaskQueuesBuilder, - F: Fn(bool) -> B, - { + fn test_batching_workstealing() { + const NUM_THREADS: usize = 4; + const BATCH_LIMIT: usize = 24; + let hive = workstealing_builder(false) + .with_worker_default() + .num_threads(NUM_THREADS) + .batch_limit(BATCH_LIMIT) + .build(); + run_test(&hive, NUM_THREADS, BATCH_LIMIT, false); + } + + #[rstest] + fn test_set_batch_limit_channel() { const NUM_THREADS: usize = 4; const BATCH_LIMIT_0: usize = 10; - const BATCH_LIMIT_1: usize = 20; - const BATCH_LIMIT_2: usize = 50; - let hive = builder_factory(false) + const BATCH_LIMIT_1: usize = 50; + const BATCH_LIMIT_2: usize = 20; + let hive = channel_builder(false) .with_worker_default() .num_threads(NUM_THREADS) .batch_limit(BATCH_LIMIT_0) .build(); - run_test(&hive, NUM_THREADS, BATCH_LIMIT_0); + run_test(&hive, NUM_THREADS, BATCH_LIMIT_0, true); // increase batch size - hive.set_worker_batch_limit(BATCH_LIMIT_2); - run_test(&hive, NUM_THREADS, BATCH_LIMIT_2); + hive.set_worker_batch_limit(BATCH_LIMIT_1); + run_test(&hive, NUM_THREADS, BATCH_LIMIT_1, true); // decrease batch size + hive.set_worker_batch_limit(BATCH_LIMIT_2); + run_test(&hive, NUM_THREADS, BATCH_LIMIT_2, true); + } + + #[rstest] + fn test_set_batch_limit_workstealing() { + const NUM_THREADS: usize = 4; + const BATCH_LIMIT_0: usize = 10; + const BATCH_LIMIT_1: usize = 50; + const BATCH_LIMIT_2: usize = 20; + let hive = workstealing_builder(false) + .with_worker_default() + .num_threads(NUM_THREADS) + .batch_limit(BATCH_LIMIT_0) + .build(); + run_test(&hive, NUM_THREADS, BATCH_LIMIT_0, false); + // increase batch size hive.set_worker_batch_limit(BATCH_LIMIT_1); - run_test(&hive, NUM_THREADS, BATCH_LIMIT_1); + run_test(&hive, NUM_THREADS, BATCH_LIMIT_1, false); + // decrease batch size + hive.set_worker_batch_limit(BATCH_LIMIT_2); + run_test(&hive, NUM_THREADS, BATCH_LIMIT_2, false); } #[rstest] From 5c9e5c8263f7a95f2ad13cbdbbe0c5ec6d294c05 Mon Sep 17 00:00:00 2001 From: jdidion Date: Sat, 22 Feb 2025 15:29:58 -0800 Subject: [PATCH 32/67] fix workflow --- .github/workflows/ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 168c05f..322bdc3 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -14,6 +14,8 @@ jobs: steps: - uses: actions/checkout@v4 - uses: EmbarkStudios/cargo-deny-action@v2 + with: + rust-version: "1.85.0" - uses: actions-rust-lang/setup-rust-toolchain@v1 with: components: rustfmt, clippy @@ -29,7 +31,6 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - uses: EmbarkStudios/cargo-deny-action@v2 - uses: actions-rust-lang/setup-rust-toolchain@v1 - name: Install cargo-llvm-cov uses: taiki-e/install-action@cargo-llvm-cov @@ -58,7 +59,6 @@ jobs: - loole steps: - uses: actions/checkout@v4 - - uses: EmbarkStudios/cargo-deny-action@v2 - uses: actions-rust-lang/setup-rust-toolchain@v1 - uses: actions-rs/cargo@v1 with: From 44349f35ebb30ce2ecb93a3e47f57600c374c79d Mon Sep 17 00:00:00 2001 From: jdidion Date: Sat, 22 Feb 2025 15:31:42 -0800 Subject: [PATCH 33/67] update fmt --- benches/perf.rs | 4 ++-- src/bee/stock/call.rs | 4 ++-- src/channel.rs | 8 ++++---- src/hive/builder/mod.rs | 2 +- src/hive/husk.rs | 4 ++-- src/hive/inner/config.rs | 2 +- src/hive/outcome/store.rs | 6 ++++-- src/util.rs | 6 +----- 8 files changed, 17 insertions(+), 19 deletions(-) diff --git a/benches/perf.rs b/benches/perf.rs index 3dbda77..2088bc8 100644 --- a/benches/perf.rs +++ b/benches/perf.rs @@ -1,6 +1,6 @@ use beekeeper::bee::stock::EchoWorker; -use beekeeper::hive::{outcome_channel, Builder, ChannelBuilder, TaskQueuesBuilder}; -use divan::{bench, black_box_drop, AllocProfiler, Bencher}; +use beekeeper::hive::{Builder, ChannelBuilder, TaskQueuesBuilder, outcome_channel}; +use divan::{AllocProfiler, Bencher, bench, black_box_drop}; use itertools::iproduct; #[global_allocator] diff --git a/src/bee/stock/call.rs b/src/bee/stock/call.rs index a78dc18..daa5890 100644 --- a/src/bee/stock/call.rs +++ b/src/bee/stock/call.rs @@ -270,8 +270,8 @@ mod tests { u8, String, impl FnMut((bool, u8), &Context<(bool, u8)>) -> Result> - + Clone - + 'static, + + Clone + + 'static, > { RetryCaller::of(|input: (bool, u8), _: &Context<(bool, u8)>| { if input.0 { diff --git a/src/channel.rs b/src/channel.rs index 0d195e1..32ae32f 100644 --- a/src/channel.rs +++ b/src/channel.rs @@ -30,7 +30,7 @@ pub trait ReceiverExt { #[cfg(not(any(feature = "crossbeam", feature = "flume", feature = "loole")))] pub mod prelude { - pub use std::sync::mpsc::{channel, Receiver, SendError, Sender}; + pub use std::sync::mpsc::{Receiver, SendError, Sender, channel}; use super::{Message, ReceiverExt, SenderExt}; use std::sync::mpsc::TryRecvError; @@ -57,7 +57,7 @@ pub mod prelude { #[cfg(all(feature = "crossbeam", not(any(feature = "flume", feature = "loole"))))] pub mod prelude { - pub use crossbeam_channel::{unbounded as channel, Receiver, SendError, Sender}; + pub use crossbeam_channel::{Receiver, SendError, Sender, unbounded as channel}; use super::{Message, ReceiverExt, SenderExt}; use crossbeam_channel::TryRecvError; @@ -84,7 +84,7 @@ pub mod prelude { #[cfg(all(feature = "flume", not(any(feature = "crossbeam", feature = "loole"))))] pub mod prelude { - pub use flume::{unbounded as channel, Receiver, SendError, Sender}; + pub use flume::{Receiver, SendError, Sender, unbounded as channel}; use super::{Message, ReceiverExt, SenderExt}; use flume::TryRecvError; @@ -111,7 +111,7 @@ pub mod prelude { #[cfg(all(feature = "loole", not(any(feature = "crossbeam", feature = "flume"))))] pub mod prelude { - pub use loole::{unbounded as channel, Receiver, SendError, Sender}; + pub use loole::{Receiver, SendError, Sender, unbounded as channel}; use super::{Message, ReceiverExt, SenderExt}; use loole::TryRecvError; diff --git a/src/hive/builder/mod.rs b/src/hive/builder/mod.rs index e211ed8..2bb906c 100644 --- a/src/hive/builder/mod.rs +++ b/src/hive/builder/mod.rs @@ -47,9 +47,9 @@ mod queue; pub use bee::BeeBuilder; pub use full::FullBuilder; pub use open::OpenBuilder; +pub use queue::TaskQueuesBuilder; pub use queue::channel::ChannelBuilder; pub use queue::workstealing::WorkstealingBuilder; -pub use queue::TaskQueuesBuilder; pub fn open(with_defaults: bool) -> OpenBuilder { if with_defaults { diff --git a/src/hive/husk.rs b/src/hive/husk.rs index 6051e67..cda6d77 100644 --- a/src/hive/husk.rs +++ b/src/hive/husk.rs @@ -139,8 +139,8 @@ mod tests { use crate::bee::stock::{PunkWorker, Thunk, ThunkWorker}; use crate::hive::ChannelTaskQueues; use crate::hive::{ - outcome_channel, Builder, ChannelBuilder, Outcome, OutcomeIteratorExt, OutcomeStore, - TaskQueuesBuilder, + Builder, ChannelBuilder, Outcome, OutcomeIteratorExt, OutcomeStore, TaskQueuesBuilder, + outcome_channel, }; #[test] diff --git a/src/hive/inner/config.rs b/src/hive/inner/config.rs index 17bfb75..7d4f0b8 100644 --- a/src/hive/inner/config.rs +++ b/src/hive/inner/config.rs @@ -118,8 +118,8 @@ pub mod reset { #[cfg(test)] mod tests { - use super::reset::Reset; use super::Config; + use super::reset::Reset; use serial_test::serial; #[test] diff --git a/src/hive/outcome/store.rs b/src/hive/outcome/store.rs index d401ed5..30a3ab4 100644 --- a/src/hive/outcome/store.rs +++ b/src/hive/outcome/store.rs @@ -53,7 +53,9 @@ pub trait OutcomeStore: DerefOutcomes { fn assert_empty(&self, allow_successes: bool) { let (unprocessed, successes, failures) = self.count(); if !allow_successes && successes > 0 { - panic!("{unprocessed} unprocessed inputs, {successes} successes, and {failures} failed tasks found"); + panic!( + "{unprocessed} unprocessed inputs, {successes} successes, and {failures} failed tasks found" + ); } else if unprocessed > 0 || failures > 0 { panic!("{unprocessed} unprocessed inputs and {failures} failed tasks found"); } @@ -523,8 +525,8 @@ mod tests { #[cfg(all(test, feature = "retry"))] mod retry_tests { - use super::tests::TestWorker; use super::OutcomeStore; + use super::tests::TestWorker; use crate::hive::{Outcome, OutcomeBatch}; use crate::panic::Panic; diff --git a/src/util.rs b/src/util.rs index 20d057c..d8793b7 100644 --- a/src/util.rs +++ b/src/util.rs @@ -92,11 +92,7 @@ mod tests { 4, 0..100, |i| { - if i == 50 { - Err("Fiddy!") - } else { - Ok(i + 1) - } + if i == 50 { Err("Fiddy!") } else { Ok(i + 1) } }, ); assert!(result.has_failures()); From f3790921df7da7133f27de7e07348addb810b8b1 Mon Sep 17 00:00:00 2001 From: jdidion Date: Sat, 22 Feb 2025 15:50:54 -0800 Subject: [PATCH 34/67] remove captures that are no longer necessary w rust 2024 --- src/bee/queen.rs | 2 +- src/hive/cores.rs | 2 +- src/hive/inner/queue/channel.rs | 2 +- src/hive/inner/shared.rs | 2 +- src/hive/outcome/queue.rs | 2 +- src/hive/outcome/store.rs | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/bee/queen.rs b/src/bee/queen.rs index 8f9d808..8405086 100644 --- a/src/bee/queen.rs +++ b/src/bee/queen.rs @@ -32,7 +32,7 @@ impl QueenCell { Self(RwLock::new(mut_queen)) } - pub fn get(&self) -> impl Deref + '_ { + pub fn get(&self) -> impl Deref { self.0.read() } diff --git a/src/hive/cores.rs b/src/hive/cores.rs index 77b7f17..a84a485 100644 --- a/src/hive/cores.rs +++ b/src/hive/cores.rs @@ -132,7 +132,7 @@ impl Cores { /// Returns an iterator over `(core_index, Option)`, where `Some(core)` can be used to /// set the core affinity of the current thread. The `core` will be `None` for cores that are /// not currently available. - pub fn iter(&self) -> impl Iterator)> + '_ { + pub fn iter(&self) -> impl Iterator)> { let cores = CORES.lock(); self.0.iter().cloned().map(move |index| { ( diff --git a/src/hive/inner/queue/channel.rs b/src/hive/inner/queue/channel.rs index d8eef27..9e9c9c9 100644 --- a/src/hive/inner/queue/channel.rs +++ b/src/hive/inner/queue/channel.rs @@ -128,7 +128,7 @@ impl GlobalQueue { } #[cfg(feature = "batching")] - fn try_iter(&self) -> impl Iterator> + '_ { + fn try_iter(&self) -> impl Iterator> { self.global_rx.try_iter() } } diff --git a/src/hive/inner/shared.rs b/src/hive/inner/shared.rs index 1beeca9..77cc27d 100644 --- a/src/hive/inner/shared.rs +++ b/src/hive/inner/shared.rs @@ -458,7 +458,7 @@ impl, T: TaskQueues> Shared { } /// Returns a mutable reference to the retained task outcomes. - pub fn outcomes(&self) -> impl DerefMut>> + '_ { + pub fn outcomes(&self) -> impl DerefMut>> { self.outcomes.get_mut() } diff --git a/src/hive/outcome/queue.rs b/src/hive/outcome/queue.rs index 6d3c83b..a571d2d 100644 --- a/src/hive/outcome/queue.rs +++ b/src/hive/outcome/queue.rs @@ -17,7 +17,7 @@ impl OutcomeQueue { } /// Flushes the queue into the map of outcomes and returns a mutable reference to the map. - pub fn get_mut(&self) -> impl DerefMut>> + '_ { + pub fn get_mut(&self) -> impl DerefMut>> { let mut outcomes = self.outcomes.lock(); // add any queued outcomes to the map while let Some(outcome) = self.queue.pop() { diff --git a/src/hive/outcome/store.rs b/src/hive/outcome/store.rs index 30a3ab4..1858653 100644 --- a/src/hive/outcome/store.rs +++ b/src/hive/outcome/store.rs @@ -10,7 +10,7 @@ pub trait DerefOutcomes { fn outcomes_deref(&self) -> impl Deref>>; /// Returns a mutable reference to a map of task task_id to `Outcome`. - fn outcomes_deref_mut(&mut self) -> impl DerefMut>> + '_; + fn outcomes_deref_mut(&mut self) -> impl DerefMut>>; } pub trait OwnedOutcomes: Sized { From 2e586fa15a7dd58cce5c43d5e4804f0a3046aace Mon Sep 17 00:00:00 2001 From: jdidion Date: Sat, 22 Feb 2025 18:27:48 -0800 Subject: [PATCH 35/67] fix tests --- src/hive/inner/mod.rs | 3 +++ src/hive/mod.rs | 13 +++++-------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/hive/inner/mod.rs b/src/hive/inner/mod.rs index f3f4686..eb2c41a 100644 --- a/src/hive/inner/mod.rs +++ b/src/hive/inner/mod.rs @@ -16,6 +16,9 @@ pub mod set_config { }; } +// Note: it would be more appropriate for the publicly exported traits (`Builder`, `TaskQueues`) +// to be in the `beekeeper::hive` module, but they need to be in `inner` for visiblity reasons. + pub use self::builder::{Builder, BuilderConfig}; pub use self::queue::{ChannelTaskQueues, TaskQueues, WorkerQueues, WorkstealingTaskQueues}; diff --git a/src/hive/mod.rs b/src/hive/mod.rs index 2c2942d..8cbbf3c 100644 --- a/src/hive/mod.rs +++ b/src/hive/mod.rs @@ -1650,17 +1650,14 @@ mod tests { /// changed and will wake, even if new tasks have been added in the meantime. /// /// In this example, this means the waiting threads will exit the join in groups of four - /// because the waiter pool has four processes. - fn test_join_wavesurfer( - #[values(channel_builder, workstealing_builder)] builder_factory: F, - ) where - B: TaskQueuesBuilder, - F: Fn(bool) -> B, - { + /// because the waiter pool has four processes + /// + /// TODO: make this test work with WorkstealingTaskQueues. + fn test_join_wavesurfer() { let n_waves = 4; let n_workers = 4; let (tx, rx) = mpsc::channel(); - let builder = builder_factory(false) + let builder = channel_builder(false) .num_threads(n_workers) .thread_name("join wavesurfer"); let waiter_hive = builder From 9c4b082f0727a87615a79573e1ea2d6cc88e05f0 Mon Sep 17 00:00:00 2001 From: jdidion Date: Sat, 22 Feb 2025 18:36:44 -0800 Subject: [PATCH 36/67] fix tests --- src/hive/mod.rs | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/hive/mod.rs b/src/hive/mod.rs index 8cbbf3c..8433d96 100644 --- a/src/hive/mod.rs +++ b/src/hive/mod.rs @@ -1099,13 +1099,10 @@ mod tests { assert_eq!(outputs, (0..8).rev().collect::>()) } + // TODO: make this test work with WorkstealingTaskQueues #[rstest] - fn test_map_send(#[values(channel_builder, workstealing_builder)] builder_factory: F) - where - B: TaskQueuesBuilder, - F: Fn(bool) -> B, - { - let hive = thunk_hive::(8, builder_factory(false)); + fn test_map_send() { + let hive = thunk_hive::(8, channel_builder(false)); let (tx, rx) = super::outcome_channel(); let mut task_ids = hive.map_send( (0..8u8).map(|i| { From 658cd31a1295bbdc9292d04cfd55ff34561bb021 Mon Sep 17 00:00:00 2001 From: jdidion Date: Sat, 22 Feb 2025 18:50:39 -0800 Subject: [PATCH 37/67] fix tests --- src/hive/mod.rs | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/src/hive/mod.rs b/src/hive/mod.rs index 8433d96..7dc9b18 100644 --- a/src/hive/mod.rs +++ b/src/hive/mod.rs @@ -1099,10 +1099,13 @@ mod tests { assert_eq!(outputs, (0..8).rev().collect::>()) } - // TODO: make this test work with WorkstealingTaskQueues #[rstest] - fn test_map_send() { - let hive = thunk_hive::(8, channel_builder(false)); + fn test_map_send(#[values(channel_builder, workstealing_builder)] builder_factory: F) + where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + let hive = thunk_hive::(8, builder_factory(false)); let (tx, rx) = super::outcome_channel(); let mut task_ids = hive.map_send( (0..8u8).map(|i| { @@ -1113,17 +1116,18 @@ mod tests { }), tx, ); - let (mut outcome_task_ids, values): (Vec, Vec) = rx + let (mut outcome_task_ids, mut values): (Vec, Vec) = rx .iter() .map(|outcome| match outcome { Outcome::Success { value, task_id } => (task_id, value), _ => panic!("unexpected error"), }) .unzip(); - assert_eq!(values, (0..8).rev().collect::>()); task_ids.sort(); outcome_task_ids.sort(); assert_eq!(task_ids, outcome_task_ids); + values.sort(); + assert_eq!(values, (0..8).collect::>()); } #[rstest] @@ -1181,7 +1185,7 @@ mod tests { F: Fn(bool) -> B, { let hive = thunk_hive::(8, builder_factory(false)); - let outputs: Vec<_> = hive + let mut outputs: Vec<_> = hive .swarm_unordered((0..8u8).map(|i| { Thunk::of(move || { thread::sleep(Duration::from_millis((8 - i as u64) * 100)); @@ -1190,7 +1194,8 @@ mod tests { })) .map(Outcome::unwrap) .collect(); - assert_eq!(outputs, (0..8).rev().collect::>()) + outputs.sort(); + assert_eq!(outputs, (0..8).collect::>()) } #[rstest] From d3ec064040b3eb1e18eae61fb5ab49649b836f6e Mon Sep 17 00:00:00 2001 From: jdidion Date: Sat, 22 Feb 2025 19:02:14 -0800 Subject: [PATCH 38/67] fix tests --- src/hive/mod.rs | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/hive/mod.rs b/src/hive/mod.rs index 7dc9b18..658d975 100644 --- a/src/hive/mod.rs +++ b/src/hive/mod.rs @@ -2117,18 +2117,14 @@ mod batching_tests { run_test(&hive, NUM_THREADS, BATCH_LIMIT_2, false); } + // TODO: make this work with WorkstealingTaskQueues #[rstest] - fn test_shrink_batch_limit( - #[values(channel_builder, workstealing_builder)] builder_factory: F, - ) where - B: TaskQueuesBuilder, - F: Fn(bool) -> B, - { + fn test_shrink_batch_limit() { const NUM_THREADS: usize = 4; const NUM_TASKS_PER_THREAD: usize = 125; const BATCH_LIMIT_0: usize = 100; const BATCH_LIMIT_1: usize = 10; - let hive = builder_factory(false) + let hive = channel_builder(false) .with_worker_default() .num_threads(NUM_THREADS) .batch_limit(BATCH_LIMIT_0) From 2f0a57b50a9717d68e47f315607234dcfc6eb14a Mon Sep 17 00:00:00 2001 From: jdidion Date: Sat, 22 Feb 2025 19:07:23 -0800 Subject: [PATCH 39/67] fix tests --- src/hive/mod.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/hive/mod.rs b/src/hive/mod.rs index 658d975..be51ba3 100644 --- a/src/hive/mod.rs +++ b/src/hive/mod.rs @@ -1087,7 +1087,7 @@ mod tests { F: Fn(bool) -> B, { let hive = thunk_hive::(8, builder_factory(false)); - let outputs: Vec<_> = hive + let mut outputs: Vec<_> = hive .map_unordered((0..8u8).map(|i| { Thunk::of(move || { thread::sleep(Duration::from_millis((8 - i as u64) * 100)); @@ -1096,7 +1096,8 @@ mod tests { })) .map(Outcome::unwrap) .collect(); - assert_eq!(outputs, (0..8).rev().collect::>()) + outputs.sort(); + assert_eq!(outputs, (0..8).collect::>()) } #[rstest] From ff9fce6086a10e216bc42eac5596e6dbbc90f65a Mon Sep 17 00:00:00 2001 From: jdidion Date: Mon, 24 Feb 2025 10:51:21 -0800 Subject: [PATCH 40/67] add tests --- Cargo.toml | 2 +- src/bee/context.rs | 16 +++++--- src/hive/mod.rs | 58 ++++++++++++++++++++++++++++ src/hive/outcome/iter.rs | 27 ++++++++++--- src/hive/outcome/mod.rs | 1 - src/hive/outcome/outcome.rs | 77 ++++++++++++++++++++++++++++++++++++- 6 files changed, 166 insertions(+), 15 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 8fba746..db4c6a5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,7 +40,7 @@ name = "perf" harness = false [features] -default = [] +default = ["affinity", "batching", "retry"] affinity = ["dep:core_affinity"] batching = [] retry = [] diff --git a/src/bee/context.rs b/src/bee/context.rs index e804bc0..795b1ff 100644 --- a/src/bee/context.rs +++ b/src/bee/context.rs @@ -1,4 +1,5 @@ //! The context for a task processed by a `Worker`. +use std::cell::RefCell; use std::fmt::Debug; pub type TaskId = usize; @@ -17,7 +18,7 @@ pub trait TaskContext: Debug { pub struct Context<'a, I> { task_id: TaskId, task_ctx: Option<&'a dyn TaskContext>, - subtask_ids: Option>, + subtask_ids: RefCell>>, #[cfg(feature = "retry")] attempt: u32, } @@ -50,10 +51,13 @@ impl Context<'_, I> { /// the `Hive` if there is no sender. /// /// Returns an `Err` containing `input` if the new task was not successfully submitted. - pub fn submit(&mut self, input: I) -> Result<(), I> { + pub fn submit(&self, input: I) -> Result<(), I> { if let Some(worker) = self.task_ctx.as_ref() { let task_id = worker.submit_task(input); - self.subtask_ids.get_or_insert_default().push(task_id); + self.subtask_ids + .borrow_mut() + .get_or_insert_default() + .push(task_id); Ok(()) } else { Err(input) @@ -63,7 +67,7 @@ impl Context<'_, I> { /// Consumes this `Context` and returns the IDs of the subtasks spawned during the execution /// of the task, if any. pub(crate) fn into_subtask_ids(self) -> Option> { - self.subtask_ids + self.subtask_ids.into_inner() } } @@ -103,7 +107,7 @@ impl<'a, I> Context<'a, I> { task_id: 0, attempt: 0, task_ctx: None, - subtask_ids: None, + subtask_ids: RefCell::new(None), } } @@ -113,7 +117,7 @@ impl<'a, I> Context<'a, I> { task_id, attempt, task_ctx, - subtask_ids: None, + subtask_ids: RefCell::new(None), } } diff --git a/src/hive/mod.rs b/src/hive/mod.rs index be51ba3..02823b7 100644 --- a/src/hive/mod.rs +++ b/src/hive/mod.rs @@ -1493,6 +1493,46 @@ mod tests { .collect::>(); } + const NUM_FIRST_TASKS: usize = 4; + + #[derive(Debug, Default)] + struct SendWorker; + + impl Worker for SendWorker { + type Input = usize; + type Output = usize; + type Error = (); + + fn apply(&mut self, input: Self::Input, ctx: &Context) -> WorkerResult { + if input < NUM_FIRST_TASKS { + ctx.submit(input + NUM_FIRST_TASKS) + .map_err(|input| ApplyError::Retryable { input, error: () })?; + } + Ok(input) + } + } + + #[rstest] + fn test_send_from_task( + #[values(channel_builder, workstealing_builder)] builder_factory: F, + ) where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + let hive = builder_factory(false) + .num_threads(2) + .with_worker_default::() + .build(); + let (tx, rx) = super::outcome_channel(); + let task_ids = hive.map_send(0..NUM_FIRST_TASKS, tx); + hive.join(); + // each task submits another task + assert_eq!(task_ids.len(), NUM_FIRST_TASKS); + let outputs: Vec<_> = rx.select_ordered_outputs(task_ids).collect(); + assert_eq!(outputs.len(), NUM_FIRST_TASKS * 2); + assert_eq!(outputs, (0..NUM_FIRST_TASKS * 2).collect::>()); + } + #[rstest] fn test_husk(#[values(channel_builder, workstealing_builder)] builder_factory: F) where @@ -1622,6 +1662,24 @@ mod tests { ); } + #[rstest] + fn test_clone_into_husk_fails( + #[values(channel_builder, workstealing_builder)] builder_factory: F, + ) where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + let hive1: Hive>, B::TaskQueues<_>> = builder_factory(false) + .with_worker_default() + .num_threads(2) + .build(); + let hive2 = hive1.clone(); + // should return None the first time since there is more than one reference + assert!(hive1.try_into_husk(false).is_none()); + // hive1 has been dropped, so we're down to 1 reference and it should succeed + assert!(hive2.try_into_husk(false).is_some()); + } + #[rstest] fn test_channel_hive_send() { fn assert_send() {} diff --git a/src/hive/outcome/iter.rs b/src/hive/outcome/iter.rs index 54f9405..47b8c3f 100644 --- a/src/hive/outcome/iter.rs +++ b/src/hive/outcome/iter.rs @@ -37,7 +37,7 @@ impl UnorderedOutcomeIterator { { let task_ids: BTreeSet<_> = task_ids.into_iter().collect(); Self { - inner: Box::new(inner.into_iter().take(task_ids.len())), + inner: Box::new(inner.into_iter()), task_ids, } } @@ -47,19 +47,27 @@ impl Iterator for UnorderedOutcomeIterator { type Item = Outcome; fn next(&mut self) -> Option { + if self.task_ids.is_empty() { + return None; + } loop { match self.inner.next() { Some(outcome) if self.task_ids.remove(outcome.task_id()) => break Some(outcome), - Some(_) => continue, // drop unrequested outcomes - None if !self.task_ids.is_empty() => { + None => { // convert extra task_ids to Missing outcomes break Some(Outcome::Missing { task_id: self.task_ids.pop_first().unwrap(), }); } - None => break None, + _ => continue, // drop unrequested outcomes } } + .inspect(|outcome| { + // if the originating task submitted subtasks, add their IDs to the queue + if let Some(subtask_ids) = outcome.subtask_ids() { + self.task_ids.extend(subtask_ids); + } + }) } } @@ -84,7 +92,7 @@ impl OrderedOutcomeIterator { { let task_ids: VecDeque = task_ids.into_iter().collect(); Self { - inner: Box::new(inner.into_iter().take(task_ids.len())), + inner: Box::new(inner.into_iter()), buf: HashMap::with_capacity(task_ids.len()), task_ids, } @@ -122,10 +130,19 @@ impl Iterator for OrderedOutcomeIterator { //if !self.buf.is_empty() { .. } break None; } + .inspect(|outcome| { + // if the originating task submitted subtasks, add their IDs to the queue + if let Some(subtask_ids) = outcome.subtask_ids() { + self.task_ids.extend(subtask_ids); + } + }) } } /// Extension trait for iterators over `Outcome`s. +/// +/// Note that, if your worker submits additional tasks to the `Hive`, their `Outcome`s will be +/// included in the iterator. pub trait OutcomeIteratorExt: IntoIterator> + Sized { /// Consumes this iterator and returns an unordered iterator over the `Outcome`s with the /// specified `task_ids`. diff --git a/src/hive/outcome/mod.rs b/src/hive/outcome/mod.rs index f3d2610..3669e51 100644 --- a/src/hive/outcome/mod.rs +++ b/src/hive/outcome/mod.rs @@ -23,7 +23,6 @@ use crate::panic::Panic; /// /// Note that `Outcome`s can only be compared or ordered with other `Outcome`s produced by the same /// `Hive`, because comparison/ordering is completely based on the task ID. -#[derive(Debug)] pub enum Outcome { /// The task was executed successfully. Success { value: W::Output, task_id: TaskId }, diff --git a/src/hive/outcome/outcome.rs b/src/hive/outcome/outcome.rs index 777430d..39b8886 100644 --- a/src/hive/outcome/outcome.rs +++ b/src/hive/outcome/outcome.rs @@ -1,6 +1,7 @@ use super::Outcome; use crate::bee::{ApplyError, TaskId, Worker, WorkerResult}; use std::cmp::Ordering; +use std::fmt::Debug; impl Outcome { /// Converts a worker `result` into an `Outcome` with the given task_id and optional subtask ids. @@ -130,13 +131,16 @@ impl Outcome { /// Consumes this `Outcome` and returns the value if it is a `Success`, otherwise panics. pub fn unwrap(self) -> W::Output { - self.success().expect("not a Success outcome") + match self { + Self::Success { value, .. } | Self::SuccessWithSubtasks { value, .. } => value, + outcome => panic!("Not a success outcome: {:?}", outcome), + } } /// Consumes this `Outcome` and returns the output value if it is a `Success`, otherwise `None`. pub fn success(self) -> Option { match self { - Self::Success { value, .. } => Some(value), + Self::Success { value, .. } | Self::SuccessWithSubtasks { value, .. } => Some(value), _ => None, } } @@ -178,6 +182,75 @@ impl Outcome { } } +impl Debug for Outcome { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Success { task_id, .. } => { + f.debug_struct("Success").field("task_id", task_id).finish() + } + Self::SuccessWithSubtasks { + task_id, + subtask_ids, + .. + } => f + .debug_struct("SuccessWithSubtasks") + .field("task_id", task_id) + .field("subtask_ids", subtask_ids) + .finish(), + Self::Failure { error, task_id, .. } => f + .debug_struct("Failure") + .field("error", error) + .field("task_id", task_id) + .finish(), + Self::FailureWithSubtasks { + error, + task_id, + subtask_ids, + .. + } => f + .debug_struct("FailureWithSubtasks") + .field("error", error) + .field("task_id", task_id) + .field("subtask_ids", subtask_ids) + .finish(), + Self::Unprocessed { task_id, .. } => f + .debug_struct("Unprocessed") + .field("task_id", task_id) + .finish(), + Self::UnprocessedWithSubtasks { + task_id, + subtask_ids, + .. + } => f + .debug_struct("UnprocessedWithSubtasks") + .field("task_id", task_id) + .field("subtask_ids", subtask_ids) + .finish(), + Self::Missing { task_id } => { + f.debug_struct("Missing").field("task_id", task_id).finish() + } + Self::Panic { task_id, .. } => { + f.debug_struct("Panic").field("task_id", task_id).finish() + } + Self::PanicWithSubtasks { + task_id, + subtask_ids, + .. + } => f + .debug_struct("PanicWithSubtasks") + .field("task_id", task_id) + .field("subtask_ids", subtask_ids) + .finish(), + #[cfg(feature = "retry")] + Self::MaxRetriesAttempted { error, task_id, .. } => f + .debug_struct("MaxRetriesAttempted") + .field("error", error) + .field("task_id", task_id) + .finish(), + } + } +} + impl PartialEq for Outcome { fn eq(&self, other: &Self) -> bool { match (self, other) { From 551cd9d0ed3f4f9aca39f1e262323759b17b1c63 Mon Sep 17 00:00:00 2001 From: jdidion Date: Mon, 24 Feb 2025 11:03:40 -0800 Subject: [PATCH 41/67] fix --- src/bee/context.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/bee/context.rs b/src/bee/context.rs index 795b1ff..45c2417 100644 --- a/src/bee/context.rs +++ b/src/bee/context.rs @@ -78,7 +78,7 @@ impl<'a, I> Context<'a, I> { Self { task_id: 0, task_ctx: None, - subtask_ids: None, + subtask_ids: RefCell::new(None), } } @@ -87,7 +87,7 @@ impl<'a, I> Context<'a, I> { Self { task_id, task_ctx, - subtask_ids: None, + subtask_ids: RefCell::new(None), } } From 6551f47dcede015fb324cc8435f302e804a1bf7b Mon Sep 17 00:00:00 2001 From: jdidion Date: Thu, 27 Feb 2025 12:55:23 -0800 Subject: [PATCH 42/67] add mock module; add support for task weighting --- .github/workflows/ci.yml | 8 +- CHANGELOG.md | 5 +- Cargo.toml | 14 +- README.md | 2 +- src/atomic.rs | 4 +- src/bee/context.rs | 157 ++++++------ src/bee/mock.rs | 44 ++++ src/bee/mod.rs | 7 +- src/bee/queen.rs | 3 + src/hive/builder/mod.rs | 4 +- src/hive/context.rs | 63 +++++ src/hive/hive.rs | 341 ++++++++------------------- src/hive/inner/builder.rs | 40 +++- src/hive/inner/config.rs | 28 ++- src/hive/inner/mod.rs | 16 +- src/hive/inner/queue/channel.rs | 115 +++++---- src/hive/inner/queue/retry.rs | 26 +- src/hive/inner/queue/workstealing.rs | 76 ++++-- src/hive/inner/shared.rs | 113 ++++++--- src/hive/inner/task.rs | 140 +++++++---- src/hive/mod.rs | 87 +++---- src/hive/outcome/mod.rs | 7 + src/hive/outcome/outcome.rs | 28 ++- src/hive/sentinel.rs | 67 ++++++ src/hive/util.rs | 33 +++ src/hive/weighted.rs | 172 ++++++++++++++ 26 files changed, 1026 insertions(+), 574 deletions(-) create mode 100644 src/bee/mock.rs create mode 100644 src/hive/context.rs create mode 100644 src/hive/sentinel.rs create mode 100644 src/hive/util.rs create mode 100644 src/hive/weighted.rs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 322bdc3..9331dab 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -20,11 +20,11 @@ jobs: with: components: rustfmt, clippy - run: | - cargo clippy --all-targets -F affinity,batching,retry \ + cargo clippy --all-targets -F affinity,local-batch,retry \ -- -D warnings $(cat .lints | cut -f1 -d"#" | tr '\n' ' ') - run: cargo fmt -- --check - - run: cargo doc -F affinity,batching,retry - - run: cargo test -F affinity,batching,retry --doc + - run: cargo doc -F affinity,local-batch,retry + - run: cargo test -F affinity,local-batch,retry --doc coverage: name: Code coverage @@ -35,7 +35,7 @@ jobs: - name: Install cargo-llvm-cov uses: taiki-e/install-action@cargo-llvm-cov - name: Generate code coverage - run: cargo llvm-cov --lcov --output-path lcov.info -F affinity,batching,retry + run: cargo llvm-cov --lcov --output-path lcov.info -F affinity,local-batch,retry - name: Upload coverage to Codecov uses: codecov/codecov-action@v3 with: diff --git a/CHANGELOG.md b/CHANGELOG.md index 2b25839..1d41e62 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,12 +20,15 @@ The general theme of this release is performance improvement by eliminating thre * Added the `TaskQueues` trait, which enables `Hive` to be specialized for different implementations of global (i.e., sending tasks from the `Hive` to worker threads) and local (i.e., worker thread-specific) queues. * `ChannelTaskQueues` implements the existing behavior, using a channel for sending tasks. * `WorkstealingTaskQueues` has been added to implement the workstealing pattern, based on `crossbeam::dequeue`. - * Added the `batching` feature, which enables worker threads to queue up batches of tasks locally, which can alleviate contention between threads in the pool, especially when there are many short-lived tasks. + * Added the `local-batch` feature, which enables worker threads to queue up batches of tasks locally, which can alleviate contention between threads in the pool, especially when there are many short-lived tasks. + * When this feature is enabled, tasks can be optionally weighted (by wrapping each input in `crate::hive::Weighted`) to help evenly distribute tasks with variable processing times. + * Enabling this feature should be transparent (i.e., not break existing code), and the `Hive`'s task submission methods support both weighted and unweighted inputs (due to the blanket implementation of `From for Weighted`); however, there are some cases where it is now necessary to specify the input type where before it could be elided. * Added the `Context::submit` method, which enables tasks to submit new tasks to the `Hive`. * Other * Switched to using thread-local retry queues for the implementation of the `retry` feature, to reduce thread-contention. * Switched to storing `Outcome`s in the hive using a data structure that does not require locking when inserting, which should reduce thread contention when using `*_store` operations. * Switched to using `crossbeam_channel` for the task input channel in `ChannelTaskQueues`. These are multi-produer, multi-consumer channels (mpmc; as opposed to `std::mpsc`, which is single-consumer), which means it is no longer necessary for worker threads to aquire a Mutex lock on the channel receiver when getting tasks. + * Added the `beekeeper::bee::mock` module, which has a mock implementation of `beekeeper::bee::context::LocalContext`, and a `apply` function for `apply`ing a worker in a mock context. This is useful for testing your `Worker`. ## 0.2.1 diff --git a/Cargo.toml b/Cargo.toml index db4c6a5..c3bdb43 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,12 +24,13 @@ thiserror = "1.0.63" # required with the `affinity` feature core_affinity = { version = "0.8.1", optional = true } # required with alternate outcome channel implementations that can be enabled with features +# NOTE: these version requirements could be relaxed as we don't actually depend on the +# functionality of these crates internally (other than in tests) flume = { version = "0.11.1", optional = true } loole = { version = "0.4.0", optional = true } [dev-dependencies] divan = "0.1.17" -generic-tests = "0.1.3" itertools = "0.14.0" serial_test = "3.2.0" rstest = "0.22.0" @@ -40,16 +41,21 @@ name = "perf" harness = false [features] -default = ["affinity", "batching", "retry"] +default = ["affinity", "local-batch", "retry"] affinity = ["dep:core_affinity"] -batching = [] +local-batch = [] retry = [] crossbeam = [] flume = ["dep:flume"] loole = ["dep:loole"] +[lints.rust] +unexpected_cfgs = { level = "warn", check-cfg = [ + 'cfg(coverage,coverage_nightly)', +] } + [package.metadata.cargo-all-features] -allowlist = ["affinity", "batching", "retry"] +allowlist = ["affinity", "local-batch", "retry"] [profile.release] lto = true diff --git a/README.md b/README.md index ea26646..534f21c 100644 --- a/README.md +++ b/README.md @@ -48,7 +48,7 @@ is sometimes called a "worker pool"). * The following optional features are provided via feature flags: * `affinity`: worker threads may be pinned to CPU cores to minimize the overhead of context-switching. - * `batching` (>=0.3.0): worker threads take batches of tasks from the input channel and queue them locally, which may alleviate thread contention, especially when there are many short-lived tasks. + * `local-batch` (>=0.3.0): worker threads take batches of tasks from the input channel and queue them locally, which may alleviate thread contention, especially when there are many short-lived tasks. * `retry`: Tasks that fail due to transient errors (e.g., temporarily unavailable resources) may be retried a set number of times, with an optional, exponentially increasing delay between retries. diff --git a/src/atomic.rs b/src/atomic.rs index 6452717..ef0e72c 100644 --- a/src/atomic.rs +++ b/src/atomic.rs @@ -389,8 +389,8 @@ mod affinity { } } -#[cfg(any(feature = "batching", feature = "retry"))] -mod batching { +#[cfg(any(feature = "local-batch", feature = "retry"))] +mod local_batch { use super::{Atomic, AtomicOption, MutError}; use std::fmt::Debug; diff --git a/src/bee/context.rs b/src/bee/context.rs index 45c2417..0150ecc 100644 --- a/src/bee/context.rs +++ b/src/bee/context.rs @@ -1,12 +1,14 @@ //! The context for a task processed by a `Worker`. + use std::cell::RefCell; use std::fmt::Debug; +/// Type of unique ID for a task within the `Hive`. pub type TaskId = usize; /// Trait that provides a `Context` with limited access to a worker thread's state during /// task execution. -pub trait TaskContext: Debug { +pub trait LocalContext: Debug { /// Returns `true` if tasks in progress should be cancelled. fn should_cancel_tasks(&self) -> bool; @@ -14,19 +16,41 @@ pub trait TaskContext: Debug { fn submit_task(&self, input: I) -> TaskId; } +/// The context visible to a task when processing an input. #[derive(Debug)] pub struct Context<'a, I> { - task_id: TaskId, - task_ctx: Option<&'a dyn TaskContext>, + meta: TaskMeta, + local: Option<&'a dyn LocalContext>, subtask_ids: RefCell>>, - #[cfg(feature = "retry")] - attempt: u32, } -impl Context<'_, I> { +impl<'a, I> Context<'a, I> { + /// Returns a new empty context. This is primarily useful for testing. + pub fn empty() -> Self { + Self { + meta: TaskMeta::empty(), + local: None, + subtask_ids: RefCell::new(None), + } + } + + /// Creates a new `Context` with the given task_id and shared cancellation status. + pub fn new(meta: TaskMeta, local: Option<&'a dyn LocalContext>) -> Self { + Self { + meta, + local, + subtask_ids: RefCell::new(None), + } + } + /// The unique ID of this task within the `Hive`. pub fn task_id(&self) -> TaskId { - self.task_id + self.meta.id + } + + /// Returns the number of previous failed attempts to execute the current task. + pub fn attempt(&self) -> u32 { + self.meta.attempt } /// Returns `true` if the current task should be cancelled. @@ -34,7 +58,7 @@ impl Context<'_, I> { /// A long-running `Worker` should check this periodically and, if it returns `true`, exit /// early with an `ApplyError::Cancelled` result. pub fn is_cancelled(&self) -> bool { - self.task_ctx + self.local .as_ref() .map(|worker| worker.should_cancel_tasks()) .unwrap_or(false) @@ -52,8 +76,8 @@ impl Context<'_, I> { /// /// Returns an `Err` containing `input` if the new task was not successfully submitted. pub fn submit(&self, input: I) -> Result<(), I> { - if let Some(worker) = self.task_ctx.as_ref() { - let task_id = worker.submit_task(input); + if let Some(local) = self.local.as_ref() { + let task_id = local.submit_task(input); self.subtask_ids .borrow_mut() .get_or_insert_default() @@ -66,87 +90,80 @@ impl Context<'_, I> { /// Consumes this `Context` and returns the IDs of the subtasks spawned during the execution /// of the task, if any. - pub(crate) fn into_subtask_ids(self) -> Option> { - self.subtask_ids.into_inner() + pub(crate) fn into_parts(self) -> (TaskMeta, Option>) { + (self.meta, self.subtask_ids.into_inner()) } } -#[cfg(not(feature = "retry"))] -impl<'a, I> Context<'a, I> { - /// Returns a new empty context. This is primarily useful for testing. +/// The metadata of a task. +#[derive(Default, Clone, Debug)] +pub struct TaskMeta { + id: TaskId, + #[cfg(feature = "local-batch")] + weight: u32, + #[cfg(feature = "retry")] + attempt: u32, +} + +impl TaskMeta { pub fn empty() -> Self { - Self { - task_id: 0, - task_ctx: None, - subtask_ids: RefCell::new(None), + Self::new(0) + } + + pub fn new(id: TaskId) -> Self { + TaskMeta { + id, + ..Default::default() } } - /// Creates a new `Context` with the given task_id and shared cancellation status. - pub fn new(task_id: TaskId, task_ctx: Option<&'a dyn TaskContext>) -> Self { - Self { - task_id, - task_ctx, - subtask_ids: RefCell::new(None), + #[cfg(feature = "local-batch")] + pub fn with_weight(task_id: TaskId, weight: u32) -> Self { + TaskMeta { + id: task_id, + weight, + ..Default::default() } } + pub fn id(&self) -> TaskId { + self.id + } + /// The number of previous failed attempts to execute the current task. /// - /// Always returns `0`. + /// Always returns `0` if the `retry` feature is not enabled. pub fn attempt(&self) -> u32 { - 0 - } -} - -#[cfg(feature = "retry")] -impl<'a, I> Context<'a, I> { - /// Returns a new empty context. This is primarily useful for testing. - pub fn empty() -> Self { - Self { - task_id: 0, - attempt: 0, - task_ctx: None, - subtask_ids: RefCell::new(None), - } + #[cfg(feature = "retry")] + return self.attempt; + #[cfg(not(feature = "retry"))] + return 0; } - /// Creates a new `Context` with the given task_id and shared cancellation status. - pub fn new(task_id: TaskId, attempt: u32, task_ctx: Option<&'a dyn TaskContext>) -> Self { - Self { - task_id, - attempt, - task_ctx, - subtask_ids: RefCell::new(None), - } + /// Increments the number of previous failed attempts to execute the current task. + #[cfg(feature = "retry")] + pub(crate) fn inc_attempt(&mut self) { + self.attempt += 1; } - /// The number of previous attempts to execute the current task. + /// Returns the task weight. /// - /// Returns `0` for the first attempt and increments by `1` for each retry attempt (if any). - pub fn attempt(&self) -> u32 { - self.attempt + /// Always returns `0` if the `local-batch` feature is not enabled. + pub fn weight(&self) -> u32 { + #[cfg(feature = "local-batch")] + return self.weight; + #[cfg(not(feature = "local-batch"))] + return 0; } } -#[cfg(test)] -pub mod mock { - use super::{TaskContext, TaskId}; - use std::cell::RefCell; - - #[derive(Debug, Default)] - pub struct MockTaskContext(RefCell); - - impl TaskContext for MockTaskContext { - fn should_cancel_tasks(&self) -> bool { - false - } - - fn submit_task(&self, _: I) -> super::TaskId { - let mut task_id = self.0.borrow_mut(); - let cur_id = *task_id; - *task_id += 1; - cur_id +#[cfg(all(test, feature = "retry"))] +impl TaskMeta { + pub fn with_attempt(task_id: TaskId, attempt: u32) -> Self { + Self { + id: task_id, + attempt, + ..Default::default() } } } diff --git a/src/bee/mock.rs b/src/bee/mock.rs new file mode 100644 index 0000000..c73b91f --- /dev/null +++ b/src/bee/mock.rs @@ -0,0 +1,44 @@ +use super::{Context, LocalContext, TaskId, TaskMeta, Worker, WorkerResult}; +use std::cell::RefCell; + +/// Applies the given `worker` to the given `input` using the given `task_meta`. +/// +/// Returns a tuple of the apply result, the (possibly modified) task metadata, and the IDs of any +/// subtasks that were submitted. +pub fn apply( + input: W::Input, + task_meta: TaskMeta, + worker: &mut W, +) -> (WorkerResult, TaskMeta, Option>) { + let local = MockLocalContext::new(task_meta.id()); + let ctx = Context::new(task_meta, Some(&local)); + let result = worker.apply(input, &ctx); + let (task_meta, subtask_ids) = ctx.into_parts(); + (result, task_meta, subtask_ids) +} + +#[derive(Debug, Default)] +pub struct MockLocalContext(RefCell); + +impl MockLocalContext { + pub fn new(task_id: TaskId) -> Self { + Self(RefCell::new(task_id)) + } + + pub fn into_task_count(self) -> usize { + self.0.into_inner() + } +} + +impl LocalContext for MockLocalContext { + fn should_cancel_tasks(&self) -> bool { + false + } + + fn submit_task(&self, _: I) -> super::TaskId { + let mut task_id = self.0.borrow_mut(); + let cur_id = *task_id; + *task_id += 1; + cur_id + } +} diff --git a/src/bee/mod.rs b/src/bee/mod.rs index 7b1027d..de3938a 100644 --- a/src/bee/mod.rs +++ b/src/bee/mod.rs @@ -1,3 +1,4 @@ +#![cfg_attr(coverage_nightly, feature(coverage_attribute))] //! Traits for defining workers in the worker pool. //! //! A [`Hive`](crate::hive::Hive) is populated by bees: @@ -111,13 +112,13 @@ //! workers, the queen, and/or the client thread(s). mod context; mod error; +#[cfg_attr(coverage_nightly, coverage(off))] +pub mod mock; mod queen; pub mod stock; mod worker; -#[cfg(test)] -pub use self::context::mock::MockTaskContext; -pub use self::context::{Context, TaskContext, TaskId}; +pub use self::context::{Context, LocalContext, TaskId, TaskMeta}; pub use self::error::{ApplyError, ApplyRefError}; pub use self::queen::{CloneQueen, DefaultQueen, Queen, QueenCell, QueenMut}; pub use self::worker::{RefWorker, RefWorkerResult, Worker, WorkerError, WorkerResult}; diff --git a/src/bee/queen.rs b/src/bee/queen.rs index 8405086..0f670f8 100644 --- a/src/bee/queen.rs +++ b/src/bee/queen.rs @@ -122,3 +122,6 @@ impl Queen for CloneQueen { self.0.clone() } } + +#[cfg(test)] +mod tests {} diff --git a/src/hive/builder/mod.rs b/src/hive/builder/mod.rs index 2bb906c..7c8e790 100644 --- a/src/hive/builder/mod.rs +++ b/src/hive/builder/mod.rs @@ -90,8 +90,8 @@ use crate::hive::inner::{BuilderConfig, Token}; // } // } -// #[cfg(all(test, feature = "batching"))] -// mod batching_tests { +// #[cfg(all(test, feature = "local-batch"))] +// mod local_batch_tests { // use super::OpenBuilder; // } diff --git a/src/hive/context.rs b/src/hive/context.rs new file mode 100644 index 0000000..a958da2 --- /dev/null +++ b/src/hive/context.rs @@ -0,0 +1,63 @@ +use crate::bee::{LocalContext, Queen, TaskId, Worker}; +use crate::hive::{OutcomeSender, Shared, TaskQueues, WorkerQueues}; +use std::fmt; +use std::sync::Arc; + +pub struct HiveLocalContext<'a, W, Q, T> +where + W: Worker, + Q: Queen, + T: TaskQueues, +{ + worker_queues: &'a T::WorkerQueues, + shared: &'a Arc>, + outcome_tx: Option<&'a OutcomeSender>, +} + +impl<'a, W, Q, T> HiveLocalContext<'a, W, Q, T> +where + W: Worker, + Q: Queen, + T: TaskQueues, +{ + pub fn new( + worker_queues: &'a T::WorkerQueues, + shared: &'a Arc>, + outcome_tx: Option<&'a OutcomeSender>, + ) -> Self { + Self { + worker_queues, + shared, + outcome_tx, + } + } +} + +impl LocalContext for HiveLocalContext<'_, W, Q, T> +where + W: Worker, + Q: Queen, + T: TaskQueues, +{ + fn should_cancel_tasks(&self) -> bool { + self.shared.is_suspended() + } + + fn submit_task(&self, input: W::Input) -> TaskId { + let task = self.shared.prepare_task(input, self.outcome_tx); + let task_id = task.id(); + self.worker_queues.push(task); + task_id + } +} + +impl fmt::Debug for HiveLocalContext<'_, W, Q, T> +where + W: Worker, + Q: Queen, + T: TaskQueues, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("HiveLocalContext").finish() + } +} diff --git a/src/hive/hive.rs b/src/hive/hive.rs index 27e4673..488659b 100644 --- a/src/hive/hive.rs +++ b/src/hive/hive.rs @@ -1,15 +1,15 @@ use super::{ - ChannelBuilder, ChannelTaskQueues, Config, DerefOutcomes, Husk, Outcome, OutcomeBatch, - OutcomeIteratorExt, OutcomeSender, Shared, SpawnError, TaskQueues, TaskQueuesBuilder, - WorkerQueues, + ChannelBuilder, ChannelTaskQueues, Config, DerefOutcomes, HiveLocalContext, Husk, Outcome, + OutcomeBatch, OutcomeIteratorExt, OutcomeSender, Sentinel, Shared, SpawnError, TaskInput, + TaskQueues, TaskQueuesBuilder, }; -use crate::bee::{DefaultQueen, Queen, TaskContext, TaskId, Worker}; +use crate::bee::{ApplyError, Context, DefaultQueen, Queen, TaskId, Worker}; use std::borrow::Borrow; use std::collections::HashMap; use std::fmt; use std::ops::{Deref, DerefMut}; use std::sync::Arc; -use std::thread::{self, JoinHandle}; +use std::thread::JoinHandle; #[derive(thiserror::Error, Debug)] #[error("The hive has been poisoned")] @@ -35,7 +35,7 @@ impl> Hive { impl, T: TaskQueues> Hive { /// Spawns a new worker thread with the specified index and with access to the `shared` data. - fn try_spawn( + pub fn try_spawn( thread_index: usize, shared: &Arc>, ) -> Result, SpawnError> { @@ -43,10 +43,13 @@ impl, T: TaskQueues> Hive { let shared = Arc::clone(shared); // spawn a thread that executes the worker loop thread_builder.spawn(move || { - // perform one-time initialization of the worker thread - Self::init_thread(thread_index, &shared); + #[cfg(feature = "affinity")] + if let Some(core) = shared.get_core_affinity(thread_index) { + // try to pin the worker thread to a specific CPU core. + core.try_pin_current(); + } // create a Sentinel that will spawn a new thread on panic until it is cancelled - let sentinel = Sentinel::new(thread_index, Arc::clone(&shared)); + let sentinel = Sentinel::new(thread_index, Arc::clone(&shared), Self::try_spawn); // get the thread-local interface to the task queues let worker_queues = shared.worker_queues(thread_index); // create a new worker to process tasks @@ -54,8 +57,34 @@ impl, T: TaskQueues> Hive { // execute the main loop: get the next task to process, which decrements the queued // counter and increments the active counter while let Some(task) = shared.get_next_task(&worker_queues) { - // execute the task and dispose of the outcome - Self::execute(task, &mut worker, &worker_queues, &shared); + let (input, task_meta, outcome_tx) = task.into_parts(); + let local_ctx = HiveLocalContext::new(&worker_queues, &shared, outcome_tx.as_ref()); + let apply_ctx = Context::new(task_meta, Some(&local_ctx)); + // execute the task until it succeeds or we reach maximum retries - this should + // be the only place where a panic can occur + let result = worker.apply(input, &apply_ctx); + let (task_meta, subtask_ids) = apply_ctx.into_parts(); + let outcome = match result { + #[cfg(feature = "retry")] + Err(ApplyError::Retryable { input, error }) + if subtask_ids.is_none() && shared.can_retry(&task_meta) => + { + match shared.try_send_retry( + input, + task_meta, + outcome_tx.as_ref(), + &worker_queues, + ) { + Ok(_) => return, + Err(task) => { + let (input, task_meta, _) = task.into_parts(); + Outcome::from_fatal(input, task_meta, error) + } + } + } + result => Outcome::from_worker_result(result, task_meta, subtask_ids), + }; + shared.send_or_store_outcome(outcome, outcome_tx); // finish the task - decrements the active counter and notifies other threads shared.finish_task(false); } @@ -99,7 +128,7 @@ impl, T: TaskQueues> Hive { /// Sends one `input` to the `Hive` for procesing and returns the result, blocking until the /// result is available. Creates a channel to send the input and receive the outcome. Returns /// an [`Outcome`] with the task output or an error. - pub fn apply(&self, input: W::Input) -> Outcome { + pub fn apply>>(&self, input: I) -> Outcome { let (tx, rx) = super::outcome_channel(); let task_id = self.shared().send_one_global(input, Some(&tx)); drop(tx); @@ -108,7 +137,7 @@ impl, T: TaskQueues> Hive { /// Sends one `input` to the `Hive` for processing and returns its ID. The [`Outcome`] of /// the task will be sent to `tx` upon completion. - pub fn apply_send(&self, input: W::Input, outcome_tx: X) -> TaskId + pub fn apply_send>, X>(&self, input: I, outcome_tx: X) -> TaskId where X: Borrow>, { @@ -118,7 +147,7 @@ impl, T: TaskQueues> Hive { /// Sends one `input` to the `Hive` for processing and returns its ID immediately. The /// [`Outcome`] of the task will be retained and available for later retrieval. - pub fn apply_store(&self, input: W::Input) -> TaskId { + pub fn apply_store>>(&self, input: I) -> TaskId { self.shared().send_one_global(input, None) } @@ -127,9 +156,10 @@ impl, T: TaskQueues> Hive { /// /// This method is more efficient than [`map`](Self::map) when the input is an /// [`ExactSizeIterator`]. - pub fn swarm(&self, batch: B) -> impl Iterator> + use + pub fn swarm(&self, batch: B) -> impl Iterator> + use where - B: IntoIterator, + I: Into>, + B: IntoIterator, B::IntoIter: ExactSizeIterator, { let (tx, rx) = super::outcome_channel(); @@ -145,9 +175,13 @@ impl, T: TaskQueues> Hive { /// instead receive the `Outcome`s in the order they were submitted. This method is more /// efficient than [`map_unordered`](Self::map_unordered) when the input is an /// [`ExactSizeIterator`]. - pub fn swarm_unordered(&self, batch: B) -> impl Iterator> + use + pub fn swarm_unordered( + &self, + batch: B, + ) -> impl Iterator> + use where - B: IntoIterator, + I: Into>, + B: IntoIterator, B::IntoIter: ExactSizeIterator, { let (tx, rx) = super::outcome_channel(); @@ -160,11 +194,12 @@ impl, T: TaskQueues> Hive { /// /// This method is more efficient than [`map_send`](Self::map_send) when the input is an /// [`ExactSizeIterator`]. - pub fn swarm_send(&self, batch: B, outcome_tx: S) -> Vec + pub fn swarm_send(&self, batch: B, outcome_tx: S) -> Vec where - S: Borrow>, - B: IntoIterator, + I: Into>, + B: IntoIterator, B::IntoIter: ExactSizeIterator, + S: Borrow>, { self.shared() .send_batch_global(batch, Some(outcome_tx.borrow())) @@ -174,9 +209,10 @@ impl, T: TaskQueues> Hive { /// The [`Outcome`]s of the task are retained and available for later retrieval. /// /// This method is more efficient than `map_store` when the input is an [`ExactSizeIterator`]. - pub fn swarm_store(&self, batch: B) -> Vec + pub fn swarm_store(&self, batch: B) -> Vec where - B: IntoIterator, + I: Into>, + B: IntoIterator, B::IntoIter: ExactSizeIterator, { self.shared().send_batch_global(batch, None) @@ -186,9 +222,10 @@ impl, T: TaskQueues> Hive { /// iterator over the [`Outcome`]s in the same order as the inputs. /// /// [`swarm`](Self::swarm) should be preferred when `inputs` is an [`ExactSizeIterator`]. - pub fn map(&self, batch: B) -> impl Iterator> + use + pub fn map(&self, batch: B) -> impl Iterator> + use where - B: IntoIterator, + I: Into>, + B: IntoIterator, { let (tx, rx) = super::outcome_channel(); let task_ids: Vec<_> = batch @@ -204,9 +241,13 @@ impl, T: TaskQueues> Hive { /// /// [`swarm_unordered`](Self::swarm_unordered) should be preferred when `inputs` is an /// [`ExactSizeIterator`]. - pub fn map_unordered(&self, batch: B) -> impl Iterator> + use + pub fn map_unordered( + &self, + batch: B, + ) -> impl Iterator> + use where - B: IntoIterator, + I: Into>, + B: IntoIterator, { let (tx, rx) = super::outcome_channel(); // `map` is required (rather than `inspect`) because we need owned items @@ -223,9 +264,10 @@ impl, T: TaskQueues> Hive { /// /// [`swarm_send`](Self::swarm_send) should be preferred when `inputs` is an /// [`ExactSizeIterator`]. - pub fn map_send(&self, batch: B, outcome_tx: X) -> Vec + pub fn map_send(&self, batch: B, outcome_tx: X) -> Vec where - B: IntoIterator, + I: Into>, + B: IntoIterator, X: Borrow>, { batch @@ -239,9 +281,10 @@ impl, T: TaskQueues> Hive { /// /// [`swarm_store`](Self::swarm_store) should be preferred when `inputs` is an /// [`ExactSizeIterator`]. - pub fn map_store(&self, batch: B) -> Vec + pub fn map_store(&self, batch: B) -> Vec where - B: IntoIterator, + I: Into>, + B: IntoIterator, { batch .into_iter() @@ -252,10 +295,11 @@ impl, T: TaskQueues> Hive { /// Iterates over `items` and calls `f` with a mutable reference to a state value (initialized /// to `init`) and each item. `F` returns an input that is sent to the `Hive` for processing. /// Returns an [`OutcomeBatch`] of the outputs and the final state value. - pub fn scan(&self, batch: B, init: S, f: F) -> (OutcomeBatch, S) + pub fn scan(&self, batch: B, init: S, f: F) -> (OutcomeBatch, S) where B: IntoIterator, - F: FnMut(&mut S, I) -> W::Input, + O: Into>, + F: FnMut(&mut S, I) -> O, { let (tx, rx) = super::outcome_channel(); let (task_ids, fold_value) = self.scan_send(batch, &tx, init, f); @@ -268,15 +312,16 @@ impl, T: TaskQueues> Hive { /// to `init`) and each item. `F` returns an input that is sent to the `Hive` for processing, /// or an error. Returns an [`OutcomeBatch`] of the outputs, a [`Vec`] of errors, and the final /// state value. - pub fn try_scan( + pub fn try_scan( &self, batch: B, init: S, mut f: F, ) -> (OutcomeBatch, Vec, S) where + O: Into>, B: IntoIterator, - F: FnMut(&mut S, I) -> Result, + F: FnMut(&mut S, I) -> Result, { let (tx, rx) = super::outcome_channel(); let (task_ids, errors, fold_value) = batch.into_iter().fold( @@ -298,7 +343,7 @@ impl, T: TaskQueues> Hive { /// to `init`) and each item. `f` returns an input that is sent to the `Hive` for processing. /// The outputs are sent to `tx` in the order they become available. Returns a [`Vec`] of the /// task IDs and the final state value. - pub fn scan_send( + pub fn scan_send( &self, batch: B, outcome_tx: X, @@ -306,9 +351,10 @@ impl, T: TaskQueues> Hive { mut f: F, ) -> (Vec, S) where + O: Into>, B: IntoIterator, X: Borrow>, - F: FnMut(&mut S, I) -> W::Input, + F: FnMut(&mut S, I) -> O, { batch .into_iter() @@ -324,7 +370,7 @@ impl, T: TaskQueues> Hive { /// or an error. The outputs are sent to `tx` in the order they become available. This /// function returns the final state value and a [`Vec`] of results, where each result is /// either a task ID or an error. - pub fn try_scan_send( + pub fn try_scan_send( &self, batch: B, outcome_tx: X, @@ -332,9 +378,10 @@ impl, T: TaskQueues> Hive { mut f: F, ) -> (Vec>, S) where + O: Into>, B: IntoIterator, X: Borrow>, - F: FnMut(&mut S, I) -> Result, + F: FnMut(&mut S, I) -> Result, { batch .into_iter() @@ -350,10 +397,11 @@ impl, T: TaskQueues> Hive { /// to `init`) and each item. `f` returns an input that is sent to the `Hive` for processing. /// This function returns the final state value and a [`Vec`] of task IDs. The [`Outcome`]s of /// the tasks are retained and available for later retrieval. - pub fn scan_store(&self, batch: B, init: S, mut f: F) -> (Vec, S) + pub fn scan_store(&self, batch: B, init: S, mut f: F) -> (Vec, S) where + O: Into>, B: IntoIterator, - F: FnMut(&mut S, I) -> W::Input, + F: FnMut(&mut S, I) -> O, { batch .into_iter() @@ -369,15 +417,16 @@ impl, T: TaskQueues> Hive { /// or an error. This function returns the final value of the state value and a [`Vec`] of /// results, where each result is either a task ID or an error. The [`Outcome`]s of the /// tasks are retained and available for later retrieval. - pub fn try_scan_store( + pub fn try_scan_store( &self, batch: B, init: S, mut f: F, ) -> (Vec>, S) where + O: Into>, B: IntoIterator, - F: FnMut(&mut S, I) -> Result, + F: FnMut(&mut S, I) -> Result, { batch .into_iter() @@ -708,79 +757,11 @@ where } } -/// Sentinel for a worker thread. Until the sentinel is cancelled, it will respawn the worker -/// thread if it panics. -struct Sentinel -where - W: Worker, - Q: Queen, - T: TaskQueues, -{ - thread_index: usize, - shared: Arc>, - active: bool, -} - -impl Sentinel -where - W: Worker, - Q: Queen, - T: TaskQueues, -{ - fn new(thread_index: usize, shared: Arc>) -> Self { - Self { - thread_index, - shared, - active: true, - } - } - - /// Cancel and destroy this sentinel. - fn cancel(mut self) { - self.active = false; - } -} - -impl Drop for Sentinel -where - W: Worker, - Q: Queen, - T: TaskQueues, -{ - fn drop(&mut self) { - if self.active { - // if the sentinel is active, that means the thread panicked during task execution, so - // we have to finish the task here before respawning - self.shared.finish_task(thread::panicking()); - // only respawn if the sentinel is active and the hive has not been poisoned - if !self.shared.is_poisoned() { - // can't do anything with the previous result - let _ = self - .shared - .respawn_thread(self.thread_index, |thread_index| { - Hive::try_spawn(thread_index, &self.shared) - }); - } - } - } -} - -#[cfg(not(feature = "affinity"))] -mod no_affinity { - use crate::bee::{Queen, Worker}; - use crate::hive::{Hive, Shared, TaskQueues}; - - impl, T: TaskQueues> Hive { - #[inline] - pub(super) fn init_thread(_: usize, _: &Shared) {} - } -} - #[cfg(feature = "affinity")] mod affinity { use crate::bee::{Queen, Worker}; use crate::hive::cores::Cores; - use crate::hive::{Hive, Poisoned, Shared, TaskQueues}; + use crate::hive::{Hive, Poisoned, TaskQueues}; impl Hive where @@ -788,14 +769,6 @@ mod affinity { Q: Queen, T: TaskQueues, { - /// Tries to pin the worker thread to a specific CPU core. - #[inline] - pub(super) fn init_thread(thread_index: usize, shared: &Shared) { - if let Some(core) = shared.get_core_affinity(thread_index) { - core.try_pin_current(); - } - } - /// Attempts to increase the number of worker threads by `num_threads`. /// /// The provided `affinity` specifies additional CPU core indices to which the worker @@ -824,8 +797,8 @@ mod affinity { } } -#[cfg(feature = "batching")] -mod batching { +#[cfg(feature = "local-batch")] +mod local_batch { use crate::bee::{Queen, Worker}; use crate::hive::{Hive, TaskQueues}; @@ -851,46 +824,10 @@ mod batching { } } -#[cfg(not(feature = "retry"))] -mod no_retry { - use super::HiveTaskContext; - use crate::bee::{Context, Queen, Worker}; - use crate::hive::{Hive, Outcome, Shared, Task, TaskQueues}; - use std::sync::Arc; - - impl Hive - where - W: Worker, - Q: Queen, - T: TaskQueues, - { - pub(super) fn execute( - task: Task, - worker: &mut W, - worker_queues: &T::WorkerQueues, - shared: &Arc>, - ) { - let (task_id, input, outcome_tx) = task.into_parts(); - let task_ctx = HiveTaskContext { - worker_queues, - shared, - outcome_tx: outcome_tx.as_ref(), - }; - let ctx = Context::new(task_id, Some(&task_ctx)); - let result = worker.apply(input, &ctx); - let subtask_ids = ctx.into_subtask_ids(); - let outcome = Outcome::from_worker_result(result, task_id, subtask_ids); - shared.send_or_store_outcome(outcome, outcome_tx); - } - } -} - #[cfg(feature = "retry")] mod retry { - use super::HiveTaskContext; - use crate::bee::{ApplyError, Context, Queen, Worker}; - use crate::hive::{Hive, Outcome, Shared, Task, TaskQueues}; - use std::sync::Arc; + use crate::bee::{Queen, Worker}; + use crate::hive::{Hive, TaskQueues}; use std::time::Duration; impl Hive @@ -918,90 +855,6 @@ mod retry { pub fn set_worker_retry_factor(&self, duration: Duration) -> Duration { self.shared().set_worker_retry_factor(duration) } - - pub(super) fn execute( - task: Task, - worker: &mut W, - worker_queues: &T::WorkerQueues, - shared: &Arc>, - ) { - let (task_id, input, attempt, outcome_tx) = task.into_parts(); - let task_ctx = HiveTaskContext { - worker_queues, - shared, - outcome_tx: outcome_tx.as_ref(), - }; - let ctx = Context::new(task_id, attempt, Some(&task_ctx)); - // execute the task until it succeeds or we reach maximum retries - this should - // be the only place where a panic can occur - let result = worker.apply(input, &ctx); - let subtask_ids = ctx.into_subtask_ids(); - #[cfg(feature = "retry")] - let result = match result { - Err(ApplyError::Retryable { input, error }) - if subtask_ids.is_none() && shared.can_retry(attempt) => - { - match shared.try_send_retry( - task_id, - input, - outcome_tx.as_ref(), - attempt + 1, - worker_queues, - ) { - Ok(_) => return, - Err(task) => Result::>::Err( - ApplyError::Fatal { - input: Some(task.into_parts().1), - error, - }, - ), - } - } - result => result, - }; - let outcome = Outcome::from_worker_result(result, task_id, subtask_ids); - shared.send_or_store_outcome(outcome, outcome_tx); - } - } -} - -struct HiveTaskContext<'a, W, Q, T> -where - W: Worker, - Q: Queen, - T: TaskQueues, -{ - worker_queues: &'a T::WorkerQueues, - shared: &'a Arc>, - outcome_tx: Option<&'a OutcomeSender>, -} - -impl TaskContext for HiveTaskContext<'_, W, Q, T> -where - W: Worker, - Q: Queen, - T: TaskQueues, -{ - fn should_cancel_tasks(&self) -> bool { - self.shared.is_suspended() - } - - fn submit_task(&self, input: W::Input) -> TaskId { - let task = self.shared.prepare_task(input, self.outcome_tx); - let task_id = task.id(); - self.worker_queues.push(task); - task_id - } -} - -impl fmt::Debug for HiveTaskContext<'_, W, Q, T> -where - W: Worker, - Q: Queen, - T: TaskQueues, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("HiveTaskContext").finish() } } diff --git a/src/hive/inner/builder.rs b/src/hive/inner/builder.rs index 856f304..6a0cdc3 100644 --- a/src/hive/inner/builder.rs +++ b/src/hive/inner/builder.rs @@ -198,12 +198,12 @@ pub trait Builder: BuilderConfig + Sized { /// Sets the worker thread batch size. /// - /// This may have no effect if the `batching` feature is disabled, or if the `TaskQueues` + /// This may have no effect if the `local-batch` feature is disabled, or if the `TaskQueues` /// implementation used for this hive does not support local batching. /// - /// If `batch_limit` is `0`, batching is effectively disabled, but note that the performance - /// may be worse than with the `batching` feature disabled. - #[cfg(feature = "batching")] + /// If `batch_limit` is `0`, local batching is effectively disabled, but note that the + /// performance may be worse than with the `local-batch` feature disabled. + #[cfg(feature = "local-batch")] fn batch_limit(mut self, batch_limit: usize) -> Self { if batch_limit == 0 { self.config_ref(Token).batch_limit.set(None); @@ -214,7 +214,7 @@ pub trait Builder: BuilderConfig + Sized { } /// Sets the worker thread batch size to the global default value. - #[cfg(feature = "batching")] + #[cfg(feature = "local-batch")] fn with_default_batch_limit(mut self) -> Self { let _ = self .config_ref(Token) @@ -223,6 +223,36 @@ pub trait Builder: BuilderConfig + Sized { self } + /// Sets the maximum weight of the tasks a worker thread can have at any given time. + /// + /// If `weight_limit` is `0`, weighting is effectively disabled, but note that the performance + /// may be worse than with the `weighting` feature disabled. + /// + /// If a task has a weight greater than the limit, it is immediately converted to + /// `Outcome::WeightLimitExceeded` and sent or stored. + /// + /// If the `local-batch` feature is enabled, this limit determines the maximum total "weight" of + /// active and pending tasks in the worker's local queue. + #[cfg(feature = "local-batch")] + fn weight_limit(mut self, weight_limit: u64) -> Self { + if weight_limit == 0 { + self.config_ref(Token).weight_limit.set(None); + } else { + self.config_ref(Token).weight_limit.set(Some(weight_limit)); + } + self + } + + /// Sets the worker thread batch size to the global default value. + #[cfg(feature = "local-batch")] + fn with_default_weight_limit(mut self) -> Self { + let _ = self + .config_ref(Token) + .weight_limit + .set(super::config::DEFAULTS.lock().weight_limit.get()); + self + } + /// Sets the maximum number of times to retry a /// [`ApplyError::Retryable`](crate::bee::ApplyError::Retryable) error. A worker /// thread will retry a task until it either returns diff --git a/src/hive/inner/config.rs b/src/hive/inner/config.rs index 7d4f0b8..c943b35 100644 --- a/src/hive/inner/config.rs +++ b/src/hive/inner/config.rs @@ -1,5 +1,7 @@ -#[cfg(feature = "batching")] -pub use self::batching::set_batch_limit_default; +#[cfg(feature = "local-batch")] +pub use self::local_batch::set_batch_limit_default; +#[cfg(feature = "local-batch")] +pub use self::local_batch::set_weight_limit_default; #[cfg(feature = "retry")] pub use self::retry::{ set_max_retries_default, set_retries_default_disabled, set_retry_factor_default, @@ -43,8 +45,10 @@ impl Config { thread_stack_size: Default::default(), #[cfg(feature = "affinity")] affinity: Default::default(), - #[cfg(feature = "batching")] + #[cfg(feature = "local-batch")] batch_limit: Default::default(), + #[cfg(feature = "local-batch")] + weight_limit: Default::default(), #[cfg(feature = "retry")] max_retries: Default::default(), #[cfg(feature = "retry")] @@ -55,7 +59,7 @@ impl Config { /// Resets config values to their pre-configured defaults. fn set_const_defaults(&mut self) { self.num_threads.set(Some(DEFAULT_NUM_THREADS)); - #[cfg(feature = "batching")] + #[cfg(feature = "local-batch")] self.set_batch_const_defaults(); #[cfg(feature = "retry")] self.set_retry_const_defaults(); @@ -69,8 +73,10 @@ impl Config { thread_stack_size: self.thread_stack_size.into_sync(), #[cfg(feature = "affinity")] affinity: self.affinity.into_sync(), - #[cfg(feature = "batching")] + #[cfg(feature = "local-batch")] batch_limit: self.batch_limit.into_sync_default(), + #[cfg(feature = "local-batch")] + weight_limit: self.weight_limit.into_sync_default(), #[cfg(feature = "retry")] max_retries: self.max_retries.into_sync_default(), #[cfg(feature = "retry")] @@ -87,8 +93,10 @@ impl Config { thread_stack_size: self.thread_stack_size.into_unsync(), #[cfg(feature = "affinity")] affinity: self.affinity.into_unsync(), - #[cfg(feature = "batching")] + #[cfg(feature = "local-batch")] batch_limit: self.batch_limit.into_unsync(), + #[cfg(feature = "local-batch")] + weight_limit: self.weight_limit.into_unsync(), #[cfg(feature = "retry")] max_retries: self.max_retries.into_unsync(), #[cfg(feature = "retry")] @@ -143,8 +151,8 @@ mod tests { } } -#[cfg(feature = "batching")] -mod batching { +#[cfg(feature = "local-batch")] +mod local_batch { use super::{Config, DEFAULTS}; const DEFAULT_BATCH_LIMIT: usize = 10; @@ -152,10 +160,14 @@ mod batching { pub fn set_batch_limit_default(batch_limit: usize) { DEFAULTS.lock().batch_limit.set(Some(batch_limit)); } + pub fn set_weight_limit_default(weight_limit: u64) { + DEFAULTS.lock().weight_limit.set(Some(weight_limit)); + } impl Config { pub(super) fn set_batch_const_defaults(&mut self) { self.batch_limit.set(Some(DEFAULT_BATCH_LIMIT)); + self.weight_limit.set(None); } } } diff --git a/src/hive/inner/mod.rs b/src/hive/inner/mod.rs index eb2c41a..bcd5fa1 100644 --- a/src/hive/inner/mod.rs +++ b/src/hive/inner/mod.rs @@ -7,9 +7,9 @@ mod shared; mod task; pub mod set_config { - #[cfg(feature = "batching")] - pub use super::config::set_batch_limit_default; pub use super::config::{reset_defaults, set_num_threads_default, set_num_threads_default_all}; + #[cfg(feature = "local-batch")] + pub use super::config::{set_batch_limit_default, set_weight_limit_default}; #[cfg(feature = "retry")] pub use super::config::{ set_max_retries_default, set_retries_default_disabled, set_retry_factor_default, @@ -21,12 +21,13 @@ pub mod set_config { pub use self::builder::{Builder, BuilderConfig}; pub use self::queue::{ChannelTaskQueues, TaskQueues, WorkerQueues, WorkstealingTaskQueues}; +pub use self::task::TaskInput; use self::counter::DualCounter; use self::gate::{Gate, PhasedGate}; use self::queue::PopTaskError; use crate::atomic::{AtomicAny, AtomicBool, AtomicOption, AtomicUsize}; -use crate::bee::{Queen, TaskId, Worker}; +use crate::bee::{Queen, TaskMeta, Worker}; use crate::hive::{OutcomeQueue, OutcomeSender, SpawnError}; use parking_lot::Mutex; use std::thread::JoinHandle; @@ -44,11 +45,9 @@ pub struct Token; /// Internal representation of a task to be processed by a `Hive`. #[derive(Debug)] pub struct Task { - id: TaskId, input: W::Input, + meta: TaskMeta, outcome_tx: Option>, - #[cfg(feature = "retry")] - attempt: u32, } /// Data shared by all worker threads in a `Hive`. @@ -99,8 +98,11 @@ pub struct Config { #[cfg(feature = "affinity")] affinity: Any, /// Maximum number of tasks for a worker thread to take when receiving from the input channel - #[cfg(feature = "batching")] + #[cfg(feature = "local-batch")] batch_limit: Usize, + /// Maximum "weight" of tasks a worker thread may have active and pending + #[cfg(feature = "local-batch")] + weight_limit: U64, /// Maximum number of retries for a task #[cfg(feature = "retry")] max_retries: U32, diff --git a/src/hive/inner/queue/channel.rs b/src/hive/inner/queue/channel.rs index 9e9c9c9..99c5be5 100644 --- a/src/hive/inner/queue/channel.rs +++ b/src/hive/inner/queue/channel.rs @@ -1,6 +1,6 @@ //! Implementation of `TaskQueues` that uses `crossbeam` channels for the global queue (i.e., for //! sending tasks from the `Hive` to the worker threads) and a default implementation of local -//! queues that depends on which combination of the `retry` and `batching` features are enabled. +//! queues that depends on which combination of the `retry` and `local-batch` features are enabled. use super::{Config, PopTaskError, Status, Task, TaskQueues, Token, WorkerQueues}; use crate::bee::Worker; use crossbeam_channel::RecvTimeoutError; @@ -127,7 +127,7 @@ impl GlobalQueue { tasks.extend(self.global_rx.try_iter()); } - #[cfg(feature = "batching")] + #[cfg(feature = "local-batch")] fn try_iter(&self) -> impl Iterator> { self.global_rx.try_iter() } @@ -166,9 +166,9 @@ struct LocalQueueShared { _thread_index: usize, /// queue of abandon tasks local_abandoned: SegQueue>, - /// thread-local queue of tasks used when the `batching` feature is enabled - #[cfg(feature = "batching")] - local_batch: batching::WorkerBatchQueue, + /// thread-local queue of tasks used when the `local-batch` feature is enabled + #[cfg(feature = "local-batch")] + local_batch: local_batch::WorkerBatchQueue, /// thread-local queues used for tasks that are waiting to be retried after a failure #[cfg(feature = "retry")] local_retry: super::RetryQueue, @@ -179,20 +179,27 @@ impl LocalQueueShared { Self { _thread_index: thread_index, local_abandoned: Default::default(), - #[cfg(feature = "batching")] - local_batch: batching::WorkerBatchQueue::new(_config.batch_limit.get_or_default()), + #[cfg(feature = "local-batch")] + local_batch: local_batch::WorkerBatchQueue::new( + _config.batch_limit.get_or_default(), + _config.weight_limit.get_or_default(), + ), #[cfg(feature = "retry")] local_retry: super::RetryQueue::new(_config.retry_factor.get_or_default()), } } /// Updates the local queues based on the provided `config`: - /// If `batching` is enabled, resizes the batch queue if necessary. + /// If `local-batch` is enabled, resizes the batch queue if necessary. /// If `retry` is enabled, updates the retry factor. fn update(&self, _global: &GlobalQueue, _config: &Config) { - #[cfg(feature = "batching")] - self.local_batch - .set_limit(_config.batch_limit.get_or_default(), _global, self); + #[cfg(feature = "local-batch")] + self.local_batch.set_limits( + _config.batch_limit.get_or_default(), + _config.weight_limit.get_or_default(), + _global, + self, + ); #[cfg(feature = "retry")] self.local_retry .set_delay_factor(_config.retry_factor.get_or_default()); @@ -200,7 +207,7 @@ impl LocalQueueShared { #[inline] fn push(&self, task: Task, global: &GlobalQueue) { - #[cfg(feature = "batching")] + #[cfg(feature = "local-batch")] let task = match self.local_batch.try_push(task) { Ok(_) => return, Err(task) => task, @@ -231,14 +238,14 @@ impl LocalQueueShared { if let Some(task) = self.local_retry.try_pop() { return Ok(task); } - // if batching is enabled, try to get a task from the batch queue - // and try to refill it from the global queue if it's empty - #[cfg(feature = "batching")] + // if local batching is enabled, try to get a task from the batch queue and try to refill + // it from the global queue if it's empty + #[cfg(feature = "local-batch")] { self.local_batch.try_pop_or_refill(global, self) } // fall back to requesting a task from the global queue - #[cfg(not(feature = "batching"))] + #[cfg(not(feature = "local-batch"))] { global.try_pop() } @@ -255,17 +262,17 @@ impl LocalQueueShared { while let Some(task) = self.local_abandoned.pop() { tasks.push(task); } - #[cfg(feature = "batching")] + #[cfg(feature = "local-batch")] self.local_batch.drain_into(tasks); #[cfg(feature = "retry")] self.local_retry.drain_into(tasks); } } -#[cfg(feature = "batching")] -mod batching { +#[cfg(feature = "local-batch")] +mod local_batch { use super::{GlobalQueue, LocalQueueShared, Task}; - use crate::atomic::{Atomic, AtomicUsize}; + use crate::atomic::{Atomic, AtomicU64, AtomicUsize}; use crate::bee::Worker; use crate::hive::inner::queue::PopTaskError; use crossbeam_queue::ArrayQueue; @@ -273,40 +280,42 @@ mod batching { pub struct WorkerBatchQueue { inner: RwLock>>>, - limit: AtomicUsize, + batch_limit: AtomicUsize, + weight_limit: AtomicU64, } impl WorkerBatchQueue { - pub fn new(batch_limit: usize) -> Self { - if batch_limit == 0 { - Self { - inner: RwLock::new(None), - limit: Default::default(), - } + pub fn new(batch_limit: usize, weight_limit: u64) -> Self { + let inner = if batch_limit > 0 { + Some(ArrayQueue::new(batch_limit)) } else { - Self { - inner: RwLock::new(Some(ArrayQueue::new(batch_limit))), - limit: AtomicUsize::new(batch_limit), - } + None + }; + Self { + inner: RwLock::new(inner), + batch_limit: AtomicUsize::new(batch_limit), + weight_limit: AtomicU64::new(weight_limit), } } - pub fn set_limit( + pub fn set_limits( &self, - limit: usize, + batch_limit: usize, + weight_limit: u64, global: &GlobalQueue, parent: &LocalQueueShared, ) { + self.weight_limit.set(weight_limit); // acquire the exclusive lock first to prevent simultaneous updates let mut queue = self.inner.write(); - let old_limit = self.limit.set(limit); - if old_limit == limit { + let old_limit = self.batch_limit.set(batch_limit); + if old_limit == batch_limit { return; } - let old_queue = if limit == 0 { + let old_queue = if batch_limit == 0 { queue.take() } else { - queue.replace(ArrayQueue::new(limit)) + queue.replace(ArrayQueue::new(batch_limit)) }; if let Some(old_queue) = old_queue { // try to push tasks from the old queue to the new one and fall back to pushing @@ -345,18 +354,30 @@ mod batching { } } // otherwise pull at least 1 and up to `batch_limit + 1` tasks from the input channel - // wait for the next task from the receiver let first = global.try_pop()?; // if we succeed in getting the first task, try to refill the local queue - let limit = self.limit.get(); - // batch size 0 means batching is disabled - if limit > 0 { - // otherwise try to take up to `batch_limit` tasks from the input channel - // and add them to the local queue, but don't block if the input channel - // is empty - for task in global.try_iter().take(limit) { - if let Err(task) = local.push(task) { - parent.local_abandoned.push(task); + let batch_limit = self.batch_limit.get(); + // batch size 0 means local batching is disabled + if batch_limit > 0 { + let mut iter = global.try_iter(); + let mut batch_size = 0; + let mut total_weight = first.meta.weight() as u64; + let weight_limit = self.weight_limit.get(); + // try to take up to `batch_limit` tasks from the input channel and add them + // to the local queue, but don't block if the input channel is empty; stop + // early if the weight of the queued tasks exceeds the limit + while batch_size < batch_limit + && (weight_limit == 0 || total_weight < weight_limit) + { + if let Some(task) = iter.next() { + let task_weight = task.meta.weight() as u64; + if let Err(task) = local.push(task) { + parent.local_abandoned.push(task); + break; + } + batch_size += 1; + total_weight += task_weight; + } else { break; } } diff --git a/src/hive/inner/queue/retry.rs b/src/hive/inner/queue/retry.rs index 4a55056..e5dfb5e 100644 --- a/src/hive/inner/queue/retry.rs +++ b/src/hive/inner/queue/retry.rs @@ -47,7 +47,7 @@ impl RetryQueue { Some(queue) => { // compute the delay let delay = 2u64 - .checked_pow(task.attempt - 1) + .checked_pow(task.meta.attempt() - 1) .and_then(|multiplier| { self.delay_factor .get() @@ -143,6 +143,7 @@ impl Eq for DelayedTask {} mod tests { use super::{RetryQueue, Task, Worker}; use crate::bee::stock::EchoWorker; + use crate::bee::{TaskId, TaskMeta}; use std::{thread, time::Duration}; type TestWorker = EchoWorker; @@ -154,13 +155,24 @@ mod tests { } } + impl Task { + /// Creates a new `Task` with the given `task_id`. + fn with_attempt(task_id: TaskId, input: W::Input, attempt: u32) -> Self { + Self { + input, + meta: TaskMeta::with_attempt(task_id, attempt), + outcome_tx: None, + } + } + } + #[test] fn test_works() { let queue = RetryQueue::::new(DELAY); - let task1 = Task::with_attempt(1, 1, None, 1); - let task2 = Task::with_attempt(2, 2, None, 2); - let task3 = Task::with_attempt(3, 3, None, 3); + let task1 = Task::with_attempt(1, 1, 1); + let task2 = Task::with_attempt(2, 2, 2); + let task3 = Task::with_attempt(3, 3, 3); queue.try_push(task1.clone()).unwrap(); queue.try_push(task2.clone()).unwrap(); @@ -188,9 +200,9 @@ mod tests { fn test_into_vec() { let queue = RetryQueue::::new(DELAY); - let task1 = Task::with_attempt(1, 1, None, 1); - let task2 = Task::with_attempt(2, 2, None, 2); - let task3 = Task::with_attempt(3, 3, None, 3); + let task1 = Task::with_attempt(1, 1, 1); + let task2 = Task::with_attempt(2, 2, 2); + let task3 = Task::with_attempt(3, 3, 3); queue.try_push(task1.clone()).unwrap(); queue.try_push(task2.clone()).unwrap(); diff --git a/src/hive/inner/queue/workstealing.rs b/src/hive/inner/queue/workstealing.rs index 8810169..09dfe97 100644 --- a/src/hive/inner/queue/workstealing.rs +++ b/src/hive/inner/queue/workstealing.rs @@ -2,10 +2,10 @@ //! Tasks are sent from the `Hive` via a global `Injector` queue. Each worker thread has a local //! `Worker` queue where tasks can be pushed. If the local queue is empty, the worker thread first //! tries to steal a task from the global queue and falls back to stealing from another worker -//! thread. If the `batching` feature is enabled, a worker thread will try to fill its local queue +//! thread. If the `local-batch` feature is enabled, a worker thread will try to fill its local queue //! up to the limit when stealing from the global queue. use super::{Config, PopTaskError, Status, Task, TaskQueues, Token, WorkerQueues}; -#[cfg(feature = "batching")] +#[cfg(feature = "local-batch")] use crate::atomic::Atomic; use crate::bee::Worker; use crossbeam_deque::{Injector, Stealer}; @@ -139,21 +139,46 @@ impl GlobalQueue { /// Tries to steal up to `limit + 1` tasks from the global queue. If at least one task was /// stolen, it is popped and returned. Otherwise tries to steal a task from another worker /// thread. - #[cfg(feature = "batching")] + #[cfg(feature = "local-batch")] fn try_refill_and_pop( &self, local_batch: &crossbeam_deque::Worker>, - limit: usize, + batch_limit: usize, + weight_limit: u64, ) -> Result, PopTaskError> { - if let Some(task) = self - .queue - .steal_batch_with_limit_and_pop(local_batch, limit + 1) - .success() - { - Ok(task) - } else { - self.try_steal_from_worker() + // if we only have a size limit but not a weight limit, use the batch-stealing function + // provided by `Injector` + if batch_limit > 0 && weight_limit == 0 { + if let Some(first) = self + .queue + .steal_batch_with_limit_and_pop(local_batch, batch_limit + 1) + .success() + { + return Ok(first); + } + } + // try to steal at least one from the global queue + if let Some(first) = self.queue.steal().success() { + if batch_limit > 0 && weight_limit > 0 { + // if batching is enabled and we have a weight limit, try to steal a batch of tasks + // from the global queue one at a time + let mut batch_size = 0; + let mut total_weight = first.meta.weight() as u64; + while let Some(task) = self.queue.steal().success() { + total_weight += task.meta.weight() as u64; + local_batch.push(task); + if total_weight >= weight_limit { + break; + } + batch_size += 1; + if batch_size >= batch_limit { + break; + } + } + } + return Ok(first); } + self.try_steal_from_worker() } fn is_closed(&self) -> bool { @@ -226,8 +251,12 @@ struct LocalQueueShared { _thread_index: usize, /// queue of abandon tasks local_abandoned: SegQueue>, - #[cfg(feature = "batching")] + /// limit on the number of tasks that can be queued + #[cfg(feature = "local-batch")] batch_limit: crate::atomic::AtomicUsize, + /// limit on the total weight of active + queued tasks + #[cfg(feature = "local-batch")] + weight_limit: crate::atomic::AtomicU64, /// thread-local queues used for tasks that are waiting to be retried after a failure #[cfg(feature = "retry")] local_retry: super::RetryQueue, @@ -238,16 +267,20 @@ impl LocalQueueShared { Self { _thread_index: thread_index, local_abandoned: Default::default(), - #[cfg(feature = "batching")] + #[cfg(feature = "local-batch")] batch_limit: crate::atomic::AtomicUsize::new(_config.batch_limit.get_or_default()), #[cfg(feature = "retry")] local_retry: super::RetryQueue::new(_config.retry_factor.get_or_default()), + #[cfg(feature = "local-batch")] + weight_limit: crate::atomic::AtomicU64::new(_config.weight_limit.get_or_default()), } } fn update(&self, _config: &Config) { - #[cfg(feature = "batching")] + #[cfg(feature = "local-batch")] self.batch_limit.set(_config.batch_limit.get_or_default()); + #[cfg(feature = "local-batch")] + self.weight_limit.set(_config.weight_limit.get_or_default()); #[cfg(feature = "retry")] self.local_retry .set_delay_factor(_config.retry_factor.get_or_default()); @@ -274,13 +307,14 @@ impl LocalQueueShared { if let Some(task) = local_batch.pop() { return Ok(task); } - // fall back to requesting a task from the global queue - if batching is enabled, this will - // also try to refill the local queue - #[cfg(feature = "batching")] + // fall back to requesting a task from the global queue - if local batching is enabled, + // this will also try to refill the local queue + #[cfg(feature = "local-batch")] { - let limit = self.batch_limit.get(); - if limit > 0 { - return global.try_refill_and_pop(local_batch, limit); + let batch_limit = self.batch_limit.get(); + if batch_limit > 0 { + let weight_limit = self.weight_limit.get(); + return global.try_refill_and_pop(local_batch, batch_limit, weight_limit); } } global.try_pop_unchecked() diff --git a/src/hive/inner/shared.rs b/src/hive/inner/shared.rs index 77cc27d..79d9ef2 100644 --- a/src/hive/inner/shared.rs +++ b/src/hive/inner/shared.rs @@ -1,4 +1,4 @@ -use super::{Config, PopTaskError, Shared, Task, TaskQueues, Token, WorkerQueues}; +use super::{Config, PopTaskError, Shared, Task, TaskInput, TaskQueues, Token, WorkerQueues}; use crate::atomic::{Atomic, AtomicInt, AtomicUsize}; use crate::bee::{Queen, TaskId, Worker}; use crate::channel::SenderExt; @@ -158,60 +158,55 @@ impl, T: TaskQueues> Shared { /// Increments the number of queued tasks. Returns a new `Task` with the provided input and /// `outcome_tx` and the next ID. - pub fn prepare_task(&self, input: W::Input, outcome_tx: Option<&OutcomeSender>) -> Task { + pub fn prepare_task(&self, input: I, outcome_tx: Option<&OutcomeSender>) -> Task + where + I: Into>, + { self.num_tasks .increment_left(1) .expect("overflowed queued task counter"); let task_id = self.next_task_id.add(1); - Task::new(task_id, input, outcome_tx.cloned()) - } - - /// Adds `task` to the global queue if possible, otherwise abandons it - converts it to an - /// `Unprocessed` outcome and sends it to the outcome channel or stores it in the hive. - pub fn push_global(&self, task: Task) { - // try to send the task to the hive; if the hive is poisoned or if sending fails, convert - // the task into an `Unprocessed` outcome and try to send it to the outcome channel; if - // that fails, store the outcome in the hive - if let Some(abandoned_task) = if self.is_poisoned() { - Some(task) - } else { - self.task_queues.try_push_global(task).err() - } { - self.abandon_task(abandoned_task); - } + Task::new(task_id, input.into(), outcome_tx.cloned()) } /// Creates a new `Task` for the given input and outcome channel, and adds it to the global /// queue. - pub fn send_one_global( - &self, - input: W::Input, - outcome_tx: Option<&OutcomeSender>, - ) -> TaskId { + pub fn send_one_global(&self, input: I, outcome_tx: Option<&OutcomeSender>) -> TaskId + where + I: Into>, + { if self.num_threads() == 0 { dbg!("WARNING: no worker threads are active for hive"); } let task = self.prepare_task(input, outcome_tx); + // when the `local-batch` feature is enabled, immediately abandon any task whose weight is + // greater than the configured limit + #[cfg(feature = "local-batch")] + let task = match self.abandon_if_too_heavy(task) { + Ok(task) => task, + Err(task_id) => return task_id, + }; let task_id = task.id(); self.push_global(task); task_id } /// Creates a new `Task` for each input in the given batch and sends them to the global queue. - pub fn send_batch_global( + pub fn send_batch_global( &self, - inputs: I, + batch: B, outcome_tx: Option<&OutcomeSender>, ) -> Vec where - I: IntoIterator, - I::IntoIter: ExactSizeIterator, + I: Into>, + B: IntoIterator, + B::IntoIter: ExactSizeIterator, { #[cfg(debug_assertions)] if self.num_threads() == 0 { dbg!("WARNING: no worker threads are active for hive"); } - let iter = inputs.into_iter(); + let iter = batch.into_iter(); let (min_size, _) = iter.size_hint(); self.num_tasks .increment_left(min_size as u64) @@ -228,7 +223,7 @@ impl, T: TaskQueues> Shared { ) .map_while(move |pair| match pair { (Some(input), Some(task_id)) => { - Some(Task::new(task_id, input, outcome_tx.cloned())) + Some(Task::new(task_id, input.into(), outcome_tx.cloned())) } (Some(input), None) => Some(self.prepare_task(input, outcome_tx)), (None, Some(_)) => panic!("batch contained fewer than {min_size} items"), @@ -254,6 +249,21 @@ impl, T: TaskQueues> Shared { } } + /// Adds `task` to the global queue if possible, otherwise abandons it - converts it to an + /// `Unprocessed` outcome and sends it to the outcome channel or stores it in the hive. + pub fn push_global(&self, task: Task) { + // try to send the task to the hive; if the hive is poisoned or if sending fails, convert + // the task into an `Unprocessed` outcome and try to send it to the outcome channel; if + // that fails, store the outcome in the hive + if let Some(abandoned_task) = if self.is_poisoned() { + Some(task) + } else { + self.task_queues.try_push_global(task).err() + } { + self.abandon_task(abandoned_task); + } + } + /// Returns the next available `Task`. If there is a task in any local queue, it is returned, /// otherwise a task is requested from the global queue. /// @@ -577,8 +587,8 @@ mod affinity { } } -#[cfg(feature = "batching")] -mod batching { +#[cfg(feature = "local-batch")] +mod local_batch { use super::Shared; use crate::bee::{Queen, Worker}; use crate::hive::TaskQueues; @@ -626,7 +636,7 @@ mod batching { #[cfg(feature = "retry")] mod retry { - use crate::bee::{Queen, TaskId, Worker}; + use crate::bee::{Queen, TaskMeta, Worker}; use crate::hive::inner::{Shared, Task, TaskQueues}; use crate::hive::{OutcomeSender, WorkerQueues}; use std::time::{Duration, Instant}; @@ -687,12 +697,12 @@ mod retry { } /// Returns `true` if the hive is configured to retry tasks and the `attempt` field of the - /// given `ctx` is less than the maximum number of retries. - pub fn can_retry(&self, attempt: u32) -> bool { + /// given `task_meta` is less than the maximum number of retries. + pub fn can_retry(&self, task_meta: &TaskMeta) -> bool { self.config .max_retries .get() - .map(|max_retries| attempt < max_retries) + .map(|max_retries| task_meta.attempt() < max_retries) .unwrap_or(false) } @@ -700,21 +710,48 @@ mod retry { /// queue for the specified `thread_index`. pub fn try_send_retry( &self, - task_id: TaskId, input: W::Input, + meta: TaskMeta, outcome_tx: Option<&OutcomeSender>, - attempt: u32, worker_queues: &T::WorkerQueues, ) -> Result> { self.num_tasks .increment_left(1) .expect("overflowed queued task counter"); - let task = Task::with_attempt(task_id, input, outcome_tx.cloned(), attempt); + let task = Task::with_meta_inc_attempt(input, meta, outcome_tx.cloned()); worker_queues.try_push_retry(task) } } } +#[cfg(feature = "local-batch")] +mod weighting { + use crate::bee::{Queen, TaskId, Worker}; + use crate::hive::inner::{Shared, Task, TaskQueues}; + + impl Shared + where + W: Worker, + Q: Queen, + T: TaskQueues, + { + pub fn abandon_if_too_heavy(&self, task: Task) -> Result, TaskId> { + let weight_limit = self.config.weight_limit.get().unwrap_or(0); + if weight_limit > 0 && task.meta().weight() as u64 > weight_limit { + let task_id = task.id(); + let (outcome, outcome_tx) = task.into_overweight(); + self.send_or_store_outcome(outcome, outcome_tx); + // decrement the queued counter since it was incremented but the task was never queued + let _ = self.num_tasks.decrement_left(1); + self.no_work_notify_all(); + Err(task_id) + } else { + Ok(task) + } + } + } +} + #[cfg(test)] mod tests { use crate::bee::DefaultQueen; diff --git a/src/hive/inner/task.rs b/src/hive/inner/task.rs index c4ad1ac..d47e06e 100644 --- a/src/hive/inner/task.rs +++ b/src/hive/inner/task.rs @@ -1,11 +1,37 @@ use super::Task; -use crate::bee::{TaskId, Worker}; +use crate::bee::{TaskId, TaskMeta, Worker}; use crate::hive::{Outcome, OutcomeSender}; +pub use task_impl::TaskInput; + impl Task { - /// Returns the ID of this task. + /// Creates a new `Task` with the given metadata. + pub fn with_meta( + input: W::Input, + meta: TaskMeta, + outcome_tx: Option>, + ) -> Self { + Task { + input, + meta, + outcome_tx, + } + } + + #[inline] pub fn id(&self) -> TaskId { - self.id + self.meta.id() + } + + /// Returns a reference to the task metadata. + #[inline] + pub fn meta(&self) -> &TaskMeta { + &self.meta + } + + /// Consumes this `Task` and returns its input, metadata, and outcome sender. + pub fn into_parts(self) -> (W::Input, TaskMeta, Option>) { + (self.input, self.meta, self.outcome_tx) } /// Consumes this `Task` and returns a `Outcome::Unprocessed` outcome with the input and ID, @@ -13,96 +39,108 @@ impl Task { pub fn into_unprocessed(self) -> (Outcome, Option>) { let outcome = Outcome::Unprocessed { input: self.input, - task_id: self.id, + task_id: self.meta.id(), }; (outcome, self.outcome_tx) } } -#[cfg(not(feature = "retry"))] -mod no_retry { +#[cfg(not(feature = "local-batch"))] +mod task_impl { use super::Task; - use crate::bee::{TaskId, Worker}; + use crate::bee::{TaskId, TaskMeta, Worker}; use crate::hive::OutcomeSender; + pub type TaskInput = ::Input; + impl Task { - /// Creates a new `Task`. - pub fn new(id: TaskId, input: W::Input, outcome_tx: Option>) -> Self { + /// Creates a new `Task` with the given `task_id`. + pub fn new( + task_id: TaskId, + input: TaskInput, + outcome_tx: Option>, + ) -> Self { Task { - id, input, + meta: TaskMeta::new(task_id), outcome_tx, } } - - pub fn into_parts(self) -> (TaskId, W::Input, Option>) { - (self.id, self.input, self.outcome_tx) - } } +} - impl> Clone for Task { - fn clone(&self) -> Self { - Self { - id: self.id.clone(), - input: self.input.clone(), - outcome_tx: self.outcome_tx.clone(), +#[cfg(feature = "local-batch")] +mod task_impl { + use super::Task; + use crate::bee::{TaskId, TaskMeta, Worker}; + use crate::hive::{Outcome, OutcomeSender, Weighted}; + + pub type TaskInput = Weighted<::Input>; + + impl Task { + /// Creates a new `Task` with the given `task_id`. + pub fn new( + task_id: TaskId, + input: TaskInput, + outcome_tx: Option>, + ) -> Self { + let (input, weight) = input.into_parts(); + Task { + input, + meta: TaskMeta::with_weight(task_id, weight), + outcome_tx, } } + + /// Consumes this `Task` and returns a `Outcome::WeightLimitExceeded` outcome with the input, + /// weight, and ID, and the outcome sender. + pub fn into_overweight(self) -> (Outcome, Option>) { + let outcome = Outcome::WeightLimitExceeded { + input: self.input, + weight: self.meta.weight(), + task_id: self.meta.id(), + }; + (outcome, self.outcome_tx) + } } } #[cfg(feature = "retry")] mod retry { use super::Task; - use crate::bee::{TaskId, Worker}; + use crate::bee::{TaskMeta, Worker}; use crate::hive::OutcomeSender; impl Task { /// Creates a new `Task`. - pub fn new(id: TaskId, input: W::Input, outcome_tx: Option>) -> Self { - Task { - id, - input, - outcome_tx, - attempt: 0, - } - } - - /// Creates a new `Task`. - pub fn with_attempt( - id: TaskId, + pub fn with_meta_inc_attempt( input: W::Input, + mut meta: TaskMeta, outcome_tx: Option>, - attempt: u32, ) -> Self { - Task { - id, + meta.inc_attempt(); + Self { input, + meta, outcome_tx, - attempt, } } - - pub fn into_parts(self) -> (TaskId, W::Input, u32, Option>) { - (self.id, self.input, self.attempt, self.outcome_tx) - } } +} - impl> Clone for Task { - fn clone(&self) -> Self { - Self { - id: self.id, - input: self.input.clone(), - outcome_tx: self.outcome_tx.clone(), - attempt: self.attempt, - } +impl> Clone for Task { + fn clone(&self) -> Self { + Self { + input: self.input.clone(), + meta: self.meta.clone(), + outcome_tx: self.outcome_tx.clone(), } } } impl PartialEq for Task { fn eq(&self, other: &Self) -> bool { - self.id == other.id + self.meta.id() == other.meta.id() } } @@ -116,6 +154,6 @@ impl PartialOrd for Task { impl Ord for Task { fn cmp(&self, other: &Self) -> std::cmp::Ordering { - self.id.cmp(&other.id) + self.meta.id().cmp(&other.meta.id()) } } diff --git a/src/hive/mod.rs b/src/hive/mod.rs index 02823b7..026fe08 100644 --- a/src/hive/mod.rs +++ b/src/hive/mod.rs @@ -130,15 +130,15 @@ //! the `retry` feature is not enabled. In such cases, `Retryable` errors are automatically //! converted to `Fatal` errors by the worker thread. //! -//! ## Batching tasks (requires `feature = "batching"`) +//! ## Batching tasks (requires `feature = "local-batch"`) //! //! The performance of a `Hive` can degrade as the number of worker threads grows and/or the //! average duration of a task shrinks, due to increased contention between worker threads when //! receiving tasks from the shared input channel. To improve performance, workers can take more //! than one task each time they access the input channel, and store the extra tasks in a local -//! queue. This behavior is activated by enabling the `batching` feature. +//! queue. This behavior is activated by enabling the `local-batch` feature. //! -//! With the `batching` feature enabled, `Builder` gains the +//! With the `local-batch` feature enabled, `Builder` gains the //! [`batch_limit`](crate::hive::Builder::batch_limit) method for configuring size of worker threads' //! local queues, and `Hive` gains the [`set_worker_batch_limit`](crate::hive::Hive::set_batch_limit) //! method for changing the batch size of an existing `Hive`. @@ -154,7 +154,7 @@ //! * `num_threads` //! * [`set_num_threads_default`]: sets the default to a specific value //! * [`set_num_threads_default_all`]: sets the default to all available CPU cores -//! * [`batch_limit`](crate::hive::set_BATCH_LIMIT_default) (requires `feature = "batching"`) +//! * [`batch_limit`](crate::hive::set_BATCH_LIMIT_default) (requires `feature = "local-batch"`) //! * [`max_retries`](crate::hive::set_max_retries_default] (requires `feature = "retry"`) //! * [`retry_factor`](crate::hive::set_retry_factor_default] (requires `feature = "retry"`) //! @@ -381,6 +381,7 @@ //! ([`Husk::as_builder`](crate::hive::husk::Husk::as_builder)) or a new `Hive` //! ([`Husk::into_hive`](crate::hive::husk::Husk::into_hive)). mod builder; +mod context; #[cfg(feature = "affinity")] pub mod cores; #[allow(clippy::module_inception)] @@ -388,6 +389,9 @@ mod hive; mod husk; mod inner; mod outcome; +mod sentinel; +mod util; +mod weighted; pub use self::builder::{BeeBuilder, ChannelBuilder, FullBuilder, OpenBuilder, TaskQueuesBuilder}; pub use self::builder::{ @@ -395,11 +399,16 @@ pub use self::builder::{ }; pub use self::hive::{DefaultHive, Hive, Poisoned}; pub use self::husk::Husk; -pub use self::inner::{Builder, ChannelTaskQueues, WorkstealingTaskQueues, set_config::*}; +pub use self::inner::{ + Builder, ChannelTaskQueues, TaskInput, WorkstealingTaskQueues, set_config::*, +}; pub use self::outcome::{Outcome, OutcomeBatch, OutcomeIteratorExt, OutcomeStore}; +pub use self::weighted::Weighted; +use self::context::HiveLocalContext; use self::inner::{Config, Shared, Task, TaskQueues, WorkerQueues}; use self::outcome::{DerefOutcomes, OutcomeQueue, OwnedOutcomes}; +use self::sentinel::Sentinel; use crate::bee::Worker; use crate::channel::{Receiver, Sender, channel}; use std::io::Error as SpawnError; @@ -423,42 +432,6 @@ pub mod prelude { }; } -mod util { - use crossbeam_utils::Backoff; - use std::sync::Arc; - use std::time::{Duration, Instant}; - - const MAX_WAIT: Duration = Duration::from_secs(10); - - /// Utility function to loop (with exponential backoff) waiting for other references to `arc` to - /// drop so it can be unwrapped into its inner value. - /// - /// If `arc` cannot be unwrapped with a certain amount of time (with an exponentially - /// increasing gap between each iteration), `arc` is returned as an error. - pub fn unwrap_arc(mut arc: Arc) -> Result> { - // wait for worker threads to drop, then take ownership of the shared data and convert it - // into a Husk - let mut backoff = None::; - let mut start = None::; - loop { - arc = match std::sync::Arc::try_unwrap(arc) { - Ok(inner) => { - return Ok(inner); - } - Err(arc) if start.is_none() => { - let _ = start.insert(Instant::now()); - arc - } - Err(arc) if Instant::now() - start.unwrap() > MAX_WAIT => return Err(arc), - Err(arc) => { - backoff.get_or_insert_with(Backoff::new).spin(); - arc - } - }; - } - } -} - #[cfg(test)] mod tests { use super::inner::TaskQueues; @@ -1206,7 +1179,7 @@ mod tests { F: Fn(bool) -> B, { let hive = thunk_hive::(8, builder_factory(false)); - #[cfg(feature = "batching")] + #[cfg(feature = "local-batch")] assert_eq!(hive.worker_batch_limit(), 0); let (tx, rx) = super::outcome_channel(); let mut task_ids = hive.swarm_send( @@ -1266,10 +1239,10 @@ mod tests { F: Fn(bool) -> B, { let hive = builder_factory(false) - .with_worker(Caller::of(|i| i * i)) + .with_worker(Caller::of(|i: usize| i * i)) .num_threads(4) .build(); - let (outputs, state) = hive.scan(0..10, 0, |acc, i| { + let (outputs, state) = hive.scan(0..10usize, 0, |acc, i| { *acc += i; *acc }); @@ -1296,7 +1269,7 @@ mod tests { F: Fn(bool) -> B, { let hive = builder_factory(false) - .with_worker(Caller::of(|i| i * i)) + .with_worker(Caller::of(|i: i32| i * i)) .num_threads(4) .build(); let (tx, rx) = super::outcome_channel(); @@ -1336,7 +1309,7 @@ mod tests { F: Fn(bool) -> B, { let hive = builder_factory(false) - .with_worker(Caller::of(|i| i * i)) + .with_worker(Caller::of(|i: i32| i * i)) .num_threads(4) .build(); let (tx, rx) = super::outcome_channel(); @@ -1384,7 +1357,7 @@ mod tests { .build(); let (tx, _) = super::outcome_channel(); let _ = hive - .try_scan_send(0..10, &tx, 0, |_, _| Err("fail")) + .try_scan_send(0..10, &tx, 0, |_, _| Err::("fail")) .0 .into_iter() .map(Result::unwrap) @@ -1398,7 +1371,7 @@ mod tests { F: Fn(bool) -> B, { let mut hive = builder_factory(false) - .with_worker(Caller::of(|i| i * i)) + .with_worker(Caller::of(|i: i32| i * i)) .num_threads(4) .build(); let (mut task_ids, state) = hive.scan_store(0..10, 0, |acc, i| { @@ -1439,7 +1412,7 @@ mod tests { F: Fn(bool) -> B, { let mut hive = builder_factory(false) - .with_worker(Caller::of(|i| i * i)) + .with_worker(Caller::of(|i: i32| i * i)) .num_threads(4) .build(); let (results, state) = hive.try_scan_store(0..10, 0, |acc, i| { @@ -1486,7 +1459,7 @@ mod tests { .num_threads(4) .build(); let _ = hive - .try_scan_store(0..10, 0, |_, _| Err("fail")) + .try_scan_store(0..10, 0, |_, _| Err::("fail")) .0 .into_iter() .map(Result::unwrap) @@ -2017,8 +1990,8 @@ mod affinity_tests { } } -#[cfg(all(test, feature = "batching"))] -mod batching_tests { +#[cfg(all(test, feature = "local-batch"))] +mod local_batch_tests { use crate::barrier::IndexedBarrier; use crate::bee::DefaultQueen; use crate::bee::stock::{Thunk, ThunkWorker}; @@ -2113,7 +2086,7 @@ mod batching_tests { } #[rstest] - fn test_batching_channel() { + fn test_local_batch_channel() { const NUM_THREADS: usize = 4; const BATCH_LIMIT: usize = 24; let hive = channel_builder(false) @@ -2125,7 +2098,7 @@ mod batching_tests { } #[rstest] - fn test_batching_workstealing() { + fn test_local_batch_workstealing() { const NUM_THREADS: usize = 4; const BATCH_LIMIT: usize = 24; let hive = workstealing_builder(false) @@ -2243,7 +2216,7 @@ mod retry_tests { .retry_factor(Duration::from_secs(1)) .build(); - let v: Result, _> = hive.swarm(0..10).into_results().collect(); + let v: Result, _> = hive.swarm(0..10usize).into_results().collect(); assert_eq!(v.unwrap().len(), 10); } @@ -2277,7 +2250,7 @@ mod retry_tests { .max_retries(3) .build(); - let (success, retry_failed, not_retried) = hive.swarm(0..10).fold( + let (success, retry_failed, not_retried) = hive.swarm(0..10usize).fold( (0, 0, 0), |(success, retry_failed, not_retried), outcome| match outcome { Outcome::Success { .. } => (success + 1, retry_failed, not_retried), @@ -2304,7 +2277,7 @@ mod retry_tests { .with_thread_per_core() .with_no_retries() .build(); - let v: Result, _> = hive.swarm(0..10).into_results().collect(); + let v: Result, _> = hive.swarm(0..10usize).into_results().collect(); assert!(v.is_err()); } } diff --git a/src/hive/outcome/mod.rs b/src/hive/outcome/mod.rs index 3669e51..985fc2b 100644 --- a/src/hive/outcome/mod.rs +++ b/src/hive/outcome/mod.rs @@ -75,6 +75,13 @@ pub enum Outcome { task_id: TaskId, subtask_ids: Vec, }, + /// The task's weight was larger than the configured limit for the `Hive`. + #[cfg(feature = "local-batch")] + WeightLimitExceeded { + input: W::Input, + weight: u32, + task_id: TaskId, + }, /// The task failed after retrying the maximum number of times. #[cfg(feature = "retry")] MaxRetriesAttempted { diff --git a/src/hive/outcome/outcome.rs b/src/hive/outcome/outcome.rs index 39b8886..f1aff75 100644 --- a/src/hive/outcome/outcome.rs +++ b/src/hive/outcome/outcome.rs @@ -1,15 +1,28 @@ use super::Outcome; -use crate::bee::{ApplyError, TaskId, Worker, WorkerResult}; +use crate::bee::{ApplyError, TaskId, TaskMeta, Worker, WorkerResult}; use std::cmp::Ordering; use std::fmt::Debug; impl Outcome { + pub(in crate::hive) fn from_fatal( + input: W::Input, + task_meta: TaskMeta, + error: W::Error, + ) -> Self { + Self::Failure { + input: Some(input), + error, + task_id: task_meta.id(), + } + } + /// Converts a worker `result` into an `Outcome` with the given task_id and optional subtask ids. pub(in crate::hive) fn from_worker_result( result: WorkerResult, - task_id: TaskId, + task_meta: TaskMeta, subtask_ids: Option>, ) -> Self { + let task_id = task_meta.id(); match (result, subtask_ids) { (Ok(value), Some(subtask_ids)) => Self::SuccessWithSubtasks { value, @@ -112,6 +125,8 @@ impl Outcome { | Self::Missing { task_id } | Self::Panic { task_id, .. } | Self::PanicWithSubtasks { task_id, .. } => task_id, + #[cfg(feature = "local-batch")] + Self::WeightLimitExceeded { task_id, .. } => task_id, #[cfg(feature = "retry")] Self::MaxRetriesAttempted { task_id, .. } => task_id, } @@ -156,6 +171,8 @@ impl Outcome { Some(input) } Self::Success { .. } | Self::SuccessWithSubtasks { .. } | Self::Missing { .. } => None, + #[cfg(feature = "local-batch")] + Self::WeightLimitExceeded { input, .. } => Some(input), #[cfg(feature = "retry")] Self::MaxRetriesAttempted { input, .. } => Some(input), } @@ -176,6 +193,8 @@ impl Outcome { | Self::Unprocessed { .. } | Self::UnprocessedWithSubtasks { .. } | Self::Missing { .. } => None, + #[cfg(feature = "local-batch")] + Self::WeightLimitExceeded { .. } => None, #[cfg(feature = "retry")] Self::MaxRetriesAttempted { error, .. } => Some(error), } @@ -241,6 +260,11 @@ impl Debug for Outcome { .field("task_id", task_id) .field("subtask_ids", subtask_ids) .finish(), + #[cfg(feature = "local-batch")] + Self::WeightLimitExceeded { task_id, .. } => f + .debug_struct("WeightLimitExceeded") + .field("task_id", task_id) + .finish(), #[cfg(feature = "retry")] Self::MaxRetriesAttempted { error, task_id, .. } => f .debug_struct("MaxRetriesAttempted") diff --git a/src/hive/sentinel.rs b/src/hive/sentinel.rs new file mode 100644 index 0000000..795aa03 --- /dev/null +++ b/src/hive/sentinel.rs @@ -0,0 +1,67 @@ +use super::{Shared, TaskQueues}; +use crate::bee::{Queen, Worker}; +use std::io::Error as SpawnError; +use std::sync::Arc; +use std::thread::JoinHandle; + +/// Sentinel for a worker thread. Until the sentinel is cancelled, it will respawn the worker +/// thread if it panics. +pub struct Sentinel +where + W: Worker, + Q: Queen, + T: TaskQueues, + F: Fn(usize, &Arc>) -> Result, SpawnError> + 'static, +{ + thread_index: usize, + shared: Arc>, + active: bool, + respawn_fn: F, +} + +impl Sentinel +where + W: Worker, + Q: Queen, + T: TaskQueues, + F: Fn(usize, &Arc>) -> Result, SpawnError> + 'static, +{ + pub fn new(thread_index: usize, shared: Arc>, respawn_fn: F) -> Self { + Self { + thread_index, + shared, + active: true, + respawn_fn, + } + } + + /// Cancel and destroy this sentinel. + pub fn cancel(mut self) { + self.active = false; + } +} + +impl Drop for Sentinel +where + W: Worker, + Q: Queen, + T: TaskQueues, + F: Fn(usize, &Arc>) -> Result, SpawnError> + 'static, +{ + fn drop(&mut self) { + if self.active { + // if the sentinel is active, that means the thread panicked during task execution, so + // we have to finish the task here before respawning + self.shared.finish_task(std::thread::panicking()); + // only respawn if the sentinel is active and the hive has not been poisoned + if !self.shared.is_poisoned() { + // can't do anything with the previous result + let _ = self + .shared + .respawn_thread(self.thread_index, |thread_index| { + (self.respawn_fn)(thread_index, &self.shared) + }); + } + } + } +} diff --git a/src/hive/util.rs b/src/hive/util.rs new file mode 100644 index 0000000..4513c0f --- /dev/null +++ b/src/hive/util.rs @@ -0,0 +1,33 @@ +use crossbeam_utils::Backoff; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +const MAX_WAIT: Duration = Duration::from_secs(10); + +/// Utility function to loop (with exponential backoff) waiting for other references to `arc` to +/// drop so it can be unwrapped into its inner value. +/// +/// If `arc` cannot be unwrapped with a certain amount of time (with an exponentially +/// increasing gap between each iteration), `arc` is returned as an error. +pub fn unwrap_arc(mut arc: Arc) -> Result> { + // wait for worker threads to drop, then take ownership of the shared data and convert it + // into a Husk + let mut backoff = None::; + let mut start = None::; + loop { + arc = match std::sync::Arc::try_unwrap(arc) { + Ok(inner) => { + return Ok(inner); + } + Err(arc) if start.is_none() => { + let _ = start.insert(Instant::now()); + arc + } + Err(arc) if Instant::now() - start.unwrap() > MAX_WAIT => return Err(arc), + Err(arc) => { + backoff.get_or_insert_with(Backoff::new).spin(); + arc + } + }; + } +} diff --git a/src/hive/weighted.rs b/src/hive/weighted.rs new file mode 100644 index 0000000..cb5541d --- /dev/null +++ b/src/hive/weighted.rs @@ -0,0 +1,172 @@ +use std::ops::Deref; + +/// Wraps a value of type `T` and an associated weight. +pub struct Weighted { + value: T, + weight: u32, +} + +impl Weighted { + pub fn new(value: T, weight: u32) -> Self { + Self { value, weight } + } + + pub fn from_fn(value: T, f: F) -> Self + where + F: FnOnce(&T) -> u32, + { + let weight = f(&value); + Self::new(value, weight) + } + + pub fn weight(&self) -> u32 { + self.weight + } + + pub fn into_inner(self) -> T { + self.value + } + + pub fn into_parts(self) -> (T, u32) { + (self.value, self.weight) + } +} + +impl + Clone> Weighted { + pub fn from_identity(value: T) -> Self { + let weight = value.clone().into(); + Self::new(value, weight) + } +} + +impl Deref for Weighted { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.value + } +} + +impl From for Weighted { + fn from(value: T) -> Self { + Self::new(value, 0) + } +} + +impl From<(T, u32)> for Weighted { + fn from((value, weight): (T, u32)) -> Self { + Self::new(value, weight) + } +} + +// /// Trait implemented by a type that can be converted into an iterator over `Weighted` items. +// pub trait IntoWeightedIterator { +// type Item; +// type IntoIter: Iterator>; + +// fn into_weighted_iter(self) -> Self::IntoIter; +// } + +// // Calls to Hive task submission functions should work with iterators of `W::Input` regardless of +// // whether the `local-batch` feature is enabled. This blanket implementation of +// // `IntoWeightedIterator` just gives every task a weight of 0. +// impl> IntoWeightedIterator for I { +// type Item = T; +// type IntoIter = std::iter::Map Weighted>; + +// fn into_weighted_iter(self) -> Self::IntoIter { +// self.into_iter().map(Into::into) +// } +// } + +// struct WeightedIter(Box>>); + +// impl WeightedIter { +// fn from_tuples(tuples: I) -> Self +// where +// I: IntoIterator, +// I::IntoIter: 'static, +// { +// Self(Box::new(tuples.into_iter().map(Into::into))) +// } + +// fn from_iter(items: I, weights: W) -> Self +// where +// I: IntoIterator, +// I::IntoIter: 'static, +// W: IntoIterator, +// W::IntoIter: 'static, +// { +// Self(Box::new( +// items +// .into_iter() +// .zip(weights.into_iter().chain(std::iter::repeat(0))) +// .map(Into::into), +// )) +// } + +// fn from_const(items: I, weight: u32) -> Self +// where +// I: IntoIterator, +// I::IntoIter: 'static, +// { +// Self::from_iter(items, std::iter::repeat(weight)) +// } + +// fn from_fn(items: B, f: F) -> Self +// where +// B: IntoIterator, +// B::IntoIter: 'static, +// F: Fn(&T) -> u32 + 'static, +// { +// Self(Box::new(items.into_iter().map(move |value| { +// let weight = f(&value); +// Weighted::new(value, weight) +// }))) +// } +// } + +// impl + Clone + 'static> WeightedIter { +// fn from_identity(items: I) -> Self +// where +// I: IntoIterator, +// I::IntoIter: 'static, +// { +// Self::from_fn(items, |item| item.clone().into()) +// } +// } + +// impl IntoWeightedIterator for WeightedIter { +// type Item = T; +// type IntoIter = Box>>; + +// fn into_weighted_iter(self) -> Self::IntoIter { +// self.0 +// } +// } + +// #[cfg(feature = "local-batch")] +// pub fn map2(&self, batch: B) -> impl Iterator> + use +// where +// B: IntoWeightedIterator, +// { +// let (tx, rx) = outcome_channel(); +// let task_ids: Vec<_> = batch +// .into_weighted_iter() +// .map(|(task, weight)| self.apply_send(task, &tx)) +// .collect(); +// drop(tx); +// rx.select_ordered(task_ids) +// } +// #[cfg(test)] +// mod foo { +// use super::*; +// use crate::bee::stock::EchoWorker; + +// fn test_foo() { +// let hive = DefaultHive::>::default(); +// let result = hive +// .map2(Weighted::from_iter(0..10, 0..10)) +// .collect::>(); +// } +// } From 0b76afa6c4ee09ffabdba0c76077c8043ba7545c Mon Sep 17 00:00:00 2001 From: jdidion Date: Thu, 27 Feb 2025 13:37:00 -0800 Subject: [PATCH 43/67] move mock --- CHANGELOG.md | 2 +- src/bee/mock.rs | 44 ----------------------------------------- src/bee/mod.rs | 2 -- src/hive/mock.rs | 51 ++++++++++++++++++++++++++++++++++++++++++++++++ src/hive/mod.rs | 2 ++ 5 files changed, 54 insertions(+), 47 deletions(-) delete mode 100644 src/bee/mock.rs create mode 100644 src/hive/mock.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index 1d41e62..f3cc9b8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,7 +28,7 @@ The general theme of this release is performance improvement by eliminating thre * Switched to using thread-local retry queues for the implementation of the `retry` feature, to reduce thread-contention. * Switched to storing `Outcome`s in the hive using a data structure that does not require locking when inserting, which should reduce thread contention when using `*_store` operations. * Switched to using `crossbeam_channel` for the task input channel in `ChannelTaskQueues`. These are multi-produer, multi-consumer channels (mpmc; as opposed to `std::mpsc`, which is single-consumer), which means it is no longer necessary for worker threads to aquire a Mutex lock on the channel receiver when getting tasks. - * Added the `beekeeper::bee::mock` module, which has a mock implementation of `beekeeper::bee::context::LocalContext`, and a `apply` function for `apply`ing a worker in a mock context. This is useful for testing your `Worker`. + * Added the `beekeeper::hive::mock` module, which has a `MockTaskRunner` for `apply`ing a worker in a mock context. This is useful for testing your `Worker`. ## 0.2.1 diff --git a/src/bee/mock.rs b/src/bee/mock.rs deleted file mode 100644 index c73b91f..0000000 --- a/src/bee/mock.rs +++ /dev/null @@ -1,44 +0,0 @@ -use super::{Context, LocalContext, TaskId, TaskMeta, Worker, WorkerResult}; -use std::cell::RefCell; - -/// Applies the given `worker` to the given `input` using the given `task_meta`. -/// -/// Returns a tuple of the apply result, the (possibly modified) task metadata, and the IDs of any -/// subtasks that were submitted. -pub fn apply( - input: W::Input, - task_meta: TaskMeta, - worker: &mut W, -) -> (WorkerResult, TaskMeta, Option>) { - let local = MockLocalContext::new(task_meta.id()); - let ctx = Context::new(task_meta, Some(&local)); - let result = worker.apply(input, &ctx); - let (task_meta, subtask_ids) = ctx.into_parts(); - (result, task_meta, subtask_ids) -} - -#[derive(Debug, Default)] -pub struct MockLocalContext(RefCell); - -impl MockLocalContext { - pub fn new(task_id: TaskId) -> Self { - Self(RefCell::new(task_id)) - } - - pub fn into_task_count(self) -> usize { - self.0.into_inner() - } -} - -impl LocalContext for MockLocalContext { - fn should_cancel_tasks(&self) -> bool { - false - } - - fn submit_task(&self, _: I) -> super::TaskId { - let mut task_id = self.0.borrow_mut(); - let cur_id = *task_id; - *task_id += 1; - cur_id - } -} diff --git a/src/bee/mod.rs b/src/bee/mod.rs index de3938a..9501a49 100644 --- a/src/bee/mod.rs +++ b/src/bee/mod.rs @@ -112,8 +112,6 @@ //! workers, the queen, and/or the client thread(s). mod context; mod error; -#[cfg_attr(coverage_nightly, coverage(off))] -pub mod mock; mod queen; pub mod stock; mod worker; diff --git a/src/hive/mock.rs b/src/hive/mock.rs new file mode 100644 index 0000000..9158506 --- /dev/null +++ b/src/hive/mock.rs @@ -0,0 +1,51 @@ +use super::{Task, TaskInput}; +use crate::bee::{Context, LocalContext, TaskId, TaskMeta, Worker, WorkerResult}; +use std::cell::RefCell; + +#[derive(Debug)] +pub struct MockTaskRunner(RefCell); + +impl MockTaskRunner { + pub fn new() -> Self { + Self(RefCell::new(0)) + } + + /// Applies the given `worker` to the given `input`. + /// + /// Returns a tuple of the apply result, the task metadata, and the IDs of any + /// subtasks that were submitted. + pub fn apply( + &self, + input: TaskInput, + worker: &mut W, + ) -> (WorkerResult, TaskMeta, Option>) { + let task_id = self.next_task_id(); + let local = MockLocalContext(&self); + let task: Task = Task::new(task_id, input, None); + let (input, task_meta, _) = task.into_parts(); + let ctx = Context::new(task_meta, Some(&local)); + let result = worker.apply(input, &ctx); + let (task_meta, subtask_ids) = ctx.into_parts(); + (result, task_meta, subtask_ids) + } + + fn next_task_id(&self) -> TaskId { + let mut task_id_counter = self.0.borrow_mut(); + let task_id = *task_id_counter; + *task_id_counter += 1; + task_id + } +} + +#[derive(Debug)] +struct MockLocalContext<'a>(&'a MockTaskRunner); + +impl<'a, I> LocalContext for MockLocalContext<'a> { + fn should_cancel_tasks(&self) -> bool { + false + } + + fn submit_task(&self, _: I) -> TaskId { + self.0.next_task_id() + } +} diff --git a/src/hive/mod.rs b/src/hive/mod.rs index 026fe08..88bf413 100644 --- a/src/hive/mod.rs +++ b/src/hive/mod.rs @@ -388,6 +388,8 @@ pub mod cores; mod hive; mod husk; mod inner; +#[cfg_attr(coverage_nightly, coverage(off))] +pub mod mock; mod outcome; mod sentinel; mod util; From b170dd78f5c93d7af9a6427eb462e65887fdc272 Mon Sep 17 00:00:00 2001 From: jdidion Date: Tue, 4 Mar 2025 10:33:02 -0800 Subject: [PATCH 44/67] add test --- CHANGELOG.md | 1 + Cargo.toml | 5 - src/atomic.rs | 1 - src/barrier.rs | 6 +- src/bee/context.rs | 23 ++- src/bee/mod.rs | 1 - src/bee/queen.rs | 4 + src/hive/builder/bee.rs | 12 +- src/hive/builder/full.rs | 5 + src/hive/builder/mod.rs | 11 +- src/hive/builder/open.rs | 2 + src/hive/builder/queue.rs | 5 + src/hive/context.rs | 2 + src/hive/cores.rs | 8 +- src/hive/hive.rs | 21 +- src/hive/inner/builder.rs | 2 +- src/hive/inner/config.rs | 13 +- src/hive/inner/counter.rs | 4 +- src/hive/inner/mod.rs | 9 +- src/hive/inner/queue/channel.rs | 20 +- src/hive/inner/queue/retry.rs | 4 +- src/hive/inner/queue/status.rs | 16 +- src/hive/inner/queue/workstealing.rs | 2 + src/hive/inner/shared.rs | 252 +++++++++++------------ src/hive/inner/task.rs | 47 ++--- src/hive/mock.rs | 20 +- src/hive/mod.rs | 39 +++- src/hive/outcome/batch.rs | 1 + src/hive/outcome/{outcome.rs => impl.rs} | 26 +-- src/hive/outcome/mod.rs | 3 +- src/hive/outcome/queue.rs | 6 +- src/hive/outcome/store.rs | 3 + src/hive/sentinel.rs | 10 +- src/hive/util.rs | 5 +- src/hive/weighted.rs | 222 +++++++++----------- src/lib.rs | 1 - src/util.rs | 6 +- 37 files changed, 454 insertions(+), 364 deletions(-) rename src/hive/outcome/{outcome.rs => impl.rs} (99%) diff --git a/CHANGELOG.md b/CHANGELOG.md index f3cc9b8..31b7597 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ The general theme of this release is performance improvement by eliminating thre * `beekeeper::bee::Queen::create` now takes `&self` rather than `&mut self`. There is a new type, `beekeeper::bee::QueenMut`, with a `create(&mut self)` method, and needs to be wrapped in a `beekeeper::bee::QueenCell` to implement the `Queen` trait. This enables the `Hive` to create new workers without locking in the case of a `Queen` that does not need mutable state. * `beekeeper::bee::Context` now takes a generic parameter that must be input type of the `Worker`. * `beekeeper::hive::Hive::try_into_husk` now has an `urgent` parameter to indicate whether queued tasks should be abandoned when shutting down the hive (`true`) or if they should be allowed to finish processing (`false`). + * The type of `attempt` and `max_retries` has been changed to `u8`. This reduces memory usage and should cover the majority of use cases. * Features * Added the `TaskQueues` trait, which enables `Hive` to be specialized for different implementations of global (i.e., sending tasks from the `Hive` to worker threads) and local (i.e., worker thread-specific) queues. * `ChannelTaskQueues` implements the existing behavior, using a channel for sending tasks. diff --git a/Cargo.toml b/Cargo.toml index c3bdb43..4f2c9aa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -49,11 +49,6 @@ crossbeam = [] flume = ["dep:flume"] loole = ["dep:loole"] -[lints.rust] -unexpected_cfgs = { level = "warn", check-cfg = [ - 'cfg(coverage,coverage_nightly)', -] } - [package.metadata.cargo-all-features] allowlist = ["affinity", "local-batch", "retry"] diff --git a/src/atomic.rs b/src/atomic.rs index ef0e72c..4bf8791 100644 --- a/src/atomic.rs +++ b/src/atomic.rs @@ -4,7 +4,6 @@ //! TODO: The `Atomic` and `AtomicNumeric` traits and implementations could be replaced with the //! equivalents from the `atomic`, `atomig`, or `radium` crates, but none of those seem to be //! well-maintained at this point. - pub use num::PrimInt; use paste::paste; use std::fmt::Debug; diff --git a/src/barrier.rs b/src/barrier.rs index 51c2206..9ee84ff 100644 --- a/src/barrier.rs +++ b/src/barrier.rs @@ -1,12 +1,12 @@ use parking_lot::RwLock; use std::collections::HashSet; -/// Enables multiple threads to synchronize the beginning of some computation. Unlike -/// [`std::sync::Barrier`], this one keeps track of which threads have reached it and only -/// recognizes the first wait from each thread. use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, Barrier}; use std::thread::{self, ThreadId}; +/// Enables multiple threads to synchronize the beginning of some computation. Unlike +/// [`std::sync::Barrier`], this one keeps track of which threads have reached it and only +/// recognizes the first wait from each thread. #[derive(Clone)] pub struct IndexedBarrier(Arc); diff --git a/src/bee/context.rs b/src/bee/context.rs index 0150ecc..4712564 100644 --- a/src/bee/context.rs +++ b/src/bee/context.rs @@ -1,5 +1,4 @@ //! The context for a task processed by a `Worker`. - use std::cell::RefCell; use std::fmt::Debug; @@ -34,7 +33,7 @@ impl<'a, I> Context<'a, I> { } } - /// Creates a new `Context` with the given task_id and shared cancellation status. + /// Creates a new `Context` with the given task metadata and shared state. pub fn new(meta: TaskMeta, local: Option<&'a dyn LocalContext>) -> Self { Self { meta, @@ -45,12 +44,12 @@ impl<'a, I> Context<'a, I> { /// The unique ID of this task within the `Hive`. pub fn task_id(&self) -> TaskId { - self.meta.id + self.meta.id() } /// Returns the number of previous failed attempts to execute the current task. - pub fn attempt(&self) -> u32 { - self.meta.attempt + pub fn attempt(&self) -> u8 { + self.meta.attempt() } /// Returns `true` if the current task should be cancelled. @@ -60,7 +59,7 @@ impl<'a, I> Context<'a, I> { pub fn is_cancelled(&self) -> bool { self.local .as_ref() - .map(|worker| worker.should_cancel_tasks()) + .map(|local| local.should_cancel_tasks()) .unwrap_or(false) } @@ -102,14 +101,16 @@ pub struct TaskMeta { #[cfg(feature = "local-batch")] weight: u32, #[cfg(feature = "retry")] - attempt: u32, + attempt: u8, } impl TaskMeta { + /// Creates an empty `TaskMeta` with a default task ID. pub fn empty() -> Self { Self::new(0) } + /// Creates a new `TaskMeta` with the given task ID. pub fn new(id: TaskId) -> Self { TaskMeta { id, @@ -117,6 +118,7 @@ impl TaskMeta { } } + /// Creates a new `TaskMeta` with the given task ID and weight. #[cfg(feature = "local-batch")] pub fn with_weight(task_id: TaskId, weight: u32) -> Self { TaskMeta { @@ -126,14 +128,15 @@ impl TaskMeta { } } + /// Returns the unique ID of this task within the `Hive`. pub fn id(&self) -> TaskId { self.id } - /// The number of previous failed attempts to execute the current task. + /// Returns the number of previous failed attempts to execute the current task. /// /// Always returns `0` if the `retry` feature is not enabled. - pub fn attempt(&self) -> u32 { + pub fn attempt(&self) -> u8 { #[cfg(feature = "retry")] return self.attempt; #[cfg(not(feature = "retry"))] @@ -159,7 +162,7 @@ impl TaskMeta { #[cfg(all(test, feature = "retry"))] impl TaskMeta { - pub fn with_attempt(task_id: TaskId, attempt: u32) -> Self { + pub fn with_attempt(task_id: TaskId, attempt: u8) -> Self { Self { id: task_id, attempt, diff --git a/src/bee/mod.rs b/src/bee/mod.rs index 9501a49..c42bc21 100644 --- a/src/bee/mod.rs +++ b/src/bee/mod.rs @@ -1,4 +1,3 @@ -#![cfg_attr(coverage_nightly, feature(coverage_attribute))] //! Traits for defining workers in the worker pool. //! //! A [`Hive`](crate::hive::Hive) is populated by bees: diff --git a/src/bee/queen.rs b/src/bee/queen.rs index 0f670f8..3c7628f 100644 --- a/src/bee/queen.rs +++ b/src/bee/queen.rs @@ -28,14 +28,17 @@ pub trait QueenMut: Send + Sync + 'static { pub struct QueenCell(RwLock); impl QueenCell { + /// Creates a new `QueenCell` with the given `mut_queen`. pub fn new(mut_queen: Q) -> Self { Self(RwLock::new(mut_queen)) } + /// Returns a reference to the wrapped `Queen`. pub fn get(&self) -> impl Deref { self.0.read() } + /// Consumes this `QueenCell` and returns the inner `Queen`. pub fn into_inner(self) -> Q { self.0.into_inner() } @@ -44,6 +47,7 @@ impl QueenCell { impl Queen for QueenCell { type Kind = Q::Kind; + /// Calls the wrapped `QueenMut::create` method using interior mutability. fn create(&self) -> Self::Kind { self.0.write().create() } diff --git a/src/hive/builder/bee.rs b/src/hive/builder/bee.rs index 0a20401..fc5d170 100644 --- a/src/hive/builder/bee.rs +++ b/src/hive/builder/bee.rs @@ -44,6 +44,8 @@ impl BeeBuilder { FullBuilder::from(self.config, self.queen) } + /// Creates a new `FullBuilder` with the current configuration and queen and workstealing + /// task queues. pub fn with_workstealing_queues(self) -> FullBuilder> { FullBuilder::from(self.config, self.queen) } @@ -90,6 +92,8 @@ impl BeeBuilder> { } impl BeeBuilder> { + /// Creates a new `BeeBuilder` with a `CloneQueen` created with the given `worker` and no + /// options configured. pub fn empty_with_worker(worker: W) -> Self { Self { config: Config::empty(), @@ -97,7 +101,9 @@ impl BeeBuilder> { } } - pub fn default_with_worker(worker: W) -> Self { + /// Creates a new `BeeBuilder` with a `CloneQueen` created with the given `worker` and + /// and options configured with global defaults. + pub fn preset_with_worker(worker: W) -> Self { Self { config: Config::default(), queen: CloneQueen::new(worker), @@ -106,6 +112,8 @@ impl BeeBuilder> { } impl BeeBuilder> { + /// Creates a new `BeeBuilder` with a `DefaultQueen` created with the given `Worker` type and + /// no options configured. pub fn empty_with_worker_default() -> Self { Self { config: Config::empty(), @@ -113,6 +121,8 @@ impl BeeBuilder> { } } + /// Creates a new `BeeBuilder` with a `DefaultQueen` created with the given `Worker` type and + /// and options configured with global defaults. pub fn preset_with_worker_default() -> Self { Self { config: Config::default(), diff --git a/src/hive/builder/full.rs b/src/hive/builder/full.rs index 138399c..9c2f4f2 100644 --- a/src/hive/builder/full.rs +++ b/src/hive/builder/full.rs @@ -12,6 +12,7 @@ pub struct FullBuilder> { } impl> FullBuilder { + /// Creates a new `FullBuilder` with the given queen and no options configured. pub fn empty>(queen: Q) -> Self { Self { config: Config::empty(), @@ -20,6 +21,8 @@ impl> FullBuilder { } } + /// Creates a new `FullBuilder` with the given `queen` and options configured with global + /// defaults. pub fn preset>(queen: I) -> Self { Self { config: Config::default(), @@ -28,6 +31,7 @@ impl> FullBuilder { } } + /// Creates a new `FullBuilder` from an existing `config` and a `queen`. pub(super) fn from(config: Config, queen: Q) -> Self { Self { config, @@ -36,6 +40,7 @@ impl> FullBuilder { } } + /// Consumes this `Builder` and returns a new [`Hive`]. pub fn build(self) -> Hive { Hive::new(self.config, self.queen) } diff --git a/src/hive/builder/mod.rs b/src/hive/builder/mod.rs index 7c8e790..d5ab32e 100644 --- a/src/hive/builder/mod.rs +++ b/src/hive/builder/mod.rs @@ -51,6 +51,10 @@ pub use queue::TaskQueuesBuilder; pub use queue::channel::ChannelBuilder; pub use queue::workstealing::WorkstealingBuilder; +use crate::hive::inner::{BuilderConfig, Token}; + +/// Creates a new `OpenBuilder`. If `with_defaults` is `true`, the builder will be pre-configured +/// with the global defaults. pub fn open(with_defaults: bool) -> OpenBuilder { if with_defaults { OpenBuilder::default() @@ -59,6 +63,8 @@ pub fn open(with_defaults: bool) -> OpenBuilder { } } +/// Creates a new `ChannelBuilder`. If `with_defaults` is `true`, the builder will be +/// pre-configured with the global defaults. pub fn channel(with_defaults: bool) -> ChannelBuilder { if with_defaults { ChannelBuilder::default() @@ -66,7 +72,8 @@ pub fn channel(with_defaults: bool) -> ChannelBuilder { ChannelBuilder::empty() } } - +/// Creates a new `WorkstealingBuilder`. If `with_defaults` is `true`, the builder will be +/// pre-configured with the global defaults. pub fn workstealing(with_defaults: bool) -> WorkstealingBuilder { if with_defaults { WorkstealingBuilder::default() @@ -75,8 +82,6 @@ pub fn workstealing(with_defaults: bool) -> WorkstealingBuilder { } } -use crate::hive::inner::{BuilderConfig, Token}; - // #[cfg(all(test, feature = "affinity"))] // mod affinity_tests { // use super::{OpenBuilder, Token}; diff --git a/src/hive/builder/open.rs b/src/hive/builder/open.rs index fa6981c..8f31f90 100644 --- a/src/hive/builder/open.rs +++ b/src/hive/builder/open.rs @@ -248,6 +248,8 @@ impl OpenBuilder { ChannelBuilder::from(self.0) } + /// Consumes this `Builder` and returns a new [`WorkstealingBuilder`] using the current + /// configuration. pub fn with_workstealing_queues(self) -> WorkstealingBuilder { WorkstealingBuilder::from(self.0) } diff --git a/src/hive/builder/queue.rs b/src/hive/builder/queue.rs index 66428e0..902a263 100644 --- a/src/hive/builder/queue.rs +++ b/src/hive/builder/queue.rs @@ -2,9 +2,12 @@ use super::FullBuilder; use crate::bee::{CloneQueen, DefaultQueen, Queen, QueenCell, QueenMut, Worker}; use crate::hive::{Builder, TaskQueues}; +/// Trait implemented by builders specialized to a `TaskQueues` type. pub trait TaskQueuesBuilder: Builder + Clone + Default + Sized { + /// The type of the `TaskQueues` to use when building the `Hive`. type TaskQueues: TaskQueues; + /// Creates a new empty `Builder`. fn empty() -> Self; /// Consumes this `Builder` and returns a new [`FullBuilder`] using the given [`Queen`] to @@ -56,6 +59,7 @@ pub mod channel { use crate::hive::builder::{BuilderConfig, Token}; use crate::hive::{ChannelTaskQueues, Config}; + /// `TaskQueuesBuilder` implementation for channel-based task queues. #[derive(Clone, Default)] pub struct ChannelBuilder(Config); @@ -95,6 +99,7 @@ pub mod workstealing { use crate::hive::builder::{BuilderConfig, Token}; use crate::hive::{Config, WorkstealingTaskQueues}; + /// `TaskQueuesBuilder` implementation for workstealing-based task queues. #[derive(Clone, Default)] pub struct WorkstealingBuilder(Config); diff --git a/src/hive/context.rs b/src/hive/context.rs index a958da2..36f58f5 100644 --- a/src/hive/context.rs +++ b/src/hive/context.rs @@ -1,3 +1,4 @@ +//! Implementation of `crate::bee::LocalContext` for a `Hive`. use crate::bee::{LocalContext, Queen, TaskId, Worker}; use crate::hive::{OutcomeSender, Shared, TaskQueues, WorkerQueues}; use std::fmt; @@ -20,6 +21,7 @@ where Q: Queen, T: TaskQueues, { + /// Creates a new `HiveLocalContext` instance. pub fn new( worker_queues: &'a T::WorkerQueues, shared: &'a Arc>, diff --git a/src/hive/cores.rs b/src/hive/cores.rs index a84a485..5f41253 100644 --- a/src/hive/cores.rs +++ b/src/hive/cores.rs @@ -1,4 +1,5 @@ //! Utilities for pinning worker threads to CPU cores in a `Hive`. +use core_affinity::{self, CoreId}; use parking_lot::Mutex; use std::collections::HashSet; use std::ops::{BitOr, BitOrAssign, Sub, SubAssign}; @@ -51,14 +52,14 @@ pub fn refresh() -> usize { #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] pub struct Core { /// the OS-specific core ID - id: core_affinity::CoreId, + id: CoreId, /// whether this core is currently available for pinning threads available: bool, } impl Core { /// Creates a new `Core` with `available` set to `true`. - fn new(core_id: core_affinity::CoreId) -> Self { + fn new(core_id: CoreId) -> Self { Self { id: core_id, available: true, @@ -89,7 +90,8 @@ impl Cores { Self(Vec::new()) } - /// Returns a `Cores` set populated with the first `n` CPU indices. + /// Returns a `Cores` set populated with the first `n` CPU indices (up to the number of + /// available cores). pub fn first(n: usize) -> Self { Self(Vec::from_iter(0..n.min(num_cpus::get()))) } diff --git a/src/hive/hive.rs b/src/hive/hive.rs index 488659b..56aade4 100644 --- a/src/hive/hive.rs +++ b/src/hive/hive.rs @@ -3,7 +3,7 @@ use super::{ OutcomeBatch, OutcomeIteratorExt, OutcomeSender, Sentinel, Shared, SpawnError, TaskInput, TaskQueues, TaskQueuesBuilder, }; -use crate::bee::{ApplyError, Context, DefaultQueen, Queen, TaskId, Worker}; +use crate::bee::{Context, DefaultQueen, Queen, TaskId, Worker}; use std::borrow::Borrow; use std::collections::HashMap; use std::fmt; @@ -66,7 +66,7 @@ impl, T: TaskQueues> Hive { let (task_meta, subtask_ids) = apply_ctx.into_parts(); let outcome = match result { #[cfg(feature = "retry")] - Err(ApplyError::Retryable { input, error }) + Err(crate::bee::ApplyError::Retryable { input, error }) if subtask_ids.is_none() && shared.can_retry(&task_meta) => { match shared.try_send_retry( @@ -821,6 +821,19 @@ mod local_batch { pub fn set_worker_batch_limit(&self, batch_limit: usize) { self.shared().set_worker_batch_limit(batch_limit); } + + /// Returns the weight limit for worker threads. + pub fn worker_weight_limit(&self) -> u64 { + self.shared().worker_weight_limit() + } + + /// Sets the weight limit for worker threads. + /// + /// Depending on this hive's `TaskQueues` implementation, this method may have no effect + /// (if it does not support local batching). + pub fn set_worker_weight_limit(&self, weight_limit: u64) { + self.shared().set_worker_weight_limit(weight_limit); + } } } @@ -837,12 +850,12 @@ mod retry { T: TaskQueues, { /// Returns the current retry limit for this hive. - pub fn worker_retry_limit(&self) -> u32 { + pub fn worker_retry_limit(&self) -> u8 { self.shared().worker_retry_limit() } /// Updates the retry limit for this hive and returns the previous value. - pub fn set_worker_retry_limit(&self, limit: u32) -> u32 { + pub fn set_worker_retry_limit(&self, limit: u8) -> u8 { self.shared().set_worker_retry_limit(limit) } diff --git a/src/hive/inner/builder.rs b/src/hive/inner/builder.rs index 6a0cdc3..2b976fe 100644 --- a/src/hive/inner/builder.rs +++ b/src/hive/inner/builder.rs @@ -295,7 +295,7 @@ pub trait Builder: BuilderConfig + Sized { /// # } /// ``` #[cfg(feature = "retry")] - fn max_retries(mut self, limit: u32) -> Self { + fn max_retries(mut self, limit: u8) -> Self { let _ = if limit == 0 { self.config_ref(Token).max_retries.set(None) } else { diff --git a/src/hive/inner/config.rs b/src/hive/inner/config.rs index c943b35..5bc749e 100644 --- a/src/hive/inner/config.rs +++ b/src/hive/inner/config.rs @@ -157,9 +157,12 @@ mod local_batch { const DEFAULT_BATCH_LIMIT: usize = 10; + /// Sets the batch limit a `config` is configured with when using `Builder::default()`. pub fn set_batch_limit_default(batch_limit: usize) { DEFAULTS.lock().batch_limit.set(Some(batch_limit)); } + + /// Sets the weight limit a `config` is configured with when using `Builder::default()`. pub fn set_weight_limit_default(weight_limit: u64) { DEFAULTS.lock().weight_limit.set(Some(weight_limit)); } @@ -177,20 +180,20 @@ mod retry { use super::{Config, DEFAULTS}; use std::time::Duration; - const DEFAULT_MAX_RETRIES: u32 = 3; + const DEFAULT_MAX_RETRIES: u8 = 3; const DEFAULT_RETRY_FACTOR_SECS: u64 = 1; - /// Sets the max number of retries a `config` is configured with when using `Config::with_defaults()`. - pub fn set_max_retries_default(num_retries: u32) { + /// Sets the max number of retries a `config` is configured with when using `Builder::default()`. + pub fn set_max_retries_default(num_retries: u8) { DEFAULTS.lock().max_retries.set(Some(num_retries)); } - /// Sets the retry factor a `config` is configured with when using `Config::with_defaults()`. + /// Sets the retry factor a `config` is configured with when using `Builder::default()`. pub fn set_retry_factor_default(retry_factor: Duration) { DEFAULTS.lock().set_retry_factor_from(retry_factor); } - /// Specifies that retries should be disabled by default when using `Config::with_defaults()`. + /// Specifies that retries should be disabled by default when using `Builder::default()`. pub fn set_retries_default_disabled() { set_max_retries_default(0); } diff --git a/src/hive/inner/counter.rs b/src/hive/inner/counter.rs index 6a4741d..93072da 100644 --- a/src/hive/inner/counter.rs +++ b/src/hive/inner/counter.rs @@ -24,8 +24,8 @@ pub enum CounterError { /// The two values may be different sizes, but their total size in bits must equal the size of the /// data type (for now fixed to `64`) used to store the value. /// -/// Three operations are supported: -/// * increment the left counter (`L`) +/// The following operations are supported: +/// * increment/decrement the left counter (`L`) /// * decrement the right counter (`R`) /// * transfer an amount `N` from `L` to `R` (i.e., a simultaneous decrement of `L` and /// increment of `R` by the same amount) diff --git a/src/hive/inner/mod.rs b/src/hive/inner/mod.rs index bcd5fa1..dbfaf3a 100644 --- a/src/hive/inner/mod.rs +++ b/src/hive/inner/mod.rs @@ -1,3 +1,4 @@ +//! Internal data structures needed to implement `Hive`. mod builder; mod config; mod counter; @@ -6,6 +7,7 @@ mod queue; mod shared; mod task; +/// Prelude-like module that collects all the functions for setting global configuration defaults. pub mod set_config { pub use super::config::{reset_defaults, set_num_threads_default, set_num_threads_default_all}; #[cfg(feature = "local-batch")] @@ -35,7 +37,7 @@ use std::thread::JoinHandle; type Any = AtomicOption>; type Usize = AtomicOption; #[cfg(feature = "retry")] -type U32 = AtomicOption; +type U8 = AtomicOption; #[cfg(feature = "retry")] type U64 = AtomicOption; @@ -50,7 +52,8 @@ pub struct Task { outcome_tx: Option>, } -/// Data shared by all worker threads in a `Hive`. +/// Data shared by all worker threads in a `Hive`. This is the private API used by the `Hive` and +/// worker threads to enqueue, dequeue, and process tasks. pub struct Shared> { /// core configuration parameters config: Config, @@ -105,7 +108,7 @@ pub struct Config { weight_limit: U64, /// Maximum number of retries for a task #[cfg(feature = "retry")] - max_retries: U32, + max_retries: U8, /// Multiplier for the retry backoff strategy #[cfg(feature = "retry")] retry_factor: U64, diff --git a/src/hive/inner/queue/channel.rs b/src/hive/inner/queue/channel.rs index 99c5be5..ae90b49 100644 --- a/src/hive/inner/queue/channel.rs +++ b/src/hive/inner/queue/channel.rs @@ -17,6 +17,10 @@ type TaskSender = crossbeam_channel::Sender>; /// Type alias for the input task channel receiver type TaskReceiver = crossbeam_channel::Receiver>; +/// `TaskQueues` implementation using `crossbeam` channels for the global queue. +/// +/// Worker threads may have access to local retry and/or batch queues, depending on which features +/// are enabled. pub struct ChannelTaskQueues { global: Arc>, local: RwLock>>>, @@ -162,6 +166,7 @@ impl WorkerQueues for ChannelWorkerQueues { } } +/// Worker thread-specific data shared with the main thread. struct LocalQueueShared { _thread_index: usize, /// queue of abandon tasks @@ -190,8 +195,8 @@ impl LocalQueueShared { } /// Updates the local queues based on the provided `config`: - /// If `local-batch` is enabled, resizes the batch queue if necessary. - /// If `retry` is enabled, updates the retry factor. + /// * If `local-batch` is enabled, resizes the batch queue if necessary. + /// * If `retry` is enabled, updates the retry factor. fn update(&self, _global: &GlobalQueue, _config: &Config) { #[cfg(feature = "local-batch")] self.local_batch.set_limits( @@ -278,6 +283,17 @@ mod local_batch { use crossbeam_queue::ArrayQueue; use parking_lot::RwLock; + /// Worker thread-local queue for tasks used to reduce the frequency of polling the global + /// queue (which may have a lot of contention from other worker threads). + /// + /// When the queue is empty, then it attempts to refill itself from the global queue. This is + /// done considering both the size and weight limits - i.e., the local queue is filled until + /// either it is full or the total weight of queued tasks exceeds the weight limit. + /// + /// This queue is implemented internally using a crossbeam `ArrayQueue`, which has a fixed size. + /// The queue can be resized dynamically by creating a new queue and copying the tasks over. If + /// the new queue is smaller than the old one, then any excess tasks are pushed back to the + /// global queue. pub struct WorkerBatchQueue { inner: RwLock>>>, batch_limit: AtomicUsize, diff --git a/src/hive/inner/queue/retry.rs b/src/hive/inner/queue/retry.rs index e5dfb5e..7469724 100644 --- a/src/hive/inner/queue/retry.rs +++ b/src/hive/inner/queue/retry.rs @@ -47,7 +47,7 @@ impl RetryQueue { Some(queue) => { // compute the delay let delay = 2u64 - .checked_pow(task.meta.attempt() - 1) + .checked_pow(task.meta.attempt() as u32 - 1) .and_then(|multiplier| { self.delay_factor .get() @@ -157,7 +157,7 @@ mod tests { impl Task { /// Creates a new `Task` with the given `task_id`. - fn with_attempt(task_id: TaskId, input: W::Input, attempt: u32) -> Self { + fn with_attempt(task_id: TaskId, input: W::Input, attempt: u8) -> Self { Self { input, meta: TaskMeta::with_attempt(task_id, attempt), diff --git a/src/hive/inner/queue/status.rs b/src/hive/inner/queue/status.rs index a59d550..4ef219e 100644 --- a/src/hive/inner/queue/status.rs +++ b/src/hive/inner/queue/status.rs @@ -4,23 +4,37 @@ const OPEN: u8 = 0; const CLOSED_PUSH: u8 = 1; const CLOSED_POP: u8 = 2; +/// Represents the status of a task queue. +/// +/// This is a simple state machine +/// OPEN -> CLOSED_PUSH -> CLOSED_POP +/// |________________________^ pub struct Status(AtomicU8); impl Status { + /// Returns `true` if the queue status is `CLOSED_PUSH` or `CLOSED_POP`. pub fn is_closed(&self) -> bool { self.0.get() > OPEN } + /// Returns `true` if the queue can accept new tasks. pub fn can_push(&self) -> bool { self.0.get() < CLOSED_PUSH } + /// Returns `true` if the queue can remove tasks. pub fn can_pop(&self) -> bool { self.0.get() < CLOSED_POP } + /// Sets the queue status to `CLOSED_PUSH` if `urgent` is `false`, or `CLOSED_POP` if `urgent` + /// is `true`. pub fn set(&self, urgent: bool) { - self.0.set(if urgent { CLOSED_POP } else { CLOSED_PUSH }); + // TODO: this update should be done with `fetch_max` + let new_status = if urgent { CLOSED_POP } else { CLOSED_PUSH }; + if new_status > self.0.get() { + self.0.set(new_status); + } } } diff --git a/src/hive/inner/queue/workstealing.rs b/src/hive/inner/queue/workstealing.rs index 09dfe97..4596ddc 100644 --- a/src/hive/inner/queue/workstealing.rs +++ b/src/hive/inner/queue/workstealing.rs @@ -20,6 +20,7 @@ use std::time::Duration; /// Time to wait after trying to pop and finding all queues empty. const EMPTY_DELAY: Duration = Duration::from_millis(100); +/// `TaskQueues` implementation using workstealing. pub struct WorkstealingTaskQueues { global: Arc>, local: RwLock>>>, @@ -247,6 +248,7 @@ impl Deref for WorkstealingWorkerQueues { } } +/// Worker thread-specific data shared with the main thread. struct LocalQueueShared { _thread_index: usize, /// queue of abandon tasks diff --git a/src/hive/inner/shared.rs b/src/hive/inner/shared.rs index 79d9ef2..5d63175 100644 --- a/src/hive/inner/shared.rs +++ b/src/hive/inner/shared.rs @@ -49,7 +49,8 @@ impl, T: TaskQueues> Shared { } /// Spawns the initial set of `self.config.num_threads` worker threads using the provided - /// spawning function. Returns the number of worker threads that were successfully started. + /// spawning function. The results are stored in `self.spawn_results[0..num_threads]`. Returns + /// the number of worker threads that were successfully started. pub fn init_threads(&self, f: F) -> usize where F: Fn(usize) -> Result, SpawnError>, @@ -63,7 +64,7 @@ impl, T: TaskQueues> Shared { } /// Increases the maximum number of threads allowed in the `Hive` by `num_threads`, and - /// attempts to spawn threads with indices in `range = cur_index..cur_index + num_threads` + /// attempts to spawn threads with indices in the range `cur_index..cur_index + num_threads` /// using the provided spawning function. The results are stored in `self.spawn_results[range]`. /// Returns the number of new worker threads that were successfully started. pub fn grow_threads(&self, num_threads: usize, f: F) -> usize @@ -156,19 +157,6 @@ impl, T: TaskQueues> Shared { self.queen.create() } - /// Increments the number of queued tasks. Returns a new `Task` with the provided input and - /// `outcome_tx` and the next ID. - pub fn prepare_task(&self, input: I, outcome_tx: Option<&OutcomeSender>) -> Task - where - I: Into>, - { - self.num_tasks - .increment_left(1) - .expect("overflowed queued task counter"); - let task_id = self.next_task_id.add(1); - Task::new(task_id, input.into(), outcome_tx.cloned()) - } - /// Creates a new `Task` for the given input and outcome channel, and adds it to the global /// queue. pub fn send_one_global(&self, input: I, outcome_tx: Option<&OutcomeSender>) -> TaskId @@ -249,6 +237,19 @@ impl, T: TaskQueues> Shared { } } + /// Increments the number of queued tasks. Returns a new `Task` with the provided input and + /// `outcome_tx` and the next ID. + pub fn prepare_task(&self, input: I, outcome_tx: Option<&OutcomeSender>) -> Task + where + I: Into>, + { + self.num_tasks + .increment_left(1) + .expect("overflowed queued task counter"); + let task_id = self.next_task_id.add(1); + Task::new(task_id, input.into(), outcome_tx.cloned()) + } + /// Adds `task` to the global queue if possible, otherwise abandons it - converts it to an /// `Unprocessed` outcome and sends it to the outcome channel or stores it in the hive. pub fn push_global(&self, task: Task) { @@ -301,18 +302,6 @@ impl, T: TaskQueues> Shared { }) } - /// Sends an outcome to `outcome_tx`, or stores it in the `Hive` shared data if there is no - /// sender, or if the send fails. - pub fn send_or_store_outcome(&self, outcome: Outcome, outcome_tx: Option>) { - if let Some(outcome) = if let Some(tx) = outcome_tx { - tx.try_send_msg(outcome) - } else { - Some(outcome) - } { - self.add_outcome(outcome) - } - } - pub fn abandon_task(&self, task: Task) { let (outcome, outcome_tx) = task.into_unprocessed(); self.send_or_store_outcome(outcome, outcome_tx); @@ -352,16 +341,40 @@ impl, T: TaskQueues> Shared { task_ids } + #[cfg(feature = "local-batch")] + pub fn abandon_if_too_heavy(&self, task: Task) -> Result, TaskId> { + let weight_limit = self.config.weight_limit.get().unwrap_or(0); + if weight_limit > 0 && task.meta().weight() as u64 > weight_limit { + let task_id = task.id(); + let (outcome, outcome_tx) = task.into_overweight(); + self.send_or_store_outcome(outcome, outcome_tx); + // decrement the queued counter since it was incremented but the task was never queued + let _ = self.num_tasks.decrement_left(1); + self.no_work_notify_all(); + Err(task_id) + } else { + Ok(task) + } + } + + /// Sends an outcome to `outcome_tx`, or stores it in the `Hive` shared data if there is no + /// sender, or if the send fails. + pub fn send_or_store_outcome(&self, outcome: Outcome, outcome_tx: Option>) { + if let Some(outcome) = if let Some(tx) = outcome_tx { + tx.try_send_msg(outcome) + } else { + Some(outcome) + } { + self.add_outcome(outcome) + } + } + /// Called by a worker thread after completing a task. Notifies any thread that has `join`ed /// the `Hive` if there is no more work to be done. #[inline] pub fn finish_task(&self, panicking: bool) { - self.finish_tasks(1, panicking); - } - - pub fn finish_tasks(&self, n: u64, panicking: bool) { self.num_tasks - .decrement_right(n) + .decrement_right(1) .expect("active task counter was smaller than expected"); if panicking { self.num_panics.add(1); @@ -557,6 +570,74 @@ where } } +#[cfg(any(feature = "local-batch", feature = "retry"))] +mod update_config { + use super::Shared; + use crate::atomic::{Atomic, AtomicOption}; + use crate::bee::{Queen, Worker}; + use crate::hive::TaskQueues; + use std::fmt::Debug; + + impl Shared + where + W: Worker, + Q: Queen, + T: TaskQueues, + { + fn maybe_update(&self, new_value: P, option: &AtomicOption) -> P + where + P: Eq + Copy + Clone + Debug + Default, + A: Atomic

, + { + let prev_value = option.try_set(new_value).unwrap_or_default(); + if prev_value == new_value { + return prev_value; + } + let num_threads = self.num_threads(); + if num_threads == 0 { + return prev_value; + } + self.task_queues + .update_for_threads(0, num_threads, &self.config); + prev_value + } + + /// Changes the local queue batch size. This requires allocating a new queue for each + /// worker thread. + /// + /// Note: this method will block the current thread waiting for all local queues to become + /// writable; if `batch_limit` is less than the current batch size, this method will also + /// block while any thread's queue length is > `batch_limit` before moving the elements. + #[cfg(feature = "local-batch")] + pub fn set_worker_batch_limit(&self, batch_limit: usize) -> usize { + self.maybe_update(batch_limit, &self.config.batch_limit) + } + + /// Changes the local queue batch weight limit. + #[cfg(feature = "local-batch")] + pub fn set_worker_weight_limit(&self, weight_limit: u64) -> u64 { + self.maybe_update(weight_limit, &self.config.weight_limit) + } + + /// Sets the worker retry limit and returns the previous value. + #[cfg(feature = "retry")] + pub fn set_worker_retry_limit(&self, max_retries: u8) -> u8 { + self.maybe_update(max_retries, &self.config.max_retries) + } + + /// Sets the worker retry factor and returns the previous value. + #[cfg(feature = "retry")] + pub fn set_worker_retry_factor( + &self, + duration: std::time::Duration, + ) -> std::time::Duration { + std::time::Duration::from_nanos( + self.maybe_update(duration.as_nanos() as u64, &self.config.retry_factor), + ) + } + } +} + #[cfg(feature = "affinity")] mod affinity { use super::{Shared, TaskQueues}; @@ -589,9 +670,8 @@ mod affinity { #[cfg(feature = "local-batch")] mod local_batch { - use super::Shared; use crate::bee::{Queen, Worker}; - use crate::hive::TaskQueues; + use crate::hive::inner::{Shared, TaskQueues}; impl Shared where @@ -604,32 +684,9 @@ mod local_batch { self.config.batch_limit.get().unwrap_or_default() } - /// Changes the local queue batch size. This requires allocating a new queue for each - /// worker thread. - /// - /// Note: this method will block the current thread waiting for all local queues to become - /// writable; if `batch_limit` is less than the current batch size, this method will also - /// block while any thread's queue length is > `batch_limit` before moving the elements. - /// - /// TODO: this needs to be moved to an extension that is specific to channel hive - pub fn set_worker_batch_limit(&self, batch_limit: usize) -> usize { - // update the batch size first so any new threads spawned won't need to have their - // queues resized - let prev_batch_limit = self - .config - .batch_limit - .try_set(batch_limit) - .unwrap_or_default(); - if prev_batch_limit == batch_limit { - return prev_batch_limit; - } - let num_threads = self.num_threads(); - if num_threads == 0 { - return prev_batch_limit; - } - self.task_queues - .update_for_threads(0, num_threads, &self.config); - prev_batch_limit + /// Returns the local queue batch weight limit. A value of `0` means there is no weight + pub fn worker_weight_limit(&self) -> u64 { + self.config.weight_limit.get().unwrap_or_default() } } } @@ -639,7 +696,7 @@ mod retry { use crate::bee::{Queen, TaskMeta, Worker}; use crate::hive::inner::{Shared, Task, TaskQueues}; use crate::hive::{OutcomeSender, WorkerQueues}; - use std::time::{Duration, Instant}; + use std::time::Instant; impl Shared where @@ -648,52 +705,13 @@ mod retry { T: TaskQueues, { /// Returns the current worker retry limit. - pub fn worker_retry_limit(&self) -> u32 { + pub fn worker_retry_limit(&self) -> u8 { self.config.max_retries.get().unwrap_or_default() } - /// Sets the worker retry limit and returns the previous value. - pub fn set_worker_retry_limit(&self, max_retries: u32) -> u32 { - let prev_retry_limit = self - .config - .max_retries - .try_set(max_retries) - .unwrap_or_default(); - if prev_retry_limit == max_retries { - return prev_retry_limit; - } - let num_threads = self.num_threads(); - if num_threads == 0 { - return prev_retry_limit; - } - self.task_queues - .update_for_threads(0, num_threads, &self.config); - prev_retry_limit - } - /// Returns the current worker retry factor. - pub fn worker_retry_factor(&self) -> Duration { - Duration::from_millis(self.config.retry_factor.get().unwrap_or_default()) - } - - /// Sets the worker retry factor and returns the previous value. - pub fn set_worker_retry_factor(&self, duration: Duration) -> Duration { - let prev_retry_factor = Duration::from_nanos( - self.config - .retry_factor - .try_set(duration.as_nanos() as u64) - .unwrap_or_default(), - ); - if prev_retry_factor == duration { - return prev_retry_factor; - } - let num_threads = self.num_threads(); - if num_threads == 0 { - return prev_retry_factor; - } - self.task_queues - .update_for_threads(0, num_threads, &self.config); - prev_retry_factor + pub fn worker_retry_factor(&self) -> std::time::Duration { + std::time::Duration::from_millis(self.config.retry_factor.get().unwrap_or_default()) } /// Returns `true` if the hive is configured to retry tasks and the `attempt` field of the @@ -724,34 +742,6 @@ mod retry { } } -#[cfg(feature = "local-batch")] -mod weighting { - use crate::bee::{Queen, TaskId, Worker}; - use crate::hive::inner::{Shared, Task, TaskQueues}; - - impl Shared - where - W: Worker, - Q: Queen, - T: TaskQueues, - { - pub fn abandon_if_too_heavy(&self, task: Task) -> Result, TaskId> { - let weight_limit = self.config.weight_limit.get().unwrap_or(0); - if weight_limit > 0 && task.meta().weight() as u64 > weight_limit { - let task_id = task.id(); - let (outcome, outcome_tx) = task.into_overweight(); - self.send_or_store_outcome(outcome, outcome_tx); - // decrement the queued counter since it was incremented but the task was never queued - let _ = self.num_tasks.decrement_left(1); - self.no_work_notify_all(); - Err(task_id) - } else { - Ok(task) - } - } - } -} - #[cfg(test)] mod tests { use crate::bee::DefaultQueen; diff --git a/src/hive/inner/task.rs b/src/hive/inner/task.rs index d47e06e..4e547b4 100644 --- a/src/hive/inner/task.rs +++ b/src/hive/inner/task.rs @@ -2,6 +2,8 @@ use super::Task; use crate::bee::{TaskId, TaskMeta, Worker}; use crate::hive::{Outcome, OutcomeSender}; +/// The type of input to a task for a given `Worker` type. This changes depending on the features +/// that are enabled. pub use task_impl::TaskInput; impl Task { @@ -18,6 +20,22 @@ impl Task { } } + /// Creates a new `Task` with the given metadata, and increments the attempt number. + #[cfg(feature = "retry")] + pub fn with_meta_inc_attempt( + input: W::Input, + mut meta: TaskMeta, + outcome_tx: Option>, + ) -> Self { + meta.inc_attempt(); + Self { + input, + meta, + outcome_tx, + } + } + + /// Returns the ID of the task. #[inline] pub fn id(&self) -> TaskId { self.meta.id() @@ -34,7 +52,7 @@ impl Task { (self.input, self.meta, self.outcome_tx) } - /// Consumes this `Task` and returns a `Outcome::Unprocessed` outcome with the input and ID, + /// Consumes this `Task` and returns an `Outcome::Unprocessed` outcome with the input and ID, /// and the outcome sender. pub fn into_unprocessed(self) -> (Outcome, Option>) { let outcome = Outcome::Unprocessed { @@ -92,8 +110,8 @@ mod task_impl { } } - /// Consumes this `Task` and returns a `Outcome::WeightLimitExceeded` outcome with the input, - /// weight, and ID, and the outcome sender. + /// Consumes this `Task` and returns a `Outcome::WeightLimitExceeded` outcome with the + /// input, weight, and ID, and the outcome sender. pub fn into_overweight(self) -> (Outcome, Option>) { let outcome = Outcome::WeightLimitExceeded { input: self.input, @@ -105,29 +123,6 @@ mod task_impl { } } -#[cfg(feature = "retry")] -mod retry { - use super::Task; - use crate::bee::{TaskMeta, Worker}; - use crate::hive::OutcomeSender; - - impl Task { - /// Creates a new `Task`. - pub fn with_meta_inc_attempt( - input: W::Input, - mut meta: TaskMeta, - outcome_tx: Option>, - ) -> Self { - meta.inc_attempt(); - Self { - input, - meta, - outcome_tx, - } - } - } -} - impl> Clone for Task { fn clone(&self) -> Self { Self { diff --git a/src/hive/mock.rs b/src/hive/mock.rs index 9158506..e79ea7b 100644 --- a/src/hive/mock.rs +++ b/src/hive/mock.rs @@ -1,24 +1,24 @@ -use super::{Task, TaskInput}; -use crate::bee::{Context, LocalContext, TaskId, TaskMeta, Worker, WorkerResult}; +//! Utilities for testing `Worker`s. +use super::{Outcome, Task, TaskInput}; +use crate::bee::{Context, LocalContext, TaskId, Worker}; use std::cell::RefCell; +/// A struct used for testing `Worker`s in a mock environment without needing to create a `Hive`. #[derive(Debug)] pub struct MockTaskRunner(RefCell); impl MockTaskRunner { + /// Creates a new `MockTaskRunner` with a starting task ID of 0. pub fn new() -> Self { Self(RefCell::new(0)) } /// Applies the given `worker` to the given `input`. /// - /// Returns a tuple of the apply result, the task metadata, and the IDs of any - /// subtasks that were submitted. - pub fn apply( - &self, - input: TaskInput, - worker: &mut W, - ) -> (WorkerResult, TaskMeta, Option>) { + /// The task ID is automatically incremented and used to create the `Context`. + /// + /// Returns the `Outcome` from executing the task. + pub fn apply(&self, worker: &mut W, input: TaskInput) -> Outcome { let task_id = self.next_task_id(); let local = MockLocalContext(&self); let task: Task = Task::new(task_id, input, None); @@ -26,7 +26,7 @@ impl MockTaskRunner { let ctx = Context::new(task_meta, Some(&local)); let result = worker.apply(input, &ctx); let (task_meta, subtask_ids) = ctx.into_parts(); - (result, task_meta, subtask_ids) + Outcome::from_worker_result(result, task_meta, subtask_ids) } fn next_task_id(&self) -> TaskId { diff --git a/src/hive/mod.rs b/src/hive/mod.rs index 88bf413..7883373 100644 --- a/src/hive/mod.rs +++ b/src/hive/mod.rs @@ -388,24 +388,27 @@ pub mod cores; mod hive; mod husk; mod inner; -#[cfg_attr(coverage_nightly, coverage(off))] pub mod mock; mod outcome; mod sentinel; mod util; +#[cfg(feature = "local-batch")] mod weighted; pub use self::builder::{BeeBuilder, ChannelBuilder, FullBuilder, OpenBuilder, TaskQueuesBuilder}; pub use self::builder::{ channel as channel_builder, open as open_builder, workstealing as workstealing_builder, }; +#[cfg(feature = "affinity")] +pub use self::cores::{Core, Cores}; pub use self::hive::{DefaultHive, Hive, Poisoned}; pub use self::husk::Husk; pub use self::inner::{ Builder, ChannelTaskQueues, TaskInput, WorkstealingTaskQueues, set_config::*, }; pub use self::outcome::{Outcome, OutcomeBatch, OutcomeIteratorExt, OutcomeStore}; -pub use self::weighted::Weighted; +#[cfg(feature = "local-batch")] +pub use self::weighted::{Weighted, WeightedExactSizeIteratorExt, WeightedIteratorExt}; use self::context::HiveLocalContext; use self::inner::{Config, Shared, Task, TaskQueues, WorkerQueues}; @@ -428,6 +431,8 @@ pub fn outcome_channel() -> (OutcomeSender, OutcomeReceiver) { } pub mod prelude { + #[cfg(feature = "local-batch")] + pub use super::Weighted; pub use super::{ Builder, Hive, Husk, Outcome, OutcomeBatch, OutcomeIteratorExt, OutcomeStore, Poisoned, TaskQueuesBuilder, channel_builder, open_builder, outcome_channel, workstealing_builder, @@ -1998,8 +2003,8 @@ mod local_batch_tests { use crate::bee::DefaultQueen; use crate::bee::stock::{Thunk, ThunkWorker}; use crate::hive::{ - Builder, Hive, OutcomeIteratorExt, OutcomeReceiver, OutcomeSender, TaskQueues, - TaskQueuesBuilder, channel_builder, workstealing_builder, + Builder, Hive, Outcome, OutcomeIteratorExt, OutcomeReceiver, OutcomeSender, TaskQueues, + TaskQueuesBuilder, WeightedExactSizeIteratorExt, channel_builder, workstealing_builder, }; use rstest::*; use std::collections::HashMap; @@ -2178,6 +2183,32 @@ mod local_batch_tests { assert!(thread_counts.values().all(|count| *count > BATCH_LIMIT_0)); assert_eq!(thread_counts.values().sum::(), total_tasks); } + + #[rstest] + fn test_swarm_default_weighted( + #[values(channel_builder, workstealing_builder)] builder_factory: F, + ) where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + const NUM_THREADS: usize = 4; + const BATCH_LIMIT: usize = 24; + let hive = builder_factory(false) + .with_worker_default::>() + .num_threads(NUM_THREADS) + .batch_limit(BATCH_LIMIT) + .build(); + let inputs = (0..10u8) + .map(|i| { + Thunk::of(move || { + thread::sleep(Duration::from_millis((10 - i as u64) * 100)); + i + }) + }) + .into_default_weighted(); + let outputs: Vec<_> = hive.swarm(inputs).map(Outcome::unwrap).collect(); + assert_eq!(outputs, (0..10).collect::>()) + } } #[cfg(all(test, feature = "retry"))] diff --git a/src/hive/outcome/batch.rs b/src/hive/outcome/batch.rs index af22e82..1bdbcdf 100644 --- a/src/hive/outcome/batch.rs +++ b/src/hive/outcome/batch.rs @@ -47,6 +47,7 @@ impl DerefOutcomes for OutcomeBatch { } } +/// Functions only used in testing. #[cfg(test)] impl OutcomeBatch { pub(crate) fn empty() -> Self { diff --git a/src/hive/outcome/outcome.rs b/src/hive/outcome/impl.rs similarity index 99% rename from src/hive/outcome/outcome.rs rename to src/hive/outcome/impl.rs index f1aff75..6f97bd0 100644 --- a/src/hive/outcome/outcome.rs +++ b/src/hive/outcome/impl.rs @@ -4,18 +4,6 @@ use std::cmp::Ordering; use std::fmt::Debug; impl Outcome { - pub(in crate::hive) fn from_fatal( - input: W::Input, - task_meta: TaskMeta, - error: W::Error, - ) -> Self { - Self::Failure { - input: Some(input), - error, - task_id: task_meta.id(), - } - } - /// Converts a worker `result` into an `Outcome` with the given task_id and optional subtask ids. pub(in crate::hive) fn from_worker_result( result: WorkerResult, @@ -93,6 +81,20 @@ impl Outcome { } } + /// Creates a new `Outcome::Fatal` from the given input, task metadata, and error. + #[cfg(feature = "retry")] + pub(in crate::hive) fn from_fatal( + input: W::Input, + task_meta: TaskMeta, + error: W::Error, + ) -> Self { + Self::Failure { + input: Some(input), + error, + task_id: task_meta.id(), + } + } + /// Returns `true` if this is a `Success` outcome. pub fn is_success(&self) -> bool { matches!(self, Self::Success { .. }) diff --git a/src/hive/outcome/mod.rs b/src/hive/outcome/mod.rs index 985fc2b..10f577d 100644 --- a/src/hive/outcome/mod.rs +++ b/src/hive/outcome/mod.rs @@ -1,7 +1,6 @@ mod batch; mod iter; -#[allow(clippy::module_inception)] -mod outcome; +mod r#impl; mod queue; mod store; diff --git a/src/hive/outcome/queue.rs b/src/hive/outcome/queue.rs index a571d2d..cb6c0a2 100644 --- a/src/hive/outcome/queue.rs +++ b/src/hive/outcome/queue.rs @@ -5,13 +5,16 @@ use parking_lot::Mutex; use std::collections::HashMap; use std::ops::{Deref, DerefMut}; +/// Data structure that supports queuing `Outcomes` from multiple threads (without locking) and +/// fetching from a single thread (which requires draining the queue into a map that is behind a +/// mutex). pub struct OutcomeQueue { queue: SegQueue>, outcomes: Mutex>>, } impl OutcomeQueue { - /// Adds an outcome to the queue. + /// Adds an `outcome` to the queue. pub fn push(&self, outcome: Outcome) { self.queue.push(outcome); } @@ -37,6 +40,7 @@ impl OutcomeQueue { outcomes } + /// Consumes this `OutcomeQueue`, drains the queue, and returns the outcomes as a map. pub fn into_inner(self) -> HashMap> { let mut outcomes = self.outcomes.into_inner(); // add any queued outcomes to the map diff --git a/src/hive/outcome/store.rs b/src/hive/outcome/store.rs index 1858653..8a93578 100644 --- a/src/hive/outcome/store.rs +++ b/src/hive/outcome/store.rs @@ -5,6 +5,8 @@ use std::{ ops::{Deref, DerefMut}, }; +/// Trait implemented by structs that provide temporary access (both read-only and mutable) to a +/// reference to a map of outcomes. pub trait DerefOutcomes { /// Returns a read-only reference to a map of task task_id to `Outcome`. fn outcomes_deref(&self) -> impl Deref>>; @@ -13,6 +15,7 @@ pub trait DerefOutcomes { fn outcomes_deref_mut(&mut self) -> impl DerefMut>>; } +/// Trait implemented by structs that provide (thread-unsafe) access to an owned outcome map. pub trait OwnedOutcomes: Sized { /// Returns an owned map of task task_id to `Outcome`. fn outcomes(self) -> HashMap>; diff --git a/src/hive/sentinel.rs b/src/hive/sentinel.rs index 795aa03..ce09878 100644 --- a/src/hive/sentinel.rs +++ b/src/hive/sentinel.rs @@ -2,7 +2,7 @@ use super::{Shared, TaskQueues}; use crate::bee::{Queen, Worker}; use std::io::Error as SpawnError; use std::sync::Arc; -use std::thread::JoinHandle; +use std::thread::{self, JoinHandle}; /// Sentinel for a worker thread. Until the sentinel is cancelled, it will respawn the worker /// thread if it panics. @@ -13,9 +13,13 @@ where T: TaskQueues, F: Fn(usize, &Arc>) -> Result, SpawnError> + 'static, { + /// The index of the worker thread thread_index: usize, + /// The shared data to pass to the new worker thread when respawning shared: Arc>, + /// Whether sentinel is active active: bool, + /// The function that will be called to respawn the worker thread respawn_fn: F, } @@ -52,10 +56,10 @@ where if self.active { // if the sentinel is active, that means the thread panicked during task execution, so // we have to finish the task here before respawning - self.shared.finish_task(std::thread::panicking()); + self.shared.finish_task(thread::panicking()); // only respawn if the sentinel is active and the hive has not been poisoned if !self.shared.is_poisoned() { - // can't do anything with the previous result + // can't do anything with the previous JoinHandle let _ = self .shared .respawn_thread(self.thread_index, |thread_index| { diff --git a/src/hive/util.rs b/src/hive/util.rs index 4513c0f..2aeca6c 100644 --- a/src/hive/util.rs +++ b/src/hive/util.rs @@ -1,3 +1,4 @@ +//! Internal utilities for the `hive` module. use crossbeam_utils::Backoff; use std::sync::Arc; use std::time::{Duration, Instant}; @@ -7,8 +8,8 @@ const MAX_WAIT: Duration = Duration::from_secs(10); /// Utility function to loop (with exponential backoff) waiting for other references to `arc` to /// drop so it can be unwrapped into its inner value. /// -/// If `arc` cannot be unwrapped with a certain amount of time (with an exponentially -/// increasing gap between each iteration), `arc` is returned as an error. +/// If `arc` cannot be unwrapped within a certain amount of time (with an exponentially increasing +/// gap between each iteration), `arc` is returned as an error. pub fn unwrap_arc(mut arc: Arc) -> Result> { // wait for worker threads to drop, then take ownership of the shared data and convert it // into a Husk diff --git a/src/hive/weighted.rs b/src/hive/weighted.rs index cb5541d..b6defed 100644 --- a/src/hive/weighted.rs +++ b/src/hive/weighted.rs @@ -1,3 +1,4 @@ +//! Weighted value used for task submission with the `local-batch` feature. use std::ops::Deref; /// Wraps a value of type `T` and an associated weight. @@ -7,10 +8,13 @@ pub struct Weighted { } impl Weighted { + /// Creates a new `Weighted` instance with the given value and weight. pub fn new(value: T, weight: u32) -> Self { Self { value, weight } } + /// Creates a new `Weighted` instance with the given value and weight obtained from calling the + /// given function on `value`. pub fn from_fn(value: T, f: F) -> Self where F: FnOnce(&T) -> u32, @@ -19,26 +23,27 @@ impl Weighted { Self::new(value, weight) } - pub fn weight(&self) -> u32 { - self.weight + /// Creates a new `Weighted` instance with the given value and weight obtained by converting + /// the value into a `u32`. + pub fn from_identity(value: T) -> Self + where + T: Into + Clone, + { + let weight = value.clone().into(); + Self::new(value, weight) } - pub fn into_inner(self) -> T { - self.value + /// Returns the weight associated with this `Weighted` value. + pub fn weight(&self) -> u32 { + self.weight } + /// Returns the value and weight as a tuple. pub fn into_parts(self) -> (T, u32) { (self.value, self.weight) } } -impl + Clone> Weighted { - pub fn from_identity(value: T) -> Self { - let weight = value.clone().into(); - Self::new(value, weight) - } -} - impl Deref for Weighted { type Target = T; @@ -59,114 +64,87 @@ impl From<(T, u32)> for Weighted { } } -// /// Trait implemented by a type that can be converted into an iterator over `Weighted` items. -// pub trait IntoWeightedIterator { -// type Item; -// type IntoIter: Iterator>; - -// fn into_weighted_iter(self) -> Self::IntoIter; -// } - -// // Calls to Hive task submission functions should work with iterators of `W::Input` regardless of -// // whether the `local-batch` feature is enabled. This blanket implementation of -// // `IntoWeightedIterator` just gives every task a weight of 0. -// impl> IntoWeightedIterator for I { -// type Item = T; -// type IntoIter = std::iter::Map Weighted>; - -// fn into_weighted_iter(self) -> Self::IntoIter { -// self.into_iter().map(Into::into) -// } -// } - -// struct WeightedIter(Box>>); - -// impl WeightedIter { -// fn from_tuples(tuples: I) -> Self -// where -// I: IntoIterator, -// I::IntoIter: 'static, -// { -// Self(Box::new(tuples.into_iter().map(Into::into))) -// } - -// fn from_iter(items: I, weights: W) -> Self -// where -// I: IntoIterator, -// I::IntoIter: 'static, -// W: IntoIterator, -// W::IntoIter: 'static, -// { -// Self(Box::new( -// items -// .into_iter() -// .zip(weights.into_iter().chain(std::iter::repeat(0))) -// .map(Into::into), -// )) -// } - -// fn from_const(items: I, weight: u32) -> Self -// where -// I: IntoIterator, -// I::IntoIter: 'static, -// { -// Self::from_iter(items, std::iter::repeat(weight)) -// } - -// fn from_fn(items: B, f: F) -> Self -// where -// B: IntoIterator, -// B::IntoIter: 'static, -// F: Fn(&T) -> u32 + 'static, -// { -// Self(Box::new(items.into_iter().map(move |value| { -// let weight = f(&value); -// Weighted::new(value, weight) -// }))) -// } -// } - -// impl + Clone + 'static> WeightedIter { -// fn from_identity(items: I) -> Self -// where -// I: IntoIterator, -// I::IntoIter: 'static, -// { -// Self::from_fn(items, |item| item.clone().into()) -// } -// } - -// impl IntoWeightedIterator for WeightedIter { -// type Item = T; -// type IntoIter = Box>>; - -// fn into_weighted_iter(self) -> Self::IntoIter { -// self.0 -// } -// } - -// #[cfg(feature = "local-batch")] -// pub fn map2(&self, batch: B) -> impl Iterator> + use -// where -// B: IntoWeightedIterator, -// { -// let (tx, rx) = outcome_channel(); -// let task_ids: Vec<_> = batch -// .into_weighted_iter() -// .map(|(task, weight)| self.apply_send(task, &tx)) -// .collect(); -// drop(tx); -// rx.select_ordered(task_ids) -// } -// #[cfg(test)] -// mod foo { -// use super::*; -// use crate::bee::stock::EchoWorker; - -// fn test_foo() { -// let hive = DefaultHive::>::default(); -// let result = hive -// .map2(Weighted::from_iter(0..10, 0..10)) -// .collect::>(); -// } -// } +/// Extends `IntoIterator` to add methods to convert any iterator into an iterator over `Weighted` +/// items. +pub trait WeightedIteratorExt: IntoIterator + Sized { + fn into_weighted(self) -> impl Iterator> + where + Self: IntoIterator, + { + self.into_iter() + .map(|(value, weight)| Weighted::new(value, weight)) + } + + fn into_default_weighted(self) -> impl Iterator> { + self.into_iter().map(Into::into) + } + + fn into_const_weighted(self, weight: u32) -> impl Iterator> { + self.into_iter() + .map(move |item| Weighted::new(item, weight)) + } + + fn into_identity_weighted(self) -> impl Iterator> + where + Self::Item: Into + Clone, + { + self.into_iter() + .map(move |item| Weighted::from_identity(item)) + } + + fn into_weighted_zip(self, weights: W) -> impl Iterator> + where + W: IntoIterator, + W::IntoIter: 'static, + { + self.into_iter() + .zip(weights.into_iter().chain(std::iter::repeat(0))) + .map(Into::into) + } +} + +impl WeightedIteratorExt for T {} + +/// Extends `IntoIterator` to add methods to convert any iterator into an iterator over `Weighted` +/// items. +pub trait WeightedExactSizeIteratorExt: IntoIterator + Sized { + fn into_weighted(self) -> impl ExactSizeIterator> + where + Self: IntoIterator, + Self::IntoIter: ExactSizeIterator + 'static, + { + self.into_iter() + .map(|(value, weight)| Weighted::new(value, weight)) + } + + fn into_default_weighted(self) -> impl ExactSizeIterator> + where + Self::IntoIter: ExactSizeIterator + 'static, + { + self.into_iter().map(Into::into) + } + + fn into_const_weighted(self, weight: u32) -> impl ExactSizeIterator> + where + Self::IntoIter: ExactSizeIterator + 'static, + { + self.into_iter() + .map(move |item| Weighted::new(item, weight)) + } + + fn into_identity_weighted(self) -> impl ExactSizeIterator> + where + Self::Item: Into + Clone, + Self::IntoIter: ExactSizeIterator + 'static, + { + self.into_iter() + .map(move |item| Weighted::from_identity(item)) + } +} + +impl WeightedExactSizeIteratorExt for T +where + T: IntoIterator, + T::IntoIter: ExactSizeIterator, +{ +} diff --git a/src/lib.rs b/src/lib.rs index c55e8bf..351bfd5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -325,7 +325,6 @@ //! } //! # } //! ``` - mod atomic; #[cfg(test)] mod barrier; diff --git a/src/util.rs b/src/util.rs index d8793b7..0ae0be8 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,8 +1,8 @@ //! Utility functions for simple use cases. //! //! In all cases, the number of threads is specified as a parameter, and the function takes care of -//! creating the [`Hive`](crate::hive::Hive), submitting tasks, collecting results, and shutting -//! down the `Hive` properly. +//! creating the [`Hive`](crate::hive::Hive) (with channel-based task queues), submitting tasks, +//! collecting results, and shutting down the `Hive` properly. use crate::bee::stock::{Caller, OnceCaller}; use crate::hive::{Builder, ChannelBuilder, Outcome, OutcomeBatch, TaskQueuesBuilder}; use std::fmt::Debug; @@ -143,7 +143,7 @@ mod retry { /// ``` pub fn try_map_retryable( num_threads: usize, - max_retries: u32, + max_retries: u8, inputs: Inputs, f: F, ) -> OutcomeBatch> From d180c2b2552fcb59b5655e7fa08eb482fce95735 Mon Sep 17 00:00:00 2001 From: jdidion Date: Wed, 12 Mar 2025 12:48:55 -0700 Subject: [PATCH 45/67] add tests --- CHANGELOG.md | 3 +- Cargo.toml | 6 + src/atomic.rs | 3 + src/bee/context.rs | 15 +- src/bee/error.rs | 1 + src/bee/queen.rs | 119 ++++++++++- src/bee/stock/call.rs | 292 ++++++++++++++++----------- src/bee/stock/echo.rs | 15 +- src/bee/stock/thunk.rs | 51 +++-- src/bee/worker.rs | 1 + src/channel.rs | 4 + src/hive/builder/bee.rs | 167 ++++++++++++++- src/hive/builder/full.rs | 20 +- src/hive/builder/mod.rs | 40 ++-- src/hive/builder/open.rs | 12 +- src/hive/builder/queue.rs | 8 +- src/hive/cores.rs | 12 +- src/hive/hive.rs | 25 +-- src/hive/husk.rs | 16 +- src/hive/inner/builder.rs | 156 +++++++++++++- src/hive/inner/config.rs | 3 + src/hive/inner/counter.rs | 1 + src/hive/inner/mod.rs | 12 ++ src/hive/inner/queue/channel.rs | 4 + src/hive/inner/queue/retry.rs | 1 + src/hive/inner/queue/workstealing.rs | 5 +- src/hive/inner/shared.rs | 3 +- src/hive/inner/task.rs | 43 ++-- src/hive/mock.rs | 87 +++++++- src/hive/mod.rs | 258 +++++++++++++++++------ src/hive/outcome/batch.rs | 5 + src/hive/outcome/impl.rs | 238 ++++++++++++---------- src/hive/outcome/iter.rs | 1 + src/hive/outcome/mod.rs | 24 ++- src/hive/outcome/queue.rs | 79 ++++++-- src/hive/outcome/store.rs | 1 + src/hive/weighted.rs | 182 +++++++++++++++-- src/lib.rs | 2 + src/panic.rs | 13 +- src/util.rs | 8 +- 40 files changed, 1469 insertions(+), 467 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 31b7597..7d5356c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,7 +16,8 @@ The general theme of this release is performance improvement by eliminating thre * `beekeeper::bee::Queen::create` now takes `&self` rather than `&mut self`. There is a new type, `beekeeper::bee::QueenMut`, with a `create(&mut self)` method, and needs to be wrapped in a `beekeeper::bee::QueenCell` to implement the `Queen` trait. This enables the `Hive` to create new workers without locking in the case of a `Queen` that does not need mutable state. * `beekeeper::bee::Context` now takes a generic parameter that must be input type of the `Worker`. * `beekeeper::hive::Hive::try_into_husk` now has an `urgent` parameter to indicate whether queued tasks should be abandoned when shutting down the hive (`true`) or if they should be allowed to finish processing (`false`). - * The type of `attempt` and `max_retries` has been changed to `u8`. This reduces memory usage and should cover the majority of use cases. + * The type of `attempt` and `max_retries` has been changed to `u8`. This reduces memory usage and should still allow for the majority of use cases. + * The `::of` methods have been removed from stock `Worker`s in favor of implementing `From`. * Features * Added the `TaskQueues` trait, which enables `Hive` to be specialized for different implementations of global (i.e., sending tasks from the `Hive` to worker threads) and local (i.e., worker thread-specific) queues. * `ChannelTaskQueues` implements the existing behavior, using a channel for sending tasks. diff --git a/Cargo.toml b/Cargo.toml index 4f2c9aa..816458b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ crossbeam-channel = "0.5.13" crossbeam-deque = "0.8.6" crossbeam-queue = "0.3.12" crossbeam-utils = "0.8.20" +derive_more = { version = "2.0.1", features = ["debug"] } num = "0.4.3" num_cpus = "1.16.0" parking_lot = "0.12.3" @@ -49,6 +50,11 @@ crossbeam = [] flume = ["dep:flume"] loole = ["dep:loole"] +[lints.rust] +unexpected_cfgs = { level = "warn", check-cfg = [ + 'cfg(coverage,coverage_nightly)', +] } + [package.metadata.cargo-all-features] allowlist = ["affinity", "local-batch", "retry"] diff --git a/src/atomic.rs b/src/atomic.rs index 4bf8791..3b58e3e 100644 --- a/src/atomic.rs +++ b/src/atomic.rs @@ -361,6 +361,7 @@ mod affinity { } #[cfg(test)] + #[cfg_attr(coverage_nightly, coverage(off))] mod tests { use crate::atomic::{AtomicAny, AtomicOption, MutError}; @@ -411,6 +412,7 @@ mod local_batch { } #[cfg(test)] + #[cfg_attr(coverage_nightly, coverage(off))] mod tests { use crate::atomic::{AtomicOption, AtomicUsize, MutError}; @@ -434,6 +436,7 @@ mod local_batch { } #[cfg(test)] +#[cfg_attr(coverage_nightly, coverage(off))] mod tests { use super::*; use paste::paste; diff --git a/src/bee/context.rs b/src/bee/context.rs index 4712564..36eb2e8 100644 --- a/src/bee/context.rs +++ b/src/bee/context.rs @@ -27,7 +27,7 @@ impl<'a, I> Context<'a, I> { /// Returns a new empty context. This is primarily useful for testing. pub fn empty() -> Self { Self { - meta: TaskMeta::empty(), + meta: TaskMeta::default(), local: None, subtask_ids: RefCell::new(None), } @@ -95,7 +95,7 @@ impl<'a, I> Context<'a, I> { } /// The metadata of a task. -#[derive(Default, Clone, Debug)] +#[derive(Clone, Debug, Default)] pub struct TaskMeta { id: TaskId, #[cfg(feature = "local-batch")] @@ -105,11 +105,6 @@ pub struct TaskMeta { } impl TaskMeta { - /// Creates an empty `TaskMeta` with a default task ID. - pub fn empty() -> Self { - Self::new(0) - } - /// Creates a new `TaskMeta` with the given task ID. pub fn new(id: TaskId) -> Self { TaskMeta { @@ -160,6 +155,12 @@ impl TaskMeta { } } +impl From for TaskMeta { + fn from(value: TaskId) -> Self { + TaskMeta::new(value) + } +} + #[cfg(all(test, feature = "retry"))] impl TaskMeta { pub fn with_attempt(task_id: TaskId, attempt: u8) -> Self { diff --git a/src/bee/error.rs b/src/bee/error.rs index 5a0137e..0cc5138 100644 --- a/src/bee/error.rs +++ b/src/bee/error.rs @@ -91,6 +91,7 @@ impl From for ApplyRefError { } #[cfg(test)] +#[cfg_attr(coverage_nightly, coverage(off))] mod tests { use super::ApplyError; use crate::panic::Panic; diff --git a/src/bee/queen.rs b/src/bee/queen.rs index 3c7628f..c0c93c8 100644 --- a/src/bee/queen.rs +++ b/src/bee/queen.rs @@ -1,8 +1,10 @@ //! The Queen bee trait. use super::Worker; +use derive_more::Debug; use parking_lot::RwLock; use std::marker::PhantomData; use std::ops::Deref; +use std::{any, fmt}; /// A trait for factories that create `Worker`s. pub trait Queen: Send + Sync + 'static { @@ -53,6 +55,20 @@ impl Queen for QueenCell { } } +impl fmt::Debug for QueenCell { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("QueenCell") + .field("queen", &*self.0.read()) + .finish() + } +} + +impl Clone for QueenCell { + fn clone(&self) -> Self { + Self(RwLock::new(self.0.read().clone())) + } +} + impl Default for QueenCell { fn default() -> Self { Self::new(Q::default()) @@ -98,8 +114,15 @@ impl From for QueenCell { /// } /// ``` #[derive(Default, Debug)] +#[debug("DefaultQueen<{}>", any::type_name::())] pub struct DefaultQueen(PhantomData); +impl Clone for DefaultQueen { + fn clone(&self) -> Self { + Self::default() + } +} + impl Queen for DefaultQueen { type Kind = W; @@ -111,6 +134,7 @@ impl Queen for DefaultQueen { /// A `Queen` that can create a `Worker` type that implements `Clone`, by making copies of /// an existing instance of that `Worker` type. #[derive(Debug)] +#[debug("CloneQueen<{}>", any::type_name::())] pub struct CloneQueen(W); impl CloneQueen { @@ -119,6 +143,18 @@ impl CloneQueen { } } +impl Clone for CloneQueen { + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} + +impl Default for CloneQueen { + fn default() -> Self { + Self(W::default()) + } +} + impl Queen for CloneQueen { type Kind = W; @@ -128,4 +164,85 @@ impl Queen for CloneQueen { } #[cfg(test)] -mod tests {} +#[cfg_attr(coverage_nightly, coverage(off))] +mod tests { + use super::{CloneQueen, DefaultQueen, Queen, QueenCell, QueenMut}; + use crate::bee::stock::EchoWorker; + + #[derive(Default, Debug, Clone)] + struct TestQueen(usize); + + impl QueenMut for TestQueen { + type Kind = EchoWorker; + + fn create(&mut self) -> Self::Kind { + self.0 += 1; + EchoWorker::default() + } + } + + #[test] + fn test_queen_cell() { + let queen = QueenCell::new(TestQueen(0)); + for _ in 0..10 { + let _worker = queen.create(); + } + assert_eq!(queen.get().0, 10); + assert_eq!(queen.into_inner().0, 10); + } + + #[test] + fn test_queen_cell_default() { + let queen = QueenCell::::default(); + for _ in 0..10 { + let _worker = queen.create(); + } + assert_eq!(queen.get().0, 10); + } + + #[test] + fn test_queen_cell_clone() { + let queen = QueenCell::::default(); + for _ in 0..10 { + let _worker = queen.create(); + } + assert_eq!(queen.clone().get().0, 10); + } + + #[test] + fn test_queen_cell_debug() { + let queen = QueenCell::::default(); + for _ in 0..10 { + let _worker = queen.create(); + } + assert_eq!(format!("{:?}", queen), "QueenCell { queen: TestQueen(10) }"); + } + + #[test] + fn test_queen_cell_from() { + let queen = QueenCell::from(TestQueen::default()); + for _ in 0..10 { + let _worker = queen.create(); + } + assert_eq!(queen.get().0, 10); + } + + #[test] + fn test_default_queen() { + let queen1 = DefaultQueen::>::default(); + let worker1 = queen1.create(); + let queen2 = queen1.clone(); + let worker2 = queen2.create(); + assert_eq!(worker1, worker2); + } + + #[test] + fn test_clone_queen() { + let worker = EchoWorker::::default(); + let queen = CloneQueen::new(worker); + let worker1 = queen.create(); + let queen2 = queen.clone(); + let worker2 = queen2.create(); + assert_eq!(worker1, worker2); + } +} diff --git a/src/bee/stock/call.rs b/src/bee/stock/call.rs index daa5890..54e9c00 100644 --- a/src/bee/stock/call.rs +++ b/src/bee/stock/call.rs @@ -2,15 +2,27 @@ use crate::bee::{ ApplyError, ApplyRefError, Context, RefWorker, RefWorkerResult, Worker, WorkerResult, }; -use std::fmt::Debug; +use derive_more::Debug; use std::marker::PhantomData; +use std::ops::{Deref, DerefMut}; +use std::{any, fmt}; /// Wraps a closure or function pointer and calls it when applied. For this `Callable` to be /// useable by a `Worker`, the function must be `FnMut` *and* `Clone`able. +/// +/// TODO: we could provide a better `Debug` implementation by providing a macro that can wrap a +/// closure and store the text of the function, and then change all the Workers to take a +/// `F: Deref`. +/// See https://users.rust-lang.org/t/is-it-possible-to-implement-debug-for-fn-type/14824/3 +#[derive(Debug)] struct Callable { + #[debug(skip)] f: F, + #[debug("{}", any::type_name::())] i: PhantomData, + #[debug("{}", any::type_name::())] o: PhantomData, + #[debug("{}", any::type_name::())] e: PhantomData, } @@ -23,6 +35,10 @@ impl Callable { e: PhantomData, } } + + fn into_inner(self) -> F { + self.f + } } impl Clone for Callable { @@ -31,18 +47,52 @@ impl Clone for Callable { } } +impl Deref for Callable { + type Target = F; + + fn deref(&self) -> &Self::Target { + &self.f + } +} + +impl DerefMut for Callable { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.f + } +} + /// A `Caller` that executes its function once on the input and returns the output. The function /// should not panic. -pub struct Caller(Callable); +#[derive(Debug)] +pub struct Caller { + callable: Callable, +} impl Caller { - pub fn of(f: F) -> Self - where - I: Send + Sync + 'static, - O: Send + Sync + 'static, - F: FnMut(I) -> O + Clone + 'static, - { - Caller(Callable::of(f)) + /// Returns the wrapped callable. + pub fn into_inner(self) -> F { + self.callable.into_inner() + } +} + +impl From for Caller +where + I: Send + Sync + 'static, + O: Send + Sync + 'static, + F: FnMut(I) -> O + Clone + 'static, +{ + fn from(f: F) -> Self { + Caller { + callable: Callable::of(f), + } + } +} + +impl Clone for Caller { + fn clone(&self) -> Self { + Self { + callable: self.callable.clone(), + } } } @@ -58,43 +108,45 @@ where #[inline] fn apply(&mut self, input: Self::Input, _: &Context) -> WorkerResult { - Ok((self.0.f)(input)) + Ok((self.callable)(input)) } } -impl O> Debug for Caller { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str("Caller") - } +/// A `Caller` that executes its function once on each input. The input value is consumed by the +/// function. If the function returns an error, it is wrapped in `ApplyError::Fatal`. +/// +/// If ownership of the input value is not required, consider using `RefCaller` instead. +#[derive(Debug)] +pub struct OnceCaller { + callable: Callable, } -impl Clone for Caller { - fn clone(&self) -> Self { - Self(self.0.clone()) +impl OnceCaller { + /// Returns the wrapped callable. + pub fn into_inner(self) -> F { + self.callable.into_inner() } } -impl O + Clone + 'static> From for Caller { +impl From for OnceCaller +where + I: Send + Sync + 'static, + O: Send + Sync + 'static, + E: Send + Sync + fmt::Debug + 'static, + F: FnMut(I) -> Result + Clone + 'static, +{ fn from(f: F) -> Self { - Caller(Callable::of(f)) + OnceCaller { + callable: Callable::of(f), + } } } -/// A `Caller` that executes its function once on each input. The input value is consumed by the -/// function. If the function returns an error, it is wrapped in `ApplyError::Fatal`. -/// -/// If ownership of the input value is not required, consider using `RefCaller` instead. -pub struct OnceCaller(Callable); - -impl OnceCaller { - pub fn of(f: F) -> Self - where - I: Send + Sync + 'static, - O: Send + Sync + 'static, - E: Send + Sync + Debug + 'static, - F: FnMut(I) -> Result + Clone + 'static, - { - OnceCaller(Callable::of(f)) +impl Clone for OnceCaller { + fn clone(&self) -> Self { + Self { + callable: self.callable.clone(), + } } } @@ -102,7 +154,7 @@ impl Worker for OnceCaller where I: Send + 'static, O: Send + 'static, - E: Send + Debug + 'static, + E: Send + fmt::Debug + 'static, F: FnMut(I) -> Result + Clone + 'static, { type Input = I; @@ -111,47 +163,46 @@ where #[inline] fn apply(&mut self, input: Self::Input, _: &Context) -> WorkerResult { - (self.0.f)(input).map_err(|error| ApplyError::Fatal { error, input: None }) + (self.callable)(input).map_err(|error| ApplyError::Fatal { error, input: None }) } } -impl Result> Debug for OnceCaller { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str("OnceCaller") - } +/// A `Caller` that executes its function once on a reference to the input. If the function +/// returns an error, it is wrapped in `ApplyError::Fatal`. +/// +/// The benefit of using `RefCaller` over `OnceCaller` is that the `Fatal` error +/// contains the input value for later recovery. +#[derive(Debug)] +pub struct RefCaller { + callable: Callable, } -impl Clone for OnceCaller { - fn clone(&self) -> Self { - Self(self.0.clone()) +impl RefCaller { + /// Returns the wrapped callable. + pub fn into_inner(self) -> F { + self.callable.into_inner() } } -impl From for OnceCaller +impl From for RefCaller where - F: FnMut(I) -> Result + Clone + 'static, + I: Send + Sync + 'static, + O: Send + Sync + 'static, + E: Send + Sync + fmt::Debug + 'static, + F: FnMut(&I) -> Result + Clone + 'static, { fn from(f: F) -> Self { - OnceCaller(Callable::of(f)) + RefCaller { + callable: Callable::of(f), + } } } -/// A `Caller` that executes its function once on a reference to the input. If the function -/// returns an error, it is wrapped in `ApplyError::Fatal`. -/// -/// The benefit of using `RefCaller` over `OnceCaller` is that the `Fatal` error -/// contains the input value for later recovery. -pub struct RefCaller(Callable); - -impl RefCaller { - pub fn of(f: F) -> Self - where - I: Send + Sync + 'static, - O: Send + Sync + 'static, - E: Send + Sync + Debug + 'static, - F: FnMut(&I) -> Result + Clone + 'static, - { - RefCaller(Callable::of(f)) +impl Clone for RefCaller { + fn clone(&self) -> Self { + Self { + callable: self.callable.clone(), + } } } @@ -159,7 +210,7 @@ impl RefWorker for RefCaller where I: Send + 'static, O: Send + 'static, - E: Send + Debug + 'static, + E: Send + fmt::Debug + 'static, F: FnMut(&I) -> Result + Clone + 'static, { type Input = I; @@ -172,44 +223,43 @@ where input: &Self::Input, _: &Context, ) -> RefWorkerResult { - (self.0.f)(input).map_err(|error| ApplyRefError::Fatal(error)) + (self.callable)(input).map_err(|error| ApplyRefError::Fatal(error)) } } -impl Result> Debug for RefCaller { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str("RefCaller") - } +/// A `Caller` that returns a `Result`. A result of `Err(ApplyError::Retryable)` +/// can be returned to indicate the task should be retried. +#[derive(Debug)] +pub struct RetryCaller { + callable: Callable, } -impl Clone for RefCaller { - fn clone(&self) -> Self { - Self(self.0.clone()) +impl RetryCaller { + /// Returns the wrapped callable. + pub fn into_inner(self) -> F { + self.callable.into_inner() } } -impl From for RefCaller +impl From for RetryCaller where - F: FnMut(&I) -> Result + Clone + 'static, + I: Send + Sync + 'static, + O: Send + Sync + 'static, + E: Send + Sync + fmt::Debug + 'static, + F: FnMut(I, &Context) -> Result> + Clone + 'static, { fn from(f: F) -> Self { - RefCaller(Callable::of(f)) + RetryCaller { + callable: Callable::of(f), + } } } -/// A `Caller` that returns a `Result`. A result of `Err(ApplyError::Retryable)` -/// can be returned to indicate the task should be retried. -pub struct RetryCaller(Callable); - -impl RetryCaller { - pub fn of(f: F) -> Self - where - I: Send + Sync + 'static, - O: Send + Sync + 'static, - E: Send + Sync + Debug + 'static, - F: FnMut(I, &Context) -> Result> + Clone + 'static, - { - RetryCaller(Callable::of(f)) +impl Clone for RetryCaller { + fn clone(&self) -> Self { + Self { + callable: self.callable.clone(), + } } } @@ -217,7 +267,7 @@ impl Worker for RetryCaller where I: Send + 'static, O: Send + 'static, - E: Send + Debug + 'static, + E: Send + fmt::Debug + 'static, F: FnMut(I, &Context) -> Result> + Clone + 'static, { type Input = I; @@ -226,44 +276,30 @@ where #[inline] fn apply(&mut self, input: Self::Input, ctx: &Context) -> WorkerResult { - (self.0.f)(input, ctx) - } -} - -impl Clone for RetryCaller { - fn clone(&self) -> Self { - Self(self.0.clone()) - } -} - -impl) -> Result>> Debug - for RetryCaller -{ - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str("RetryCaller") - } -} - -impl From for RetryCaller -where - F: FnMut(I, &Context) -> Result> + Clone + 'static, -{ - fn from(f: F) -> Self { - RetryCaller(Callable::of(f)) + (self.callable)(input, ctx) } } #[cfg(test)] +#[cfg_attr(coverage_nightly, coverage(off))] mod tests { use super::*; use crate::bee::Context; #[test] fn test_call() { - let mut worker = Caller::of(|input: u8| input + 1); + let mut worker = Caller::from(|input: u8| input + 1); assert!(matches!(worker.apply(5, &Context::empty()), Ok(6))) } + #[test] + fn test_clone() { + let worker1 = Caller::from(|input: u8| input + 1); + let worker2 = worker1.clone(); + let f = worker2.into_inner(); + assert_eq!(f(5), 6); + } + #[allow(clippy::type_complexity)] fn try_caller() -> RetryCaller< (bool, u8), @@ -273,7 +309,7 @@ mod tests { + Clone + 'static, > { - RetryCaller::of(|input: (bool, u8), _: &Context<(bool, u8)>| { + RetryCaller::from(|input: (bool, u8), _: &Context<(bool, u8)>| { if input.0 { Ok(input.1 + 1) } else { @@ -291,6 +327,14 @@ mod tests { assert!(matches!(worker.apply((true, 5), &Context::empty()), Ok(6))); } + #[test] + fn test_clone_retry_caller() { + let worker1 = try_caller(); + let worker2 = worker1.clone(); + let mut f = worker2.into_inner(); + assert!(matches!(f((true, 5), &Context::empty()), Ok(6))); + } + #[test] fn test_try_call_fail() { let mut worker = try_caller(); @@ -312,7 +356,7 @@ mod tests { String, impl FnMut((bool, u8)) -> Result + Clone + 'static, > { - OnceCaller::of(|input: (bool, u8)| { + OnceCaller::from(|input: (bool, u8)| { if input.0 { Ok(input.1 + 1) } else { @@ -327,6 +371,14 @@ mod tests { assert!(matches!(worker.apply((true, 5), &Context::empty()), Ok(6))); } + #[test] + fn test_clone_once_caller() { + let worker1 = once_caller(); + let worker2 = worker1.clone(); + let mut f = worker2.into_inner(); + assert!(matches!(f((true, 5)), Ok(6))); + } + #[test] fn test_once_call_fail() { let mut worker = once_caller(); @@ -348,7 +400,7 @@ mod tests { String, impl FnMut(&(bool, u8)) -> Result + Clone + 'static, > { - RefCaller::of(|input: &(bool, u8)| { + RefCaller::from(|input: &(bool, u8)| { if input.0 { Ok(input.1 + 1) } else { @@ -363,6 +415,14 @@ mod tests { assert!(matches!(worker.apply((true, 5), &Context::empty()), Ok(6))); } + #[test] + fn test_clone_ref_caller() { + let worker1 = ref_caller(); + let worker2 = worker1.clone(); + let mut f = worker2.into_inner(); + assert!(matches!(f(&(true, 5)), Ok(6))); + } + #[test] fn test_ref_call_fail() { let mut worker = ref_caller(); diff --git a/src/bee/stock/echo.rs b/src/bee/stock/echo.rs index d1dccaa..f3f9c59 100644 --- a/src/bee/stock/echo.rs +++ b/src/bee/stock/echo.rs @@ -1,18 +1,14 @@ use crate::bee::{Context, Worker, WorkerResult}; -use std::fmt::Debug; +use derive_more::Debug; use std::marker::PhantomData; +use std::{any, fmt}; /// A `Worker` that simply returns the input. -#[derive(Debug)] +#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] +#[debug("EchoWorker<{}>", any::type_name::())] pub struct EchoWorker(PhantomData); -impl Default for EchoWorker { - fn default() -> Self { - EchoWorker(PhantomData) - } -} - -impl Worker for EchoWorker { +impl Worker for EchoWorker { type Input = T; type Output = T; type Error = (); @@ -24,6 +20,7 @@ impl Worker for EchoWorker { } #[cfg(test)] +#[cfg_attr(coverage_nightly, coverage(off))] mod tests { use super::*; use crate::bee::Context; diff --git a/src/bee/stock/thunk.rs b/src/bee/stock/thunk.rs index 26b4b15..d8f90c6 100644 --- a/src/bee/stock/thunk.rs +++ b/src/bee/stock/thunk.rs @@ -1,11 +1,13 @@ use crate::bee::{ApplyError, Context, Worker, WorkerResult}; use crate::boxed::BoxedFnOnce; use crate::panic::Panic; -use std::fmt::Debug; +use derive_more::Debug; use std::marker::PhantomData; +use std::{any, fmt}; /// A `Worker` that executes infallible `Thunk`s when applied. #[derive(Debug)] +#[debug("ThunkWorker<{}>", any::type_name::())] pub struct ThunkWorker(PhantomData); impl Default for ThunkWorker { @@ -14,7 +16,13 @@ impl Default for ThunkWorker { } } -impl Worker for ThunkWorker { +impl Clone for ThunkWorker { + fn clone(&self) -> Self { + Self::default() + } +} + +impl Worker for ThunkWorker { type Input = Thunk; type Output = T; type Error = (); @@ -27,6 +35,7 @@ impl Worker for ThunkWorker { /// A `Worker` that executes fallible `Thunk>`s when applied. #[derive(Debug)] +#[debug("FunkWorker<{}, {}>", any::type_name::(), any::type_name::())] pub struct FunkWorker(PhantomData, PhantomData); impl Default for FunkWorker { @@ -35,7 +44,17 @@ impl Default for FunkWorker { } } -impl Worker for FunkWorker { +impl Clone for FunkWorker { + fn clone(&self) -> Self { + Self::default() + } +} + +impl Worker for FunkWorker +where + T: Send + fmt::Debug + 'static, + E: Send + fmt::Debug + 'static, +{ type Input = Thunk>; type Output = T; type Error = E; @@ -50,15 +69,22 @@ impl Worker for FunkWorker /// A `Worker` that executes `Thunk`s that may panic. A panic is caught and returned as an /// `ApplyError::Panic` error. #[derive(Debug)] +#[debug("PunkWorker<{}>", any::type_name::())] pub struct PunkWorker(PhantomData); impl Default for PunkWorker { fn default() -> Self { - PunkWorker(PhantomData) + Self(PhantomData) } } -impl Worker for PunkWorker { +impl Clone for PunkWorker { + fn clone(&self) -> Self { + Self::default() + } +} + +impl Worker for PunkWorker { type Input = Thunk; type Output = T; type Error = (); @@ -72,10 +98,12 @@ impl Worker for PunkWorker { } /// A wrapper around a closure that can be executed exactly once by a worker in a `Hive`. +#[derive(Debug)] +#[debug("Thunk<{}>", any::type_name::())] pub struct Thunk(Box + Send>); -impl Thunk { - pub fn of T + Send + 'static>(f: F) -> Self { +impl T + Send + 'static> From for Thunk { + fn from(f: F) -> Self { Self(Box::new(f)) } } @@ -86,13 +114,8 @@ impl Thunk> { } } -impl Debug for Thunk { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str("Thunk") - } -} - #[cfg(test)] +#[cfg_attr(coverage_nightly, coverage(off))] mod tests { use super::*; use crate::bee::Context; @@ -100,7 +123,7 @@ mod tests { #[test] fn test_thunk() { let mut worker = ThunkWorker::::default(); - let thunk = Thunk::of(|| 5); + let thunk = Thunk::from(|| 5); assert_eq!(5, worker.apply(thunk, &Context::empty()).unwrap()); } diff --git a/src/bee/worker.rs b/src/bee/worker.rs index adbc93b..6fc3902 100644 --- a/src/bee/worker.rs +++ b/src/bee/worker.rs @@ -94,6 +94,7 @@ where } #[cfg(test)] +#[cfg_attr(coverage_nightly, coverage(off))] mod tests { use super::{ApplyRefError, RefWorker, RefWorkerResult, Worker, WorkerResult}; use crate::bee::{ApplyError, Context}; diff --git a/src/channel.rs b/src/channel.rs index 32ae32f..b646767 100644 --- a/src/channel.rs +++ b/src/channel.rs @@ -2,12 +2,16 @@ //! //! A maximum one of the channel feature may be enabled. If no channel feature is enabled, then //! `std::sync::mpsc` will be used. +use derive_more::Debug; pub use prelude::channel; pub(crate) use prelude::*; +use std::any; /// Possible results of calling `ReceiverExt::try_recv_msg()` on a `Receiver`. +#[derive(Debug)] pub enum Message { /// A message was successfully received from the channel. + #[debug("Received: {}", any::type_name::())] Received(T), /// The channel was disconnected. ChannelDisconnected, diff --git a/src/hive/builder/bee.rs b/src/hive/builder/bee.rs index fc5d170..d666ed1 100644 --- a/src/hive/builder/bee.rs +++ b/src/hive/builder/bee.rs @@ -1,11 +1,14 @@ use super::{BuilderConfig, FullBuilder, Token}; use crate::bee::{CloneQueen, DefaultQueen, Queen, QueenCell, QueenMut, Worker}; use crate::hive::{ChannelTaskQueues, Config, TaskQueues, WorkstealingTaskQueues}; +use derive_more::Debug; +use std::any; /// A Builder for creating `Hive` instances for specific [`Worker`] and [`TaskQueues`] types. -#[derive(Clone, Default)] +#[derive(Clone, Default, Debug)] pub struct BeeBuilder { config: Config, + #[debug("{}",any::type_name::())] queen: Q, } @@ -28,26 +31,26 @@ impl BeeBuilder { } /// Creates a new `BeeBuilder` from an existing `config` and a `queen`. - pub(super) fn from(config: Config, queen: Q) -> Self { + pub(super) fn from_config_and_queen(config: Config, queen: Q) -> Self { Self { config, queen } } /// Creates a new `FullBuilder` with the current configuration and queen and specified /// `TaskQueues` type. pub fn with_queues>(self) -> FullBuilder { - FullBuilder::from(self.config, self.queen) + FullBuilder::from_config_and_queen(self.config, self.queen) } /// Creates a new `FullBuilder` with the current configuration and queen and channel-based /// task queues. pub fn with_channel_queues(self) -> FullBuilder> { - FullBuilder::from(self.config, self.queen) + FullBuilder::from_config_and_queen(self.config, self.queen) } /// Creates a new `FullBuilder` with the current configuration and queen and workstealing /// task queues. pub fn with_workstealing_queues(self) -> FullBuilder> { - FullBuilder::from(self.config, self.queen) + FullBuilder::from_config_and_queen(self.config, self.queen) } } @@ -136,3 +139,157 @@ impl BuilderConfig for BeeBuilder { &mut self.config } } + +impl From for BeeBuilder { + fn from(value: Config) -> Self { + Self::from_config_and_queen(value, Q::default()) + } +} + +impl From for BeeBuilder { + fn from(value: Q) -> Self { + Self::from_config_and_queen(Config::default(), value) + } +} + +#[cfg(test)] +#[cfg_attr(coverage_nightly, coverage(off))] +mod tests { + use super::*; + use crate::bee::stock::EchoWorker; + use crate::bee::{CloneQueen, DefaultQueen, Queen, QueenCell, QueenMut}; + use rstest::rstest; + + #[derive(Clone, Default)] + struct TestQueen; + + impl Queen for TestQueen { + type Kind = EchoWorker; + + fn create(&self) -> Self::Kind { + EchoWorker::default() + } + } + + impl QueenMut for TestQueen { + type Kind = EchoWorker; + + fn create(&mut self) -> Self::Kind { + EchoWorker::default() + } + } + + #[rstest] + fn test_queen( + #[values( + BeeBuilder::::empty::, + BeeBuilder::::preset:: + )] + factory: F, + #[values( + BeeBuilder::::with_channel_queues, + BeeBuilder::::with_workstealing_queues, + )] + with_fn: W, + ) where + F: Fn(TestQueen) -> BeeBuilder, + T: TaskQueues>, + W: Fn(BeeBuilder) -> FullBuilder, + { + let bee_builder = factory(TestQueen); + let full_builder = with_fn(bee_builder); + let _hive = full_builder.build(); + } + + #[rstest] + fn test_queen_default( + #[values( + BeeBuilder::::empty_with_queen_default, + BeeBuilder::::preset_with_queen_default + )] + factory: F, + #[values( + BeeBuilder::::with_channel_queues, + BeeBuilder::::with_workstealing_queues, + )] + with_fn: W, + ) where + F: Fn() -> BeeBuilder, + T: TaskQueues>, + W: Fn(BeeBuilder) -> FullBuilder, + { + let bee_builder = factory(); + let full_builder = with_fn(bee_builder); + let _hive = full_builder.build(); + } + + #[rstest] + fn test_queen_mut_default( + #[values( + BeeBuilder::>::empty_with_queen_mut_default, + BeeBuilder::>::preset_with_queen_mut_default + )] + factory: F, + #[values( + BeeBuilder::>::with_channel_queues, + BeeBuilder::>::with_workstealing_queues, + )] + with_fn: W, + ) where + F: Fn() -> BeeBuilder>, + T: TaskQueues>, + W: Fn(BeeBuilder>) -> FullBuilder, T>, + { + let bee_builder = factory(); + let full_builder = with_fn(bee_builder); + let _hive = full_builder.build(); + } + + #[rstest] + fn test_worker( + #[values( + BeeBuilder::>>::empty_with_worker, + BeeBuilder::>>::preset_with_worker + )] + factory: F, + #[values( + BeeBuilder::>>::with_channel_queues, + BeeBuilder::>>::with_workstealing_queues, + )] + with_fn: W, + ) where + F: Fn(EchoWorker) -> BeeBuilder>>, + T: TaskQueues>, + W: Fn( + BeeBuilder>>, + ) -> FullBuilder>, T>, + { + let bee_builder = factory(EchoWorker::default()); + let full_builder = with_fn(bee_builder); + let _hive = full_builder.build(); + } + + #[rstest] + fn test_worker_default( + #[values( + BeeBuilder::>>::empty_with_worker_default, + BeeBuilder::>>::preset_with_worker_default + )] + factory: F, + #[values( + BeeBuilder::>>::with_channel_queues, + BeeBuilder::>>::with_workstealing_queues, + )] + with_fn: W, + ) where + F: Fn() -> BeeBuilder>>, + T: TaskQueues>, + W: Fn( + BeeBuilder>>, + ) -> FullBuilder>, T>, + { + let bee_builder = factory(); + let full_builder = with_fn(bee_builder); + let _hive = full_builder.build(); + } +} diff --git a/src/hive/builder/full.rs b/src/hive/builder/full.rs index 9c2f4f2..52de007 100644 --- a/src/hive/builder/full.rs +++ b/src/hive/builder/full.rs @@ -1,13 +1,17 @@ use super::{BuilderConfig, Token}; use crate::bee::Queen; use crate::hive::{Config, Hive, TaskQueues}; +use derive_more::Debug; +use std::any; use std::marker::PhantomData; /// A Builder for creating `Hive` instances for specific [`Queen`] and [`TaskQueues`] types. -#[derive(Clone, Default)] +#[derive(Clone, Default, Debug)] pub struct FullBuilder> { config: Config, + #[debug("{}", any::type_name::())] queen: Q, + #[debug("{}", any::type_name::())] _queues: PhantomData, } @@ -32,7 +36,7 @@ impl> FullBuilder { } /// Creates a new `FullBuilder` from an existing `config` and a `queen`. - pub(super) fn from(config: Config, queen: Q) -> Self { + pub(super) fn from_config_and_queen(config: Config, queen: Q) -> Self { Self { config, queen, @@ -46,6 +50,18 @@ impl> FullBuilder { } } +impl> From for FullBuilder { + fn from(value: Config) -> Self { + Self::from_config_and_queen(value, Q::default()) + } +} + +impl> From for FullBuilder { + fn from(value: Q) -> Self { + Self::from_config_and_queen(Config::default(), value) + } +} + impl> BuilderConfig for FullBuilder { fn config_ref(&mut self, _: Token) -> &mut Config { &mut self.config diff --git a/src/hive/builder/mod.rs b/src/hive/builder/mod.rs index d5ab32e..e8b2269 100644 --- a/src/hive/builder/mod.rs +++ b/src/hive/builder/mod.rs @@ -82,26 +82,22 @@ pub fn workstealing(with_defaults: bool) -> WorkstealingBuilder { } } -// #[cfg(all(test, feature = "affinity"))] -// mod affinity_tests { -// use super::{OpenBuilder, Token}; -// use crate::hive::cores::Cores; +#[cfg(test)] +#[cfg_attr(coverage_nightly, coverage(off))] +mod tests { + use super::*; + use crate::hive::Builder; + use rstest::*; -// #[test] -// fn test_with_affinity() { -// let mut builder = OpenBuilder::empty(); -// builder = builder.with_default_core_affinity(); -// assert_eq!(builder.config(Token).affinity.get(), Some(Cores::all())); -// } -// } - -// #[cfg(all(test, feature = "local-batch"))] -// mod local_batch_tests { -// use super::OpenBuilder; -// } - -// #[cfg(all(test, feature = "retry"))] -// mod retry_tests { -// use super::OpenBuilder; -// use std::time::Duration; -// } + #[rstest] + fn test_create B>( + #[values(open, channel, workstealing)] builder_factory: F, + #[values(true, false)] with_defaults: bool, + ) { + let mut builder = builder_factory(with_defaults) + .num_threads(4) + .thread_name("foo") + .thread_stack_size(100); + crate::hive::inner::builder_test_utils::check_builder(&mut builder); + } +} diff --git a/src/hive/builder/open.rs b/src/hive/builder/open.rs index 8f31f90..14c9196 100644 --- a/src/hive/builder/open.rs +++ b/src/hive/builder/open.rs @@ -37,7 +37,7 @@ use crate::hive::Config; /// .with_channel_queues() /// .build(); /// ``` -#[derive(Clone, Default)] +#[derive(Clone, Default, Debug)] pub struct OpenBuilder(Config); impl OpenBuilder { @@ -117,19 +117,19 @@ impl OpenBuilder { /// # } /// ``` pub fn with_queen>(self, queen: I) -> BeeBuilder { - BeeBuilder::from(self.0, queen.into()) + BeeBuilder::from_config_and_queen(self.0, queen.into()) } /// Consumes this `Builder` and returns a new [`BeeBuilder`] using a [`Queen`] created with /// [`Q::default()`](std::default::Default) to create [`Worker`]s. pub fn with_queen_default(self) -> BeeBuilder { - BeeBuilder::from(self.0, Q::default()) + BeeBuilder::from_config_and_queen(self.0, Q::default()) } /// Consumes this `Builder` and returns a new [`BeeBuilder`] using a [`QueenMut`] created with /// [`Q::default()`](std::default::Default) to create [`Worker`]s. pub fn with_queen_mut_default(self) -> BeeBuilder> { - BeeBuilder::from(self.0, QueenCell::new(Q::default())) + BeeBuilder::from_config_and_queen(self.0, QueenCell::new(Q::default())) } /// Consumes this `Builder` and returns a new [`BeeBuilder`] with [`Worker`]s created by @@ -187,7 +187,7 @@ impl OpenBuilder { where W: Worker + Send + Sync + Clone, { - BeeBuilder::from(self.0, CloneQueen::new(worker)) + BeeBuilder::from_config_and_queen(self.0, CloneQueen::new(worker)) } /// Consumes this `Builder` and returns a new [`BeeBuilder`] with [`Worker`]s created using @@ -239,7 +239,7 @@ impl OpenBuilder { where W: Worker + Send + Sync + Default, { - BeeBuilder::from(self.0, DefaultQueen::default()) + BeeBuilder::from_config_and_queen(self.0, DefaultQueen::default()) } /// Consumes this `Builder` and returns a new [`ChannelBuilder`] using the current diff --git a/src/hive/builder/queue.rs b/src/hive/builder/queue.rs index 902a263..02b0ece 100644 --- a/src/hive/builder/queue.rs +++ b/src/hive/builder/queue.rs @@ -60,7 +60,7 @@ pub mod channel { use crate::hive::{ChannelTaskQueues, Config}; /// `TaskQueuesBuilder` implementation for channel-based task queues. - #[derive(Clone, Default)] + #[derive(Clone, Default, Debug)] pub struct ChannelBuilder(Config); impl BuilderConfig for ChannelBuilder { @@ -83,7 +83,7 @@ pub mod channel { Q: Queen, I: Into, { - FullBuilder::from(self.0, queen.into()) + FullBuilder::from_config_and_queen(self.0, queen.into()) } } @@ -100,7 +100,7 @@ pub mod workstealing { use crate::hive::{Config, WorkstealingTaskQueues}; /// `TaskQueuesBuilder` implementation for workstealing-based task queues. - #[derive(Clone, Default)] + #[derive(Clone, Default, Debug)] pub struct WorkstealingBuilder(Config); impl BuilderConfig for WorkstealingBuilder { @@ -123,7 +123,7 @@ pub mod workstealing { Q: Queen, I: Into, { - FullBuilder::from(self.0, queen.into()) + FullBuilder::from_config_and_queen(self.0, queen.into()) } } diff --git a/src/hive/cores.rs b/src/hive/cores.rs index 5f41253..cca74d1 100644 --- a/src/hive/cores.rs +++ b/src/hive/cores.rs @@ -49,7 +49,7 @@ pub fn refresh() -> usize { } /// Represents a CPU core. -#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct Core { /// the OS-specific core ID id: CoreId, @@ -81,15 +81,10 @@ impl Core { /// /// The mapping between CPU indices and core IDs is platform-specific, but the same index is /// guaranteed to always refer to the same physical core. -#[derive(Default, Clone, PartialEq, Eq, Debug)] +#[derive(Debug, Default, Clone, PartialEq, Eq)] pub struct Cores(Vec); impl Cores { - /// Returns an empty `Cores`. - pub fn empty() -> Self { - Self(Vec::new()) - } - /// Returns a `Cores` set populated with the first `n` CPU indices (up to the number of /// available cores). pub fn first(n: usize) -> Self { @@ -190,12 +185,13 @@ impl> From for Cores { } #[cfg(test)] +#[cfg_attr(coverage_nightly, coverage(off))] mod tests { use super::*; #[test] fn test_empty() { - assert_eq!(Cores::empty().0.len(), 0); + assert_eq!(Cores::default().0.len(), 0); } #[test] diff --git a/src/hive/hive.rs b/src/hive/hive.rs index 56aade4..b28e129 100644 --- a/src/hive/hive.rs +++ b/src/hive/hive.rs @@ -4,9 +4,9 @@ use super::{ TaskQueues, TaskQueuesBuilder, }; use crate::bee::{Context, DefaultQueen, Queen, TaskId, Worker}; +use derive_more::Debug; use std::borrow::Borrow; use std::collections::HashMap; -use std::fmt; use std::ops::{Deref, DerefMut}; use std::sync::Arc; use std::thread::JoinHandle; @@ -18,6 +18,7 @@ pub struct Poisoned; /// A pool of worker threads that each execute the same function. /// /// See the [module documentation](crate::hive) for details. +#[derive(Debug)] pub struct Hive>(Option>>); impl> Hive { @@ -683,21 +684,6 @@ where } } -impl fmt::Debug for Hive -where - W: Worker, - Q: Queen, - T: TaskQueues, -{ - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if let Some(shared) = self.0.as_ref() { - f.debug_struct("Hive").field("shared", &shared).finish() - } else { - f.write_str("Hive {}") - } - } -} - impl PartialEq for Hive where W: Worker, @@ -872,6 +858,7 @@ mod retry { } #[cfg(test)] +#[cfg_attr(coverage_nightly, coverage(off))] mod tests { use super::Poisoned; use crate::bee::stock::{Caller, Thunk, ThunkWorker}; @@ -891,7 +878,7 @@ mod tests { .build(); let (tx, rx) = outcome_channel(); hive.map_send( - (0..10).map(|_| Thunk::of(|| thread::sleep(Duration::from_secs(3)))), + (0..10).map(|_| Thunk::from(|| thread::sleep(Duration::from_secs(3)))), tx, ); // Allow first set of tasks to be started. @@ -931,7 +918,7 @@ mod tests { fn test_apply_after_poison() { let hive = ChannelBuilder::empty() .num_threads(4) - .with_worker(Caller::of(|i: usize| i * 2)) + .with_worker(Caller::from(|i: usize| i * 2)) .build(); // poison hive using private method hive.0.as_ref().unwrap().poison(); @@ -953,7 +940,7 @@ mod tests { fn test_swarm_after_poison() { let hive = ChannelBuilder::empty() .num_threads(4) - .with_worker(Caller::of(|i: usize| i * 2)) + .with_worker(Caller::from(|i: usize| i * 2)) .build(); // poison hive using private method hive.0.as_ref().unwrap().poison(); diff --git a/src/hive/husk.rs b/src/hive/husk.rs index cda6d77..6bccb3b 100644 --- a/src/hive/husk.rs +++ b/src/hive/husk.rs @@ -3,6 +3,8 @@ use super::{ OwnedOutcomes, TaskQueues, }; use crate::bee::{Queen, TaskId, Worker}; +use derive_more::Debug; +use std::any; use std::collections::HashMap; use std::ops::{Deref, DerefMut}; @@ -10,10 +12,13 @@ use std::ops::{Deref, DerefMut}; /// /// Provides access to the `Queen` and to stored `Outcome`s. Can be used to create a new `Hive` /// based on the previous `Hive`'s configuration. +#[derive(Debug)] pub struct Husk { config: Config, + #[debug("{}", any::type_name::())] queen: Q, num_panics: usize, + #[debug(skip)] outcomes: HashMap>, } @@ -135,6 +140,7 @@ impl> OwnedOutcomes for Husk { } #[cfg(test)] +#[cfg_attr(coverage_nightly, coverage(off))] mod tests { use crate::bee::stock::{PunkWorker, Thunk, ThunkWorker}; use crate::hive::ChannelTaskQueues; @@ -150,7 +156,7 @@ mod tests { .num_threads(0) .with_worker_default::>() .build(); - let mut task_ids = hive.map_store((0..10).map(|i| Thunk::of(move || i))); + let mut task_ids = hive.map_store((0..10).map(|i| Thunk::from(move || i))); // cancel and smash the hive before the tasks can be processed hive.suspend(); let mut husk = hive.try_into_husk(false).unwrap(); @@ -176,7 +182,7 @@ mod tests { .num_threads(0) .with_worker_default::>() .build(); - let _ = hive1.map_store((0..10).map(|i| Thunk::of(move || i))); + let _ = hive1.map_store((0..10).map(|i| Thunk::from(move || i))); // cancel and smash the hive before the tasks can be processed hive1.suspend(); let husk1 = hive1.try_into_husk(false).unwrap(); @@ -197,7 +203,7 @@ mod tests { .num_threads(0) .with_worker_default::>() .build(); - let _ = hive1.map_store((0..10).map(|i| Thunk::of(move || i))); + let _ = hive1.map_store((0..10).map(|i| Thunk::from(move || i))); // cancel and smash the hive before the tasks can be processed hive1.suspend(); let husk1 = hive1.try_into_husk(false).unwrap(); @@ -222,7 +228,7 @@ mod tests { .num_threads(4) .with_worker_default::>() .build(); - hive.map_store((0..10).map(|i| Thunk::of(move || i))); + hive.map_store((0..10).map(|i| Thunk::from(move || i))); hive.join(); let mut outputs = hive.try_into_husk(false).unwrap().into_parts().1.unwrap(); outputs.sort(); @@ -237,7 +243,7 @@ mod tests { .with_worker_default::>() .build(); hive.map_store( - (0..10).map(|i| Thunk::of(move || if i == 5 { panic!("oh no!") } else { i })), + (0..10).map(|i| Thunk::from(move || if i == 5 { panic!("oh no!") } else { i })), ); hive.join(); let (_, result) = hive.try_into_husk(false).unwrap().into_parts(); diff --git a/src/hive/inner/builder.rs b/src/hive/inner/builder.rs index 2b976fe..67703f2 100644 --- a/src/hive/inner/builder.rs +++ b/src/hive/inner/builder.rs @@ -198,8 +198,8 @@ pub trait Builder: BuilderConfig + Sized { /// Sets the worker thread batch size. /// - /// This may have no effect if the `local-batch` feature is disabled, or if the `TaskQueues` - /// implementation used for this hive does not support local batching. + /// This may have no effect if the `TaskQueues` implementation used for this hive does not + /// support local batching. /// /// If `batch_limit` is `0`, local batching is effectively disabled, but note that the /// performance may be worse than with the `local-batch` feature disabled. @@ -253,6 +253,12 @@ pub trait Builder: BuilderConfig + Sized { self } + /// Disables local batching. + #[cfg(feature = "local-batch")] + fn with_no_local_batching(self) -> Self { + self.batch_limit(0).weight_limit(0) + } + /// Sets the maximum number of times to retry a /// [`ApplyError::Retryable`](crate::bee::ApplyError::Retryable) error. A worker /// thread will retry a task until it either returns @@ -353,16 +359,21 @@ pub trait Builder: BuilderConfig + Sized { /// Sets retry parameters to their default values. #[cfg(feature = "retry")] - fn with_default_retries(mut self) -> Self { - let defaults = super::config::DEFAULTS.lock(); + fn with_default_max_retries(mut self) -> Self { let _ = self .config_ref(Token) .max_retries - .set(defaults.max_retries.get()); + .set(super::config::DEFAULTS.lock().max_retries.get()); + + self + } + + #[cfg(feature = "retry")] + fn with_default_retry_factor(mut self) -> Self { let _ = self .config_ref(Token) .retry_factor - .set(defaults.retry_factor.get()); + .set(super::config::DEFAULTS.lock().retry_factor.get()); self } @@ -374,3 +385,136 @@ pub trait Builder: BuilderConfig + Sized { } impl Builder for B {} + +#[cfg(test)] +#[cfg_attr(coverage_nightly, coverage(off))] +mod tests { + use super::*; + pub struct TestBuilder(Config); + + impl TestBuilder { + pub fn empty() -> Self { + TestBuilder(Config::empty()) + } + } + + impl BuilderConfig for TestBuilder { + fn config_ref(&mut self, _: Token) -> &mut Config { + &mut self.0 + } + } + + #[test] + fn test_common() { + let mut builder = TestBuilder::empty() + .num_threads(4) + .thread_name("foo") + .thread_stack_size(100); + crate::hive::inner::builder_test_utils::check_builder(&mut builder); + } +} + +#[cfg(all(test, feature = "affinity"))] +#[cfg_attr(coverage_nightly, coverage(off))] +mod affinity_tests { + use super::tests::TestBuilder; + use super::*; + use crate::hive::cores::Cores; + + #[test] + fn test_core_affinity() { + let mut builder = TestBuilder::empty(); + builder = builder.core_affinity(Cores::first(4)); + assert_eq!( + builder.config_ref(Token).affinity.get(), + Some((0..4).into()) + ); + } + + #[test] + fn test_with_default_core_affinity() { + let mut builder = TestBuilder::empty(); + builder = builder.with_default_core_affinity(); + assert_eq!(builder.config_ref(Token).affinity.get(), Some(Cores::all())); + } +} + +#[cfg(all(test, feature = "local-batch"))] +#[cfg_attr(coverage_nightly, coverage(off))] +mod local_batch_tests { + use super::tests::TestBuilder; + use super::*; + use crate::hive::inner::config::DEFAULTS; + + #[test] + fn test_batch_config() { + let mut builder = TestBuilder::empty().batch_limit(10).weight_limit(100); + let config = builder.config_ref(Token); + assert_eq!(config.batch_limit.get(), Some(10)); + assert_eq!(config.weight_limit.get(), Some(100)); + } + + #[test] + fn test_disable_batch_config() { + let mut builder = TestBuilder::empty().with_no_local_batching(); + let config = builder.config_ref(Token); + assert_eq!(config.batch_limit.get(), None); + assert_eq!(config.weight_limit.get(), None); + } + + #[test] + fn test_default_batch_config() { + let mut builder = TestBuilder::empty() + .with_default_batch_limit() + .with_default_weight_limit(); + let config = builder.config_ref(Token); + assert_eq!(config.batch_limit.get(), DEFAULTS.lock().batch_limit.get()); + assert_eq!( + config.weight_limit.get(), + DEFAULTS.lock().weight_limit.get() + ); + } +} + +#[cfg(all(test, feature = "retry"))] +#[cfg_attr(coverage_nightly, coverage(off))] +mod retry_tests { + use super::tests::TestBuilder; + use super::*; + use crate::hive::inner::config::DEFAULTS; + use std::time::Duration; + + #[test] + fn test_retry_config() { + let mut builder = TestBuilder::empty() + .max_retries(5) + .retry_factor(Duration::from_secs(10)); + let config = builder.config_ref(Token); + assert_eq!(config.max_retries.get(), Some(5)); + assert_eq!( + config.retry_factor.get(), + Some(Duration::from_secs(10).as_nanos() as u64) + ); + } + + #[test] + fn test_disable_retry() { + let mut builder = TestBuilder::empty().with_no_retries(); + let config = builder.config_ref(Token); + assert_eq!(config.max_retries.get(), None); + assert_eq!(config.retry_factor.get(), None); + } + + #[test] + fn test_default_retry_config() { + let mut builder = TestBuilder::empty() + .with_default_max_retries() + .with_default_retry_factor(); + let config = builder.config_ref(Token); + assert_eq!(config.max_retries.get(), DEFAULTS.lock().max_retries.get()); + assert_eq!( + config.retry_factor.get(), + DEFAULTS.lock().retry_factor.get() + ); + } +} diff --git a/src/hive/inner/config.rs b/src/hive/inner/config.rs index 5bc749e..358076c 100644 --- a/src/hive/inner/config.rs +++ b/src/hive/inner/config.rs @@ -113,6 +113,7 @@ impl Default for Config { } #[cfg(test)] +#[cfg_attr(coverage_nightly, coverage(off))] pub mod reset { /// Struct that resets the default values when `drop`ped. pub struct Reset; @@ -125,6 +126,7 @@ pub mod reset { } #[cfg(test)] +#[cfg_attr(coverage_nightly, coverage(off))] mod tests { use super::Config; use super::reset::Reset; @@ -225,6 +227,7 @@ mod retry { } #[cfg(test)] + #[cfg_attr(coverage_nightly, coverage(off))] mod tests { use super::Config; use crate::hive::inner::config::reset::Reset; diff --git a/src/hive/inner/counter.rs b/src/hive/inner/counter.rs index 93072da..a81147a 100644 --- a/src/hive/inner/counter.rs +++ b/src/hive/inner/counter.rs @@ -138,6 +138,7 @@ impl Default for DualCounter { } #[cfg(test)] +#[cfg_attr(coverage_nightly, coverage(off))] mod tests { use super::*; diff --git a/src/hive/inner/mod.rs b/src/hive/inner/mod.rs index dbfaf3a..504cf8d 100644 --- a/src/hive/inner/mod.rs +++ b/src/hive/inner/mod.rs @@ -113,3 +113,15 @@ pub struct Config { #[cfg(feature = "retry")] retry_factor: U64, } + +#[cfg(test)] +pub(super) mod builder_test_utils { + use super::*; + + pub fn check_builder(builder: &mut B) { + let config = builder.config_ref(Token); + assert_eq!(config.num_threads.get(), Some(4)); + assert_eq!(config.thread_name.get(), Some("foo".into())); + assert_eq!(config.thread_stack_size.get(), Some(100)); + } +} diff --git a/src/hive/inner/queue/channel.rs b/src/hive/inner/queue/channel.rs index ae90b49..d72a853 100644 --- a/src/hive/inner/queue/channel.rs +++ b/src/hive/inner/queue/channel.rs @@ -5,7 +5,9 @@ use super::{Config, PopTaskError, Status, Task, TaskQueues, Token, WorkerQueues} use crate::bee::Worker; use crossbeam_channel::RecvTimeoutError; use crossbeam_queue::SegQueue; +use derive_more::Debug; use parking_lot::RwLock; +use std::any; use std::sync::Arc; use std::time::Duration; @@ -21,6 +23,8 @@ type TaskReceiver = crossbeam_channel::Receiver>; /// /// Worker threads may have access to local retry and/or batch queues, depending on which features /// are enabled. +#[derive(Debug)] +#[debug("ChannelTaskQueues<{}>", any::type_name::())] pub struct ChannelTaskQueues { global: Arc>, local: RwLock>>>, diff --git a/src/hive/inner/queue/retry.rs b/src/hive/inner/queue/retry.rs index 7469724..bdc01fb 100644 --- a/src/hive/inner/queue/retry.rs +++ b/src/hive/inner/queue/retry.rs @@ -140,6 +140,7 @@ impl PartialEq for DelayedTask { impl Eq for DelayedTask {} #[cfg(test)] +#[cfg_attr(coverage_nightly, coverage(off))] mod tests { use super::{RetryQueue, Task, Worker}; use crate::bee::stock::EchoWorker; diff --git a/src/hive/inner/queue/workstealing.rs b/src/hive/inner/queue/workstealing.rs index 4596ddc..4557780 100644 --- a/src/hive/inner/queue/workstealing.rs +++ b/src/hive/inner/queue/workstealing.rs @@ -10,17 +10,20 @@ use crate::atomic::Atomic; use crate::bee::Worker; use crossbeam_deque::{Injector, Stealer}; use crossbeam_queue::SegQueue; +use derive_more::Debug; use parking_lot::RwLock; use rand::prelude::*; use std::ops::Deref; use std::sync::Arc; -use std::thread; use std::time::Duration; +use std::{any, thread}; /// Time to wait after trying to pop and finding all queues empty. const EMPTY_DELAY: Duration = Duration::from_millis(100); /// `TaskQueues` implementation using workstealing. +#[derive(Debug)] +#[debug("WorkstealingTaskQueues<{}>", any::type_name::())] pub struct WorkstealingTaskQueues { global: Arc>, local: RwLock>>>, diff --git a/src/hive/inner/shared.rs b/src/hive/inner/shared.rs index 5d63175..cf192d7 100644 --- a/src/hive/inner/shared.rs +++ b/src/hive/inner/shared.rs @@ -736,13 +736,14 @@ mod retry { self.num_tasks .increment_left(1) .expect("overflowed queued task counter"); - let task = Task::with_meta_inc_attempt(input, meta, outcome_tx.cloned()); + let task = Task::next_retry_attempt(input, meta, outcome_tx.cloned()); worker_queues.try_push_retry(task) } } } #[cfg(test)] +#[cfg_attr(coverage_nightly, coverage(off))] mod tests { use crate::bee::DefaultQueen; use crate::bee::stock::ThunkWorker; diff --git a/src/hive/inner/task.rs b/src/hive/inner/task.rs index 4e547b4..68219bc 100644 --- a/src/hive/inner/task.rs +++ b/src/hive/inner/task.rs @@ -7,34 +7,6 @@ use crate::hive::{Outcome, OutcomeSender}; pub use task_impl::TaskInput; impl Task { - /// Creates a new `Task` with the given metadata. - pub fn with_meta( - input: W::Input, - meta: TaskMeta, - outcome_tx: Option>, - ) -> Self { - Task { - input, - meta, - outcome_tx, - } - } - - /// Creates a new `Task` with the given metadata, and increments the attempt number. - #[cfg(feature = "retry")] - pub fn with_meta_inc_attempt( - input: W::Input, - mut meta: TaskMeta, - outcome_tx: Option>, - ) -> Self { - meta.inc_attempt(); - Self { - input, - meta, - outcome_tx, - } - } - /// Returns the ID of the task. #[inline] pub fn id(&self) -> TaskId { @@ -61,6 +33,21 @@ impl Task { }; (outcome, self.outcome_tx) } + + /// Creates a new `Task` with the given metadata, and increments the attempt number. + #[cfg(feature = "retry")] + pub fn next_retry_attempt( + input: W::Input, + mut meta: TaskMeta, + outcome_tx: Option>, + ) -> Self { + meta.inc_attempt(); + Self { + input, + meta, + outcome_tx, + } + } } #[cfg(not(feature = "local-batch"))] diff --git a/src/hive/mock.rs b/src/hive/mock.rs index e79ea7b..6a1e5fd 100644 --- a/src/hive/mock.rs +++ b/src/hive/mock.rs @@ -5,12 +5,18 @@ use std::cell::RefCell; /// A struct used for testing `Worker`s in a mock environment without needing to create a `Hive`. #[derive(Debug)] -pub struct MockTaskRunner(RefCell); +pub struct MockTaskRunner { + worker: RefCell, + task_id: RefCell, +} -impl MockTaskRunner { +impl MockTaskRunner { /// Creates a new `MockTaskRunner` with a starting task ID of 0. - pub fn new() -> Self { - Self(RefCell::new(0)) + pub fn new(worker: W, first_task_id: TaskId) -> Self { + Self { + worker: RefCell::new(worker), + task_id: RefCell::new(first_task_id), + } } /// Applies the given `worker` to the given `input`. @@ -18,29 +24,45 @@ impl MockTaskRunner { /// The task ID is automatically incremented and used to create the `Context`. /// /// Returns the `Outcome` from executing the task. - pub fn apply(&self, worker: &mut W, input: TaskInput) -> Outcome { + pub fn apply>>(&self, input: I) -> Outcome { let task_id = self.next_task_id(); - let local = MockLocalContext(&self); - let task: Task = Task::new(task_id, input, None); + let local = MockLocalContext(self); + let task: Task = Task::new(task_id, input.into(), None); let (input, task_meta, _) = task.into_parts(); let ctx = Context::new(task_meta, Some(&local)); - let result = worker.apply(input, &ctx); + let result = self.worker.borrow_mut().apply(input, &ctx); let (task_meta, subtask_ids) = ctx.into_parts(); Outcome::from_worker_result(result, task_meta, subtask_ids) } fn next_task_id(&self) -> TaskId { - let mut task_id_counter = self.0.borrow_mut(); + let mut task_id_counter = self.task_id.borrow_mut(); let task_id = *task_id_counter; *task_id_counter += 1; task_id } } +impl From for MockTaskRunner { + fn from(value: W) -> Self { + Self::new(value, 0) + } +} + +impl Default for MockTaskRunner { + fn default() -> Self { + Self::from(W::default()) + } +} + #[derive(Debug)] -struct MockLocalContext<'a>(&'a MockTaskRunner); +struct MockLocalContext<'a, W: Worker>(&'a MockTaskRunner); -impl<'a, I> LocalContext for MockLocalContext<'a> { +impl LocalContext for MockLocalContext<'_, W> +where + W: Worker, + I: Into>, +{ fn should_cancel_tasks(&self) -> bool { false } @@ -49,3 +71,46 @@ impl<'a, I> LocalContext for MockLocalContext<'a> { self.0.next_task_id() } } + +#[cfg(test)] +#[cfg_attr(coverage_nightly, coverage(off))] +mod tests { + use std::vec; + + use super::MockTaskRunner; + use crate::bee::{Context, Worker, WorkerResult}; + use crate::hive::Outcome; + + #[derive(Debug, Default)] + struct TestWorker; + + impl Worker for TestWorker { + type Input = usize; + type Output = usize; + type Error = (); + + fn apply(&mut self, input: Self::Input, ctx: &Context) -> WorkerResult { + if !ctx.is_cancelled() { + for i in 1..=3 { + ctx.submit(input + i).unwrap(); + } + } + Ok(input) + } + } + + #[test] + fn test_works() { + let runner = MockTaskRunner::::default(); + let outcome = runner.apply(42); + assert!(matches!( + outcome, + Outcome::SuccessWithSubtasks { + value: 42, + task_id: 0, + .. + } + )); + assert_eq!(outcome.subtask_ids(), Some(&vec![1, 2, 3])) + } +} diff --git a/src/hive/mod.rs b/src/hive/mod.rs index 7883373..c9cd590 100644 --- a/src/hive/mod.rs +++ b/src/hive/mod.rs @@ -440,6 +440,7 @@ pub mod prelude { } #[cfg(test)] +#[cfg_attr(coverage_nightly, coverage(off))] mod tests { use super::inner::TaskQueues; use super::{ @@ -508,7 +509,7 @@ mod tests { assert!(!hive.has_dead_workers()); for _ in 0..TEST_TASKS { let tx = tx.clone(); - hive.apply_store(Thunk::of(move || { + hive.apply_store(Thunk::from(move || { tx.send(1).unwrap(); })); } @@ -525,7 +526,7 @@ mod tests { let hive = thunk_hive::(0, builder_factory(true)); // check that with 0 threads no tasks are scheduled let (tx, rx) = super::outcome_channel(); - let _ = hive.apply_send(Thunk::of(|| 0), &tx); + let _ = hive.apply_send(Thunk::from(|| 0), &tx); thread::sleep(ONE_SEC); assert_eq!(hive.num_tasks().0, 1); assert!(matches!(rx.try_recv_msg(), Message::ChannelEmpty)); @@ -548,7 +549,7 @@ mod tests { let hive = void_thunk_hive(TEST_TASKS, builder_factory(false)); // queue some long-running tasks for _ in 0..TEST_TASKS { - hive.apply_store(Thunk::of(|| thread::sleep(LONG_TASK))); + hive.apply_store(Thunk::from(|| thread::sleep(LONG_TASK))); } thread::sleep(ONE_SEC); assert_eq!(hive.num_tasks().1, TEST_TASKS as u64); @@ -558,7 +559,7 @@ mod tests { hive.grow(new_threads).expect("error spawning threads"); // queue some more long-running tasks for _ in 0..new_threads { - hive.apply_store(Thunk::of(|| thread::sleep(LONG_TASK))); + hive.apply_store(Thunk::from(|| thread::sleep(LONG_TASK))); } thread::sleep(ONE_SEC); assert_eq!(hive.num_tasks().1, total_threads as u64); @@ -576,7 +577,7 @@ mod tests { // queue some long-running tasks let total_tasks = 2 * TEST_TASKS; for _ in 0..total_tasks { - hive.apply_store(Thunk::of(|| thread::sleep(SHORT_TASK))); + hive.apply_store(Thunk::from(|| thread::sleep(SHORT_TASK))); } thread::sleep(ONE_SEC); assert_eq!(hive.num_tasks(), (TEST_TASKS as u64, TEST_TASKS as u64)); @@ -650,7 +651,7 @@ mod tests { { let hive = void_thunk_hive(TEST_TASKS, builder_factory(false)); for _ in 0..2 * TEST_TASKS { - hive.apply_store(Thunk::of(|| { + hive.apply_store(Thunk::from(|| { loop { thread::sleep(LONG_TASK) } @@ -674,7 +675,7 @@ mod tests { .build(); let num_threads = num_cpus::get(); for _ in 0..num_threads { - hive.apply_store(Thunk::of(|| { + hive.apply_store(Thunk::from(|| { loop { thread::sleep(LONG_TASK) } @@ -696,7 +697,7 @@ mod tests { let (tx, _) = super::outcome_channel(); // Panic all the existing threads. for _ in 0..TEST_TASKS { - hive.apply_send(Thunk::of(|| panic!("intentional panic")), &tx); + hive.apply_send(Thunk::from(|| panic!("intentional panic")), &tx); } hive.join(); // Ensure that none of the threads have panicked @@ -712,7 +713,7 @@ mod tests { F: Fn(bool) -> B, { let hive: Hive<_, _> = builder_factory(false) - .with_worker(RefCaller::of(|_: &u8| -> Result { + .with_worker(RefCaller::from(|_: &u8| -> Result { panic!("intentional panic") })) .num_threads(TEST_TASKS) @@ -746,7 +747,7 @@ mod tests { for _ in 0..TEST_TASKS { let waiter = waiter.clone(); let waiter_count = waiter_count.clone(); - hive.apply_store(Thunk::of(move || { + hive.apply_store(Thunk::from(move || { waiter_count.fetch_add(1, Ordering::SeqCst); waiter.wait(); panic!("intentional panic"); @@ -784,7 +785,7 @@ mod tests { let tx = tx.clone(); let (b0, b1) = (b0.clone(), b1.clone()); - hive.apply_store(Thunk::of(move || { + hive.apply_store(Thunk::from(move || { // Wait until the pool has been filled once. b0.wait(); // wait so the pool can be measured @@ -825,7 +826,7 @@ mod tests { // initial thread should share the name "test" for _ in 0..2 { let tx = tx.clone(); - hive.apply_store(Thunk::of(move || { + hive.apply_store(Thunk::from(move || { let name = thread::current().name().unwrap().to_owned(); tx.send(name).unwrap(); })); @@ -834,7 +835,7 @@ mod tests { // new spawn thread should share the name "test" too. hive.grow(3).expect("error spawning threads"); let tx_clone = tx.clone(); - hive.apply_store(Thunk::of(move || { + hive.apply_store(Thunk::from(move || { let name = thread::current().name().unwrap().to_owned(); tx_clone.send(name).unwrap(); })); @@ -859,7 +860,7 @@ mod tests { .build(); let actual_stack_size = hive - .apply(Thunk::of(|| { + .apply(Thunk::from(|| { //println!("This thread has a 4 MB stack size!"); stacker::remaining_stack().unwrap() })) @@ -880,7 +881,7 @@ mod tests { let debug = format!("{:?}", hive); assert_eq!( debug, - "Hive { shared: Shared { name: None, num_threads: 4, num_tasks_queued: 0, num_tasks_active: 0 } }" + "Hive(Some(Shared { name: None, num_threads: 4, num_tasks_queued: 0, num_tasks_active: 0 }))" ); let hive: Hive>, B::TaskQueues<_>> = builder_factory(false) @@ -891,16 +892,16 @@ mod tests { let debug = format!("{:?}", hive); assert_eq!( debug, - "Hive { shared: Shared { name: \"hello\", num_threads: 4, num_tasks_queued: 0, num_tasks_active: 0 } }" + "Hive(Some(Shared { name: \"hello\", num_threads: 4, num_tasks_queued: 0, num_tasks_active: 0 }))" ); let hive = thunk_hive(4, builder_factory(true)); - hive.apply_store(Thunk::of(|| thread::sleep(LONG_TASK))); + hive.apply_store(Thunk::from(|| thread::sleep(LONG_TASK))); thread::sleep(ONE_SEC); let debug = format!("{:?}", hive); assert_eq!( debug, - "Hive { shared: Shared { name: None, num_threads: 4, num_tasks_queued: 0, num_tasks_active: 1 } }" + "Hive(Some(Shared { name: None, num_threads: 4, num_tasks_queued: 0, num_tasks_active: 1 }))" ); } @@ -920,7 +921,7 @@ mod tests { for _ in 0..42 { let test_count = test_count.clone(); - hive.apply_store(Thunk::of(move || { + hive.apply_store(Thunk::from(move || { thread::sleep(SHORT_TASK); test_count.fetch_add(1, Ordering::Release); })); @@ -931,7 +932,7 @@ mod tests { for _ in 0..42 { let test_count = test_count.clone(); - hive.apply_store(Thunk::of(move || { + hive.apply_store(Thunk::from(move || { thread::sleep(SHORT_TASK); test_count.fetch_add(1, Ordering::Relaxed); })); @@ -972,8 +973,8 @@ mod tests { let hive1_clone = hive1.clone(); let hive0_clone = hive0.clone(); let tx = tx.clone(); - hive0.apply_store(Thunk::of(move || { - hive1_clone.apply_store(Thunk::of(move || { + hive0.apply_store(Thunk::from(move || { + hive1_clone.apply_store(Thunk::from(move || { //error(format!("p1: {} -=- {:?}\n", i, hive0_clone)); hive0_clone.join(); // ensure that the main thread has a chance to execute @@ -1027,13 +1028,13 @@ mod tests { .num_threads(8) .build(); - hive.apply_store(Thunk::of(sleepy_function)); + hive.apply_store(Thunk::from(sleepy_function)); let p_t = hive.clone(); thread::spawn(move || { (0..23) .inspect(|_| { - p_t.apply_store(Thunk::of(sleepy_function)); + p_t.apply_store(Thunk::from(sleepy_function)); }) .count(); }); @@ -1050,7 +1051,7 @@ mod tests { let hive = thunk_hive::(2, builder_factory(false)); let outputs: Vec<_> = hive .map((0..10u8).map(|i| { - Thunk::of(move || { + Thunk::from(move || { thread::sleep(Duration::from_millis((10 - i as u64) * 100)); i }) @@ -1069,7 +1070,7 @@ mod tests { let hive = thunk_hive::(8, builder_factory(false)); let mut outputs: Vec<_> = hive .map_unordered((0..8u8).map(|i| { - Thunk::of(move || { + Thunk::from(move || { thread::sleep(Duration::from_millis((8 - i as u64) * 100)); i }) @@ -1090,7 +1091,7 @@ mod tests { let (tx, rx) = super::outcome_channel(); let mut task_ids = hive.map_send( (0..8u8).map(|i| { - Thunk::of(move || { + Thunk::from(move || { thread::sleep(Duration::from_millis((8 - i as u64) * 100)); i }) @@ -1119,7 +1120,7 @@ mod tests { { let mut hive = thunk_hive::(8, builder_factory(false)); let mut task_ids = hive.map_store((0..8u8).map(|i| { - Thunk::of(move || { + Thunk::from(move || { thread::sleep(Duration::from_millis((8 - i as u64) * 100)); i }) @@ -1148,7 +1149,7 @@ mod tests { let hive = thunk_hive::(2, builder_factory(false)); let outputs: Vec<_> = hive .swarm((0..10u8).map(|i| { - Thunk::of(move || { + Thunk::from(move || { thread::sleep(Duration::from_millis((10 - i as u64) * 100)); i }) @@ -1168,7 +1169,7 @@ mod tests { let hive = thunk_hive::(8, builder_factory(false)); let mut outputs: Vec<_> = hive .swarm_unordered((0..8u8).map(|i| { - Thunk::of(move || { + Thunk::from(move || { thread::sleep(Duration::from_millis((8 - i as u64) * 100)); i }) @@ -1191,7 +1192,7 @@ mod tests { let (tx, rx) = super::outcome_channel(); let mut task_ids = hive.swarm_send( (0..8u8).map(|i| { - Thunk::of(move || { + Thunk::from(move || { thread::sleep(Duration::from_millis((8 - i as u64) * 200)); i }) @@ -1219,7 +1220,7 @@ mod tests { { let mut hive = thunk_hive::(8, builder_factory(false)); let mut task_ids = hive.swarm_store((0..8u8).map(|i| { - Thunk::of(move || { + Thunk::from(move || { thread::sleep(Duration::from_millis((8 - i as u64) * 100)); i }) @@ -1246,7 +1247,7 @@ mod tests { F: Fn(bool) -> B, { let hive = builder_factory(false) - .with_worker(Caller::of(|i: usize| i * i)) + .with_worker(Caller::from(|i: usize| i * i)) .num_threads(4) .build(); let (outputs, state) = hive.scan(0..10usize, 0, |acc, i| { @@ -1276,7 +1277,7 @@ mod tests { F: Fn(bool) -> B, { let hive = builder_factory(false) - .with_worker(Caller::of(|i: i32| i * i)) + .with_worker(Caller::from(|i: i32| i * i)) .num_threads(4) .build(); let (tx, rx) = super::outcome_channel(); @@ -1316,7 +1317,7 @@ mod tests { F: Fn(bool) -> B, { let hive = builder_factory(false) - .with_worker(Caller::of(|i: i32| i * i)) + .with_worker(Caller::from(|i: i32| i * i)) .num_threads(4) .build(); let (tx, rx) = super::outcome_channel(); @@ -1359,7 +1360,7 @@ mod tests { F: Fn(bool) -> B, { let hive = builder_factory(false) - .with_worker(OnceCaller::of(|i: i32| Ok::<_, String>(i * i))) + .with_worker(OnceCaller::from(|i: i32| Ok::<_, String>(i * i))) .num_threads(4) .build(); let (tx, _) = super::outcome_channel(); @@ -1378,7 +1379,7 @@ mod tests { F: Fn(bool) -> B, { let mut hive = builder_factory(false) - .with_worker(Caller::of(|i: i32| i * i)) + .with_worker(Caller::from(|i: i32| i * i)) .num_threads(4) .build(); let (mut task_ids, state) = hive.scan_store(0..10, 0, |acc, i| { @@ -1419,7 +1420,7 @@ mod tests { F: Fn(bool) -> B, { let mut hive = builder_factory(false) - .with_worker(Caller::of(|i: i32| i * i)) + .with_worker(Caller::from(|i: i32| i * i)) .num_threads(4) .build(); let (results, state) = hive.try_scan_store(0..10, 0, |acc, i| { @@ -1462,7 +1463,7 @@ mod tests { F: Fn(bool) -> B, { let hive = builder_factory(false) - .with_worker(OnceCaller::of(|i: i32| Ok::(i * i))) + .with_worker(OnceCaller::from(|i: i32| Ok::(i * i))) .num_threads(4) .build(); let _ = hive @@ -1520,7 +1521,7 @@ mod tests { F: Fn(bool) -> B, { let hive1 = thunk_hive::(8, builder_factory(false)); - let task_ids = hive1.map_store((0..8u8).map(|i| Thunk::of(move || i))); + let task_ids = hive1.map_store((0..8u8).map(|i| Thunk::from(move || i))); hive1.join(); let mut husk1 = hive1.try_into_husk(false).unwrap(); for i in task_ids.iter() { @@ -1535,7 +1536,7 @@ mod tests { .with_channel_queues() .build(); hive2.map_store((0..8u8).map(|i| { - Thunk::of(move || { + Thunk::from(move || { thread::sleep(Duration::from_millis((8 - i as u64) * 100)); i }) @@ -1559,7 +1560,7 @@ mod tests { let hive3 = husk1.into_hive::>(); hive3.map_store((0..8u8).map(|i| { - Thunk::of(move || { + Thunk::from(move || { thread::sleep(Duration::from_millis((8 - i as u64) * 100)); i }) @@ -1589,7 +1590,7 @@ mod tests { // This batch of tasks will occupy the pool for some time for _ in 0..6 { - hive.apply_store(Thunk::of(|| { + hive.apply_store(Thunk::from(|| { thread::sleep(SHORT_TASK); })); } @@ -1604,7 +1605,7 @@ mod tests { let (tx, rx) = mpsc::channel(); for i in 0..42 { let tx = tx.clone(); - hive.apply_store(Thunk::of(move || { + hive.apply_store(Thunk::from(move || { tx.send(i).expect("channel will be waiting"); })); } @@ -1621,7 +1622,7 @@ mod tests { let (tx, rx) = mpsc::channel(); for i in 1..12 { let tx = tx.clone(); - pool.apply_store(Thunk::of(move || { + pool.apply_store(Thunk::from(move || { tx.send(i).expect("channel will be waiting"); })); } @@ -1723,7 +1724,7 @@ mod tests { { let barrier = barrier.clone(); - clock_hive.apply_store(Thunk::of(move || { + clock_hive.apply_store(Thunk::from(move || { barrier.wait(); // this sleep is for stabilisation on weaker platforms thread::sleep(Duration::from_millis(100)); @@ -1735,11 +1736,11 @@ mod tests { let tx = tx.clone(); let clock_hive = clock_hive.clone(); let wave_counter = wave_counter.clone(); - waiter_hive.apply_store(Thunk::of(move || { + waiter_hive.apply_store(Thunk::from(move || { let wave_before = wave_counter.load(Ordering::SeqCst); clock_hive.join(); // submit tasks for the next wave - clock_hive.apply_store(Thunk::of(|| thread::sleep(ONE_SEC))); + clock_hive.apply_store(Thunk::from(|| thread::sleep(ONE_SEC))); let wave_after = wave_counter.load(Ordering::SeqCst); tx.send((wave_before, wave_after, worker)).unwrap(); })); @@ -1798,13 +1799,13 @@ mod tests { // return results to your own channel... let (tx, rx) = crate::hive::outcome_channel(); - let task_ids = hive.swarm_send((0..10).map(|i: i32| Thunk::of(move || i * i)), &tx); + let task_ids = hive.swarm_send((0..10).map(|i: i32| Thunk::from(move || i * i)), &tx); let outputs: Vec<_> = rx.select_unordered_outputs(task_ids).collect(); assert_eq!(285, outputs.into_iter().sum()); // return results as an iterator... let outputs2: Vec<_> = hive - .swarm((0..10).map(|i: i32| Thunk::of(move || i * -i))) + .swarm((0..10).map(|i: i32| Thunk::from(move || i * -i))) .into_outputs() .collect(); assert_eq!(-285, outputs2.into_iter().sum()); @@ -1970,7 +1971,7 @@ mod affinity_tests { .build(); hive.map_store((0..10).map(move |i| { - Thunk::of(move || { + Thunk::from(move || { if let Some(affininty) = core_affinity::get_core_ids() { eprintln!("task {} on thread with affinity {:?}", i, affininty); } @@ -1988,7 +1989,7 @@ mod affinity_tests { .build(); hive.map_store((0..num_cpus::get()).map(move |i| { - Thunk::of(move || { + Thunk::from(move || { if let Some(affininty) = core_affinity::get_core_ids() { eprintln!("task {} on thread with affinity {:?}", i, affininty); } @@ -2003,8 +2004,8 @@ mod local_batch_tests { use crate::bee::DefaultQueen; use crate::bee::stock::{Thunk, ThunkWorker}; use crate::hive::{ - Builder, Hive, Outcome, OutcomeIteratorExt, OutcomeReceiver, OutcomeSender, TaskQueues, - TaskQueuesBuilder, WeightedExactSizeIteratorExt, channel_builder, workstealing_builder, + Builder, Hive, OutcomeIteratorExt, OutcomeReceiver, OutcomeSender, TaskQueues, + TaskQueuesBuilder, channel_builder, workstealing_builder, }; use rstest::*; use std::collections::HashMap; @@ -2024,7 +2025,7 @@ mod local_batch_tests { .map(|_| { let barrier = barrier.clone(); let task_id = hive.apply_send( - Thunk::of(move || { + Thunk::from(move || { barrier.wait(); thread::sleep(Duration::from_millis(100)); thread::current().id() @@ -2038,7 +2039,7 @@ mod local_batch_tests { // send the rest all at once let rest_task_ids = hive.map_send( (num_threads..total_tasks).map(|_| { - Thunk::of(move || { + Thunk::from(move || { thread::sleep(Duration::from_millis(1)); thread::current().id() }) @@ -2183,6 +2184,98 @@ mod local_batch_tests { assert!(thread_counts.values().all(|count| *count > BATCH_LIMIT_0)); assert_eq!(thread_counts.values().sum::(), total_tasks); } +} + +#[cfg(all(test, feature = "local-batch"))] +mod weighted_map_tests { + use crate::bee::stock::{Thunk, ThunkWorker}; + use crate::hive::{ + Builder, Outcome, TaskQueuesBuilder, Weighted, WeightedIteratorExt, channel_builder, + workstealing_builder, + }; + use rstest::*; + use std::thread; + use std::time::Duration; + + #[rstest] + fn test_map_weighted(#[values(channel_builder, workstealing_builder)] builder_factory: F) + where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + const NUM_THREADS: usize = 4; + const BATCH_LIMIT: usize = 24; + let hive = builder_factory(false) + .with_worker_default::>() + .num_threads(NUM_THREADS) + .batch_limit(BATCH_LIMIT) + .build(); + let inputs = (0..10u8) + .map(|i| { + Thunk::from(move || { + thread::sleep(Duration::from_millis((10 - i as u64) * 100)); + i + }) + }) + .map(|thunk| (thunk, 0)) + .into_weighted(); + let outputs: Vec<_> = hive.map(inputs).map(Outcome::unwrap).collect(); + assert_eq!(outputs, (0..10).collect::>()) + } + + #[rstest] + fn test_overweight() { + const WEIGHT_LIMIT: u64 = 99; + let hive = channel_builder(false) + .with_worker_default::>() + .num_threads(1) + .weight_limit(WEIGHT_LIMIT) + .build(); + let outcome = hive.apply(Weighted::new(Thunk::from(|| 0), 100)); + assert!(matches!( + outcome, + Outcome::WeightLimitExceeded { weight: 100, .. } + )) + } +} + +#[cfg(all(test, feature = "local-batch"))] +mod weighted_swarm_tests { + use crate::bee::stock::{EchoWorker, Thunk, ThunkWorker}; + use crate::hive::{ + Builder, Outcome, TaskQueuesBuilder, WeightedExactSizeIteratorExt, channel_builder, + workstealing_builder, + }; + use rstest::*; + use std::thread; + use std::time::Duration; + + #[rstest] + fn test_swarm_weighted( + #[values(channel_builder, workstealing_builder)] builder_factory: F, + ) where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + const NUM_THREADS: usize = 4; + const BATCH_LIMIT: usize = 24; + let hive = builder_factory(false) + .with_worker_default::>() + .num_threads(NUM_THREADS) + .batch_limit(BATCH_LIMIT) + .build(); + let inputs = (0..10u8) + .map(|i| { + Thunk::from(move || { + thread::sleep(Duration::from_millis((10 - i as u64) * 100)); + i + }) + }) + .map(|thunk| (thunk, 0)) + .into_weighted(); + let outputs: Vec<_> = hive.swarm(inputs).map(Outcome::unwrap).collect(); + assert_eq!(outputs, (0..10).collect::>()) + } #[rstest] fn test_swarm_default_weighted( @@ -2200,7 +2293,7 @@ mod local_batch_tests { .build(); let inputs = (0..10u8) .map(|i| { - Thunk::of(move || { + Thunk::from(move || { thread::sleep(Duration::from_millis((10 - i as u64) * 100)); i }) @@ -2209,6 +2302,51 @@ mod local_batch_tests { let outputs: Vec<_> = hive.swarm(inputs).map(Outcome::unwrap).collect(); assert_eq!(outputs, (0..10).collect::>()) } + + #[rstest] + fn test_swarm_const_weighted( + #[values(channel_builder, workstealing_builder)] builder_factory: F, + ) where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + const NUM_THREADS: usize = 4; + const BATCH_LIMIT: usize = 24; + let hive = builder_factory(false) + .with_worker_default::>() + .num_threads(NUM_THREADS) + .batch_limit(BATCH_LIMIT) + .build(); + let inputs = (0..10u8) + .map(|i| { + Thunk::from(move || { + thread::sleep(Duration::from_millis((10 - i as u64) * 100)); + i + }) + }) + .into_const_weighted(0); + let outputs: Vec<_> = hive.swarm(inputs).map(Outcome::unwrap).collect(); + assert_eq!(outputs, (0..10).collect::>()) + } + + #[rstest] + fn test_swarm_identity_weighted( + #[values(channel_builder, workstealing_builder)] builder_factory: F, + ) where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + const NUM_THREADS: usize = 4; + const BATCH_LIMIT: usize = 24; + let hive = builder_factory(false) + .with_worker_default::>() + .num_threads(NUM_THREADS) + .batch_limit(BATCH_LIMIT) + .build(); + let inputs = (0..10u8).into_identity_weighted(); + let outputs: Vec<_> = hive.swarm(inputs).map(Outcome::unwrap).collect(); + assert_eq!(outputs, (0..10).collect::>()) + } } #[cfg(all(test, feature = "retry"))] @@ -2243,7 +2381,7 @@ mod retry_tests { F: Fn(bool) -> B, { let hive = builder_factory(false) - .with_worker(RetryCaller::of(echo_time)) + .with_worker(RetryCaller::from(echo_time)) .with_thread_per_core() .max_retries(3) .retry_factor(Duration::from_secs(1)) @@ -2278,7 +2416,7 @@ mod retry_tests { } let hive = builder_factory(false) - .with_worker(RetryCaller::of(sometimes_fail)) + .with_worker(RetryCaller::from(sometimes_fail)) .with_thread_per_core() .max_retries(3) .build(); @@ -2306,7 +2444,7 @@ mod retry_tests { F: Fn(bool) -> B, { let hive = builder_factory(false) - .with_worker(RetryCaller::of(echo_time)) + .with_worker(RetryCaller::from(echo_time)) .with_thread_per_core() .with_no_retries() .build(); diff --git a/src/hive/outcome/batch.rs b/src/hive/outcome/batch.rs index 1bdbcdf..b42efff 100644 --- a/src/hive/outcome/batch.rs +++ b/src/hive/outcome/batch.rs @@ -1,9 +1,13 @@ use super::{DerefOutcomes, Outcome, OwnedOutcomes}; use crate::bee::{TaskId, Worker}; +use derive_more::Debug; +use std::any; use std::collections::HashMap; use std::ops::{Deref, DerefMut}; /// A batch of `Outcome`s. +#[derive(Debug)] +#[debug("OutcomeBatch<{}>", any::type_name::())] pub struct OutcomeBatch(HashMap>); impl OutcomeBatch { @@ -49,6 +53,7 @@ impl DerefOutcomes for OutcomeBatch { /// Functions only used in testing. #[cfg(test)] +#[cfg_attr(coverage_nightly, coverage(off))] impl OutcomeBatch { pub(crate) fn empty() -> Self { OutcomeBatch::new(HashMap::new()) diff --git a/src/hive/outcome/impl.rs b/src/hive/outcome/impl.rs index 6f97bd0..5d66ea1 100644 --- a/src/hive/outcome/impl.rs +++ b/src/hive/outcome/impl.rs @@ -1,7 +1,7 @@ use super::Outcome; use crate::bee::{ApplyError, TaskId, TaskMeta, Worker, WorkerResult}; use std::cmp::Ordering; -use std::fmt::Debug; +use std::mem; impl Outcome { /// Converts a worker `result` into an `Outcome` with the given task_id and optional subtask ids. @@ -180,6 +180,16 @@ impl Outcome { } } + /// Retursn a reference to the wrapped error, if any. + pub fn error(&self) -> Option<&W::Error> { + match self { + Self::Failure { error, .. } | Self::FailureWithSubtasks { error, .. } => Some(error), + #[cfg(feature = "retry")] + Self::MaxRetriesAttempted { error, .. } => Some(error), + _ => None, + } + } + /// Consumes this `Outcome` and depending on the variant: /// * Returns the wrapped error if this is a `Failure` or `MaxRetriesAttempted`, /// * Resumes unwinding if this is a `Panic` outcome, @@ -203,111 +213,9 @@ impl Outcome { } } -impl Debug for Outcome { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Success { task_id, .. } => { - f.debug_struct("Success").field("task_id", task_id).finish() - } - Self::SuccessWithSubtasks { - task_id, - subtask_ids, - .. - } => f - .debug_struct("SuccessWithSubtasks") - .field("task_id", task_id) - .field("subtask_ids", subtask_ids) - .finish(), - Self::Failure { error, task_id, .. } => f - .debug_struct("Failure") - .field("error", error) - .field("task_id", task_id) - .finish(), - Self::FailureWithSubtasks { - error, - task_id, - subtask_ids, - .. - } => f - .debug_struct("FailureWithSubtasks") - .field("error", error) - .field("task_id", task_id) - .field("subtask_ids", subtask_ids) - .finish(), - Self::Unprocessed { task_id, .. } => f - .debug_struct("Unprocessed") - .field("task_id", task_id) - .finish(), - Self::UnprocessedWithSubtasks { - task_id, - subtask_ids, - .. - } => f - .debug_struct("UnprocessedWithSubtasks") - .field("task_id", task_id) - .field("subtask_ids", subtask_ids) - .finish(), - Self::Missing { task_id } => { - f.debug_struct("Missing").field("task_id", task_id).finish() - } - Self::Panic { task_id, .. } => { - f.debug_struct("Panic").field("task_id", task_id).finish() - } - Self::PanicWithSubtasks { - task_id, - subtask_ids, - .. - } => f - .debug_struct("PanicWithSubtasks") - .field("task_id", task_id) - .field("subtask_ids", subtask_ids) - .finish(), - #[cfg(feature = "local-batch")] - Self::WeightLimitExceeded { task_id, .. } => f - .debug_struct("WeightLimitExceeded") - .field("task_id", task_id) - .finish(), - #[cfg(feature = "retry")] - Self::MaxRetriesAttempted { error, task_id, .. } => f - .debug_struct("MaxRetriesAttempted") - .field("error", error) - .field("task_id", task_id) - .finish(), - } - } -} - impl PartialEq for Outcome { fn eq(&self, other: &Self) -> bool { - match (self, other) { - (Self::Success { task_id: a, .. }, Self::Success { task_id: b, .. }) - | ( - Self::SuccessWithSubtasks { task_id: a, .. }, - Self::SuccessWithSubtasks { task_id: b, .. }, - ) - | (Self::Failure { task_id: a, .. }, Self::Failure { task_id: b, .. }) - | ( - Self::FailureWithSubtasks { task_id: a, .. }, - Self::FailureWithSubtasks { task_id: b, .. }, - ) - | (Self::Unprocessed { task_id: a, .. }, Self::Unprocessed { task_id: b, .. }) - | ( - Self::UnprocessedWithSubtasks { task_id: a, .. }, - Self::UnprocessedWithSubtasks { task_id: b, .. }, - ) - | (Self::Missing { task_id: a }, Self::Missing { task_id: b }) - | (Self::Panic { task_id: a, .. }, Self::Panic { task_id: b, .. }) - | ( - Self::PanicWithSubtasks { task_id: a, .. }, - Self::PanicWithSubtasks { task_id: b, .. }, - ) => a == b, - #[cfg(feature = "retry")] - ( - Self::MaxRetriesAttempted { task_id: a, .. }, - Self::MaxRetriesAttempted { task_id: b, .. }, - ) => a == b, - _ => false, - } + mem::discriminant(self) == mem::discriminant(other) && self.task_id() == other.task_id() } } @@ -326,14 +234,121 @@ impl Ord for Outcome { } #[cfg(test)] +#[cfg_attr(coverage_nightly, coverage(off))] mod tests { use super::Outcome; use crate::bee::stock::EchoWorker; + use crate::bee::{ApplyError, TaskMeta, WorkerResult}; use crate::panic::Panic; type Worker = EchoWorker; type WorkerOutcome = Outcome; + #[test] + fn test_success() { + let outcome = WorkerOutcome::Success { + value: 42, + task_id: 1, + }; + assert_eq!(outcome.success(), Some(42)); + } + + #[test] + fn test_unwrap() { + let outcome = WorkerOutcome::Success { + value: 42, + task_id: 1, + }; + assert_eq!(outcome.unwrap(), 42); + } + + #[test] + fn test_success_on_error() { + let outcome = WorkerOutcome::Failure { + input: Some(42), + error: (), + task_id: 1, + }; + assert!(matches!(outcome.success(), None)) + } + + #[test] + #[should_panic] + fn test_unwrap_panics_on_error() { + let outcome = WorkerOutcome::Failure { + input: Some(42), + error: (), + task_id: 1, + }; + let _ = outcome.unwrap(); + } + + #[test] + fn test_retry_with_subtasks_into_failure() { + let input = 1; + let task_id = 1; + let error = (); + let result = WorkerResult::::Err(ApplyError::Retryable { input, error }); + let task_meta = TaskMeta::new(task_id); + let subtask_ids = vec![2, 3, 4]; + let outcome = + WorkerOutcome::from_worker_result(result, task_meta, Some(subtask_ids.clone())); + let expected_outcome = WorkerOutcome::FailureWithSubtasks { + input: Some(input), + error, + task_id, + subtask_ids, + }; + assert_eq!(outcome, expected_outcome); + } + + #[test] + fn test_subtasks() { + let input = 1; + let task_id = 1; + let error = (); + let task_meta = TaskMeta::new(task_id); + let subtask_ids = vec![2, 3, 4]; + + let result = WorkerResult::::Err(ApplyError::Fatal { + input: Some(input), + error, + }); + let outcome = + WorkerOutcome::from_worker_result(result, task_meta.clone(), Some(subtask_ids.clone())); + let expected_outcome = WorkerOutcome::FailureWithSubtasks { + input: Some(1), + task_id: 1, + error: (), + subtask_ids: vec![2, 3, 4], + }; + assert_eq!(outcome, expected_outcome); + + let result = WorkerResult::::Err(ApplyError::Cancelled { input }); + let outcome = + WorkerOutcome::from_worker_result(result, task_meta.clone(), Some(subtask_ids.clone())); + let expected_outcome = WorkerOutcome::UnprocessedWithSubtasks { + input: 1, + task_id: 1, + subtask_ids: vec![2, 3, 4], + }; + assert_eq!(outcome, expected_outcome); + + let result = WorkerResult::::Err(ApplyError::Panic { + input: Some(input), + payload: Panic::new("panicked", None), + }); + let outcome = + WorkerOutcome::from_worker_result(result, task_meta.clone(), Some(subtask_ids.clone())); + let expected_outcome = WorkerOutcome::PanicWithSubtasks { + input: Some(1), + task_id: 1, + subtask_ids: vec![2, 3, 4], + payload: Panic::new("panicked", None), + }; + assert_eq!(outcome, expected_outcome); + } + #[test] fn test_try_into_input() { let outcome = WorkerOutcome::Success { @@ -453,6 +468,7 @@ mod tests { #[cfg(all(test, feature = "retry"))] mod retry_tests { use super::Outcome; + use crate::bee::TaskMeta; use crate::bee::stock::EchoWorker; type Worker = EchoWorker; @@ -477,4 +493,18 @@ mod retry_tests { }; assert_eq!(outcome.try_into_error(), Some(())); } + + #[test] + fn test_from_fatal() { + let input = 1; + let task_id = 1; + let error = (); + let outcome = WorkerOutcome::from_fatal(input, TaskMeta::new(task_id), error); + let expected_outcome = WorkerOutcome::Failure { + input: Some(input), + task_id, + error, + }; + assert_eq!(outcome, expected_outcome); + } } diff --git a/src/hive/outcome/iter.rs b/src/hive/outcome/iter.rs index 47b8c3f..abad145 100644 --- a/src/hive/outcome/iter.rs +++ b/src/hive/outcome/iter.rs @@ -250,6 +250,7 @@ pub trait OutcomeIteratorExt: IntoIterator> + Sized impl>> OutcomeIteratorExt for T {} #[cfg(test)] +#[cfg_attr(coverage_nightly, coverage(off))] mod tests { use super::{OrderedOutcomeIterator, UnorderedOutcomeIterator}; use crate::bee::stock::EchoWorker; diff --git a/src/hive/outcome/mod.rs b/src/hive/outcome/mod.rs index 10f577d..b9f9f9f 100644 --- a/src/hive/outcome/mod.rs +++ b/src/hive/outcome/mod.rs @@ -1,6 +1,6 @@ mod batch; -mod iter; mod r#impl; +mod iter; mod queue; mod store; @@ -13,6 +13,7 @@ pub(super) use self::store::{DerefOutcomes, OwnedOutcomes}; use crate::bee::{TaskId, Worker}; use crate::panic::Panic; +use derive_more::Debug; /// The possible outcomes of a task execution. /// @@ -22,12 +23,18 @@ use crate::panic::Panic; /// /// Note that `Outcome`s can only be compared or ordered with other `Outcome`s produced by the same /// `Hive`, because comparison/ordering is completely based on the task ID. +#[derive(Debug)] pub enum Outcome { /// The task was executed successfully. - Success { value: W::Output, task_id: TaskId }, + Success { + #[debug(skip)] + value: W::Output, + task_id: TaskId, + }, /// The task was executed successfully, and it also submitted one or more subtask_ids to the /// `Hive`. SuccessWithSubtasks { + #[debug(skip)] value: W::Output, task_id: TaskId, subtask_ids: Vec, @@ -35,6 +42,7 @@ pub enum Outcome { /// The task failed with an error that was not retryable. The input value that caused the /// failure is provided if possible. Failure { + #[debug(skip)] input: Option, error: W::Error, task_id: TaskId, @@ -42,6 +50,7 @@ pub enum Outcome { /// The task failed with an error that was not retryable, but it submitted one or more subtask_ids /// before failing. The input value that caused the failure is provided if possible. FailureWithSubtasks { + #[debug(skip)] input: Option, error: W::Error, task_id: TaskId, @@ -49,10 +58,15 @@ pub enum Outcome { }, /// The task was not executed before the Hive was dropped, or processing of the task was /// interrupted (e.g., by `suspend`ing the `Hive`). - Unprocessed { input: W::Input, task_id: TaskId }, + Unprocessed { + #[debug(skip)] + input: W::Input, + task_id: TaskId, + }, /// The task was not executed before the Hive was dropped, or processing of the task was /// interrupted (e.g., by `suspend`ing the `Hive`), but it first submitted one or more subtask_ids. UnprocessedWithSubtasks { + #[debug(skip)] input: W::Input, task_id: TaskId, subtask_ids: Vec, @@ -62,6 +76,7 @@ pub enum Outcome { Missing { task_id: TaskId }, /// The task panicked. The input value that caused the panic is provided if possible. Panic { + #[debug(skip)] input: Option, payload: Panic, task_id: TaskId, @@ -69,6 +84,7 @@ pub enum Outcome { /// The task panicked, but it submitted one or more subtask_ids before panicking. The input value /// that caused the panic is provided if possible. PanicWithSubtasks { + #[debug(skip)] input: Option, payload: Panic, task_id: TaskId, @@ -77,6 +93,7 @@ pub enum Outcome { /// The task's weight was larger than the configured limit for the `Hive`. #[cfg(feature = "local-batch")] WeightLimitExceeded { + #[debug(skip)] input: W::Input, weight: u32, task_id: TaskId, @@ -84,6 +101,7 @@ pub enum Outcome { /// The task failed after retrying the maximum number of times. #[cfg(feature = "retry")] MaxRetriesAttempted { + #[debug(skip)] input: W::Input, error: W::Error, task_id: TaskId, diff --git a/src/hive/outcome/queue.rs b/src/hive/outcome/queue.rs index cb6c0a2..da66902 100644 --- a/src/hive/outcome/queue.rs +++ b/src/hive/outcome/queue.rs @@ -22,10 +22,7 @@ impl OutcomeQueue { /// Flushes the queue into the map of outcomes and returns a mutable reference to the map. pub fn get_mut(&self) -> impl DerefMut>> { let mut outcomes = self.outcomes.lock(); - // add any queued outcomes to the map - while let Some(outcome) = self.queue.pop() { - outcomes.insert(*outcome.task_id(), outcome); - } + drain_into(&self.queue, &mut outcomes); outcomes } @@ -33,24 +30,25 @@ impl OutcomeQueue { /// returns them. pub fn drain(&self) -> HashMap> { let mut outcomes: HashMap> = self.outcomes.lock().drain().collect(); - // add any queued outcomes to the map - while let Some(outcome) = self.queue.pop() { - outcomes.insert(*outcome.task_id(), outcome); - } + drain_into(&self.queue, &mut outcomes); outcomes } /// Consumes this `OutcomeQueue`, drains the queue, and returns the outcomes as a map. pub fn into_inner(self) -> HashMap> { let mut outcomes = self.outcomes.into_inner(); - // add any queued outcomes to the map - while let Some(outcome) = self.queue.pop() { - outcomes.insert(*outcome.task_id(), outcome); - } + drain_into(&self.queue, &mut outcomes); outcomes } } +#[inline] +fn drain_into(queue: &SegQueue>, outcomes: &mut HashMap>) { + while let Some(outcome) = queue.pop() { + outcomes.insert(*outcome.task_id(), outcome); + } +} + impl Default for OutcomeQueue { fn default() -> Self { Self { @@ -62,10 +60,63 @@ impl Default for OutcomeQueue { impl DerefOutcomes for OutcomeQueue { fn outcomes_deref(&self) -> impl Deref>> { - self.outcomes.lock() + self.get_mut() } fn outcomes_deref_mut(&mut self) -> impl DerefMut>> { - self.outcomes.lock() + self.get_mut() + } +} + +#[cfg(test)] +#[cfg_attr(coverage_nightly, coverage(off))] +mod tests { + use super::*; + use crate::bee::stock::EchoWorker; + use crate::hive::OutcomeStore; + + #[test] + fn test_works() { + let queue = OutcomeQueue::>::default(); + queue.push(Outcome::Success { + value: 42, + task_id: 1, + }); + queue.push(Outcome::Unprocessed { + input: 43, + task_id: 2, + }); + queue.push(Outcome::Failure { + input: Some(44), + error: (), + task_id: 3, + }); + assert_eq!(queue.count(), (1, 1, 1)); + queue.push(Outcome::Missing { task_id: 4 }); + let outcomes = queue.drain(); + assert_eq!(outcomes.len(), 4); + assert_eq!( + outcomes[&1], + Outcome::Success { + value: 42, + task_id: 1 + } + ); + assert_eq!( + outcomes[&2], + Outcome::Unprocessed { + input: 43, + task_id: 2 + } + ); + assert_eq!( + outcomes[&3], + Outcome::Failure { + input: Some(44), + error: (), + task_id: 3 + } + ); + assert_eq!(outcomes[&4], Outcome::Missing { task_id: 4 }) } } diff --git a/src/hive/outcome/store.rs b/src/hive/outcome/store.rs index 8a93578..45b5733 100644 --- a/src/hive/outcome/store.rs +++ b/src/hive/outcome/store.rs @@ -339,6 +339,7 @@ pub trait OutcomeStore: DerefOutcomes { impl> OutcomeStore for D {} #[cfg(test)] +#[cfg_attr(coverage_nightly, coverage(off))] mod tests { use super::OutcomeStore; use crate::bee::{Context, Worker, WorkerResult}; diff --git a/src/hive/weighted.rs b/src/hive/weighted.rs index b6defed..317be43 100644 --- a/src/hive/weighted.rs +++ b/src/hive/weighted.rs @@ -1,7 +1,9 @@ //! Weighted value used for task submission with the `local-batch` feature. +use num::ToPrimitive; use std::ops::Deref; /// Wraps a value of type `T` and an associated weight. +#[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct Weighted { value: T, weight: u32, @@ -9,8 +11,11 @@ pub struct Weighted { impl Weighted { /// Creates a new `Weighted` instance with the given value and weight. - pub fn new(value: T, weight: u32) -> Self { - Self { value, weight } + pub fn new(value: T, weight: P) -> Self { + Self { + value, + weight: weight.to_u32().unwrap(), + } } /// Creates a new `Weighted` instance with the given value and weight obtained from calling the @@ -27,9 +32,9 @@ impl Weighted { /// the value into a `u32`. pub fn from_identity(value: T) -> Self where - T: Into + Clone, + T: ToPrimitive + Clone, { - let weight = value.clone().into(); + let weight = value.clone().to_u32().unwrap(); Self::new(value, weight) } @@ -58,8 +63,8 @@ impl From for Weighted { } } -impl From<(T, u32)> for Weighted { - fn from((value, weight): (T, u32)) -> Self { +impl From<(T, P)> for Weighted { + fn from((value, weight): (T, P)) -> Self { Self::new(value, weight) } } @@ -86,10 +91,9 @@ pub trait WeightedIteratorExt: IntoIterator + Sized { fn into_identity_weighted(self) -> impl Iterator> where - Self::Item: Into + Clone, + Self::Item: ToPrimitive + Clone, { - self.into_iter() - .map(move |item| Weighted::from_identity(item)) + self.into_iter().map(Weighted::from_identity) } fn into_weighted_zip(self, weights: W) -> impl Iterator> @@ -101,6 +105,16 @@ pub trait WeightedIteratorExt: IntoIterator + Sized { .zip(weights.into_iter().chain(std::iter::repeat(0))) .map(Into::into) } + + fn into_weighted_with(self, f: F) -> impl Iterator> + where + F: Fn(&Self::Item) -> u32, + { + self.into_iter().map(move |item| { + let weight = f(&item); + Weighted::new(item, weight) + }) + } } impl WeightedIteratorExt for T {} @@ -134,11 +148,21 @@ pub trait WeightedExactSizeIteratorExt: IntoIterator + Sized { fn into_identity_weighted(self) -> impl ExactSizeIterator> where - Self::Item: Into + Clone, + Self::Item: ToPrimitive + Clone, Self::IntoIter: ExactSizeIterator + 'static, { - self.into_iter() - .map(move |item| Weighted::from_identity(item)) + self.into_iter().map(Weighted::from_identity) + } + + fn into_weighted_with(self, f: F) -> impl ExactSizeIterator> + where + Self::IntoIter: ExactSizeIterator + 'static, + F: Fn(&Self::Item) -> u32, + { + self.into_iter().map(move |item| { + let weight = f(&item); + Weighted::new(item, weight) + }) } } @@ -148,3 +172,137 @@ where T::IntoIter: ExactSizeIterator, { } + +#[cfg(test)] +#[cfg_attr(coverage_nightly, coverage(off))] +mod tests { + use super::Weighted; + + #[test] + fn test_new() { + let weighted = Weighted::new(42, 10); + assert_eq!(*weighted, 42); + assert_eq!(weighted.weight(), 10); + assert_eq!(weighted.into_parts(), (42, 10)); + } + + #[test] + fn test_from_fn() { + let weighted = Weighted::from_fn(42, |x| x * 2); + assert_eq!(*weighted, 42); + assert_eq!(weighted.weight(), 84); + } + + #[test] + fn test_from_identity() { + let weighted = Weighted::from_identity(42); + assert_eq!(*weighted, 42); + assert_eq!(weighted.weight(), 42); + } + + #[test] + fn test_from_unweighted() { + let weighted = Weighted::from(42); + assert_eq!(*weighted, 42); + assert_eq!(weighted.weight(), 0); + } + + #[test] + fn test_from_tuple() { + let weighted: Weighted = Weighted::from((42, 10)); + assert_eq!(*weighted, 42); + assert_eq!(weighted.weight(), 10); + assert_eq!(weighted.into_parts(), (42, 10)); + } +} + +#[cfg(test)] +#[cfg_attr(coverage_nightly, coverage(off))] +mod iter_tests { + use super::WeightedIteratorExt; + + #[test] + fn test_into_weighted() { + (0..10) + .map(|i| (i, i)) + .into_weighted() + .for_each(|weighted| assert_eq!(weighted.weight(), weighted.value)); + } + + #[test] + fn test_into_default_weighted() { + (0..10) + .into_default_weighted() + .for_each(|weighted| assert_eq!(weighted.weight(), 0)); + } + + #[test] + fn test_into_identity_weighted() { + (0..10) + .into_identity_weighted() + .for_each(|weighted| assert_eq!(weighted.weight(), weighted.value)); + } + + #[test] + fn test_into_const_weighted() { + (0..10) + .into_const_weighted(5) + .for_each(|weighted| assert_eq!(weighted.weight(), 5)); + } + + #[test] + fn test_into_weighted_zip() { + (0..10) + .into_weighted_zip(10..20) + .for_each(|weighted| assert_eq!(weighted.weight(), weighted.value + 10)); + } + + #[test] + fn test_into_weighted_with() { + (0..10) + .into_weighted_with(|i| i * 2) + .for_each(|weighted| assert_eq!(weighted.weight(), weighted.value * 2)); + } +} + +#[cfg(test)] +#[cfg_attr(coverage_nightly, coverage(off))] +mod exact_iter_test { + use super::WeightedExactSizeIteratorExt; + + #[test] + fn test_into_weighted() { + (0..10) + .map(|i| (i, i)) + .into_weighted() + .for_each(|weighted| assert_eq!(weighted.weight(), weighted.value)); + } + + #[test] + fn test_into_default_weighted() { + (0..10) + .into_default_weighted() + .for_each(|weighted| assert_eq!(weighted.weight(), 0)); + } + + #[test] + fn test_into_identity_weighted() { + (0..10) + .into_identity_weighted() + .for_each(|weighted| assert_eq!(weighted.weight(), weighted.value)); + } + + #[test] + fn test_into_const_weighted() { + (0..10) + .into_const_weighted(5) + .for_each(|weighted| assert_eq!(weighted.weight(), 5)); + } + + #[test] + fn test_into_weighted_with() { + (0..10) + .into_weighted_with(|i| i * 2) + .for_each(|weighted| assert_eq!(weighted.weight(), weighted.value * 2)); + } +} diff --git a/src/lib.rs b/src/lib.rs index 351bfd5..2512472 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +#![cfg_attr(coverage_nightly, feature(coverage_attribute))] //! A Rust library that provides a [thread pool](https://en.wikipedia.org/wiki/Thread_pool) //! implementation designed to execute the same operation in parallel on any number of inputs (this //! is sometimes called a "worker pool"). @@ -327,6 +328,7 @@ //! ``` mod atomic; #[cfg(test)] +#[cfg_attr(coverage_nightly, coverage(off))] mod barrier; pub mod bee; mod boxed; diff --git a/src/panic.rs b/src/panic.rs index d6adef6..06f76ad 100644 --- a/src/panic.rs +++ b/src/panic.rs @@ -1,19 +1,21 @@ //! Data type that wraps a `panic` payload. use super::boxed::BoxedFnOnce; +use derive_more::Debug; use std::any::Any; -use std::fmt::Debug; +use std::fmt; use std::panic::AssertUnwindSafe; pub type PanicPayload = Box; /// Wraps a payload from a caught `panic` with an optional `detail`. #[derive(Debug)] -pub struct Panic { +pub struct Panic { + #[debug("")] payload: PanicPayload, detail: Option, } -impl Panic { +impl Panic { /// Attempts to call the provided function `f` and catches any panic. Returns either the return /// value of the function or a `Panic` created from the panic payload and the provided `detail`. pub fn try_call O>(detail: Option, f: F) -> Result { @@ -44,15 +46,16 @@ impl Panic { } } -impl PartialEq for Panic { +impl PartialEq for Panic { fn eq(&self, other: &Self) -> bool { (*self.payload).type_id() == (*other.payload).type_id() && self.detail == other.detail } } -impl Eq for Panic {} +impl Eq for Panic {} #[cfg(test)] +#[cfg_attr(coverage_nightly, coverage(off))] mod tests { use super::Panic; use std::fmt::Debug; diff --git a/src/util.rs b/src/util.rs index 0ae0be8..1bbb385 100644 --- a/src/util.rs +++ b/src/util.rs @@ -30,7 +30,7 @@ where { ChannelBuilder::default() .num_threads(num_threads) - .with_worker(Caller::of(f)) + .with_worker(Caller::from(f)) .build() .map(inputs) .map(Outcome::unwrap) @@ -70,13 +70,14 @@ where { ChannelBuilder::default() .num_threads(num_threads) - .with_worker(OnceCaller::of(f)) + .with_worker(OnceCaller::from(f)) .build() .map(inputs) .into() } #[cfg(test)] +#[cfg_attr(coverage_nightly, coverage(off))] mod tests { use crate::hive::{Outcome, OutcomeStore}; @@ -157,13 +158,14 @@ mod retry { ChannelBuilder::default() .num_threads(num_threads) .max_retries(max_retries) - .with_worker(RetryCaller::of(f)) + .with_worker(RetryCaller::from(f)) .build() .map(inputs) .into() } #[cfg(test)] + #[cfg_attr(coverage_nightly, coverage(off))] mod tests { use crate::bee::ApplyError; use crate::hive::{Outcome, OutcomeStore}; From 1b128f407a0c24a5d9ad7ff54db0a9e962d78feb Mon Sep 17 00:00:00 2001 From: jdidion Date: Wed, 12 Mar 2025 13:25:07 -0700 Subject: [PATCH 46/67] fix lints --- src/hive/outcome/impl.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hive/outcome/impl.rs b/src/hive/outcome/impl.rs index 5d66ea1..79cf339 100644 --- a/src/hive/outcome/impl.rs +++ b/src/hive/outcome/impl.rs @@ -269,7 +269,7 @@ mod tests { error: (), task_id: 1, }; - assert!(matches!(outcome.success(), None)) + assert!(outcome.success().is_none()); } #[test] From a9d74160bf4627c4182eaceda3c5fdb4281476fe Mon Sep 17 00:00:00 2001 From: jdidion Date: Fri, 14 Mar 2025 11:00:11 -0700 Subject: [PATCH 47/67] add tests --- src/bee/context.rs | 9 + src/hive/builder/full.rs | 49 ++++ src/hive/builder/open.rs | 127 +++++++++++ src/hive/context.rs | 5 + src/hive/cores.rs | 205 ++++++++++++----- src/hive/hive.rs | 110 +++++---- src/hive/inner/queue/channel.rs | 5 + src/hive/inner/queue/mod.rs | 4 + src/hive/inner/queue/workstealing.rs | 9 + src/hive/inner/shared.rs | 9 +- src/hive/mock.rs | 5 + src/hive/mod.rs | 324 ++++++++++++++++++++++++++- src/hive/outcome/queue.rs | 10 +- 13 files changed, 730 insertions(+), 141 deletions(-) diff --git a/src/bee/context.rs b/src/bee/context.rs index 36eb2e8..ea7c8ba 100644 --- a/src/bee/context.rs +++ b/src/bee/context.rs @@ -13,6 +13,9 @@ pub trait LocalContext: Debug { /// Submits a new task to the `Hive` that is executing the current task. fn submit_task(&self, input: I) -> TaskId; + + #[cfg(test)] + fn thread_index(&self) -> usize; } /// The context visible to a task when processing an input. @@ -87,6 +90,12 @@ impl<'a, I> Context<'a, I> { } } + /// Returns the unique index of the worker thread executing this task. + #[cfg(test)] + pub fn thread_index(&self) -> Option { + self.local.map(|local| local.thread_index()) + } + /// Consumes this `Context` and returns the IDs of the subtasks spawned during the execution /// of the task, if any. pub(crate) fn into_parts(self) -> (TaskMeta, Option>) { diff --git a/src/hive/builder/full.rs b/src/hive/builder/full.rs index 52de007..2662bcd 100644 --- a/src/hive/builder/full.rs +++ b/src/hive/builder/full.rs @@ -67,3 +67,52 @@ impl> BuilderConfig for FullBuilder { &mut self.config } } + +#[cfg(test)] +#[cfg_attr(coverage_nightly, coverage(off))] +mod tests { + use super::*; + use crate::bee::Queen; + use crate::bee::stock::EchoWorker; + use crate::hive::{ChannelTaskQueues, WorkstealingTaskQueues}; + use rstest::rstest; + + #[derive(Clone, Default)] + struct TestQueen; + + impl Queen for TestQueen { + type Kind = EchoWorker; + + fn create(&self) -> Self::Kind { + EchoWorker::default() + } + } + + #[rstest] + fn test_channel( + #[values( + FullBuilder::>>::empty::, + FullBuilder::>>::preset:: + )] + factory: F, + ) where + F: Fn(TestQueen) -> FullBuilder>>, + { + let builder = factory(TestQueen); + let _hive = builder.build(); + } + + #[rstest] + fn test_workstealing( + #[values( + FullBuilder::>>::empty::, + FullBuilder::>>::preset:: + )] + factory: F, + ) where + F: Fn(TestQueen) -> FullBuilder>>, + { + let builder = factory(TestQueen); + let _hive = builder.build(); + } +} diff --git a/src/hive/builder/open.rs b/src/hive/builder/open.rs index 14c9196..ea81a06 100644 --- a/src/hive/builder/open.rs +++ b/src/hive/builder/open.rs @@ -266,3 +266,130 @@ impl From for OpenBuilder { Self(value) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::bee::stock::EchoWorker; + use crate::bee::{CloneQueen, DefaultQueen, Queen, QueenCell, QueenMut}; + use crate::hive::{FullBuilder, TaskQueues, TaskQueuesBuilder}; + use rstest::rstest; + + #[derive(Clone, Default)] + struct TestQueen; + + impl Queen for TestQueen { + type Kind = EchoWorker; + + fn create(&self) -> Self::Kind { + EchoWorker::default() + } + } + + impl QueenMut for TestQueen { + type Kind = EchoWorker; + + fn create(&mut self) -> Self::Kind { + EchoWorker::default() + } + } + + #[rstest] + fn test_create( + #[values(OpenBuilder::empty, OpenBuilder::default)] factory: F, + #[values( + OpenBuilder::with_channel_queues, + OpenBuilder::with_workstealing_queues + )] + with_fn: W, + ) where + F: Fn() -> OpenBuilder, + B: TaskQueuesBuilder, + W: Fn(OpenBuilder) -> B, + { + let open_builder = factory(); + let queue_builder = with_fn(open_builder); + let _hive = queue_builder + .with_worker(EchoWorker::::default()) + .build(); + } + + #[rstest] + fn test_queen( + #[values(OpenBuilder::empty, OpenBuilder::default)] factory: F, + #[values(BeeBuilder::with_channel_queues, BeeBuilder::with_workstealing_queues)] with_fn: W, + ) where + F: Fn() -> OpenBuilder, + T: TaskQueues>, + W: Fn(BeeBuilder) -> FullBuilder, + { + let open_builder = factory(); + let bee_builder = open_builder.with_queen(TestQueen); + let queue_builder = with_fn(bee_builder); + let _hive = queue_builder.build(); + } + + #[rstest] + fn test_queen_default( + #[values(OpenBuilder::empty, OpenBuilder::default)] factory: F, + #[values(BeeBuilder::with_channel_queues, BeeBuilder::with_workstealing_queues)] with_fn: W, + ) where + F: Fn() -> OpenBuilder, + T: TaskQueues>, + W: Fn(BeeBuilder) -> FullBuilder, + { + let open_builder = factory(); + let bee_builder = open_builder.with_queen_default::(); + let queue_builder = with_fn(bee_builder); + let _hive = queue_builder.build(); + } + + #[rstest] + fn test_queen_mut_default( + #[values(OpenBuilder::empty, OpenBuilder::default)] factory: F, + #[values(BeeBuilder::with_channel_queues, BeeBuilder::with_workstealing_queues)] with_fn: W, + ) where + F: Fn() -> OpenBuilder, + T: TaskQueues>, + W: Fn(BeeBuilder>) -> FullBuilder, T>, + { + let open_builder = factory(); + let bee_builder = open_builder.with_queen_mut_default::(); + let queue_builder = with_fn(bee_builder); + let _hive = queue_builder.build(); + } + + #[rstest] + fn test_worker( + #[values(OpenBuilder::empty, OpenBuilder::default)] factory: F, + #[values(BeeBuilder::with_channel_queues, BeeBuilder::with_workstealing_queues)] with_fn: W, + ) where + F: Fn() -> OpenBuilder, + T: TaskQueues>, + W: Fn( + BeeBuilder>>, + ) -> FullBuilder>, T>, + { + let open_builder = factory(); + let bee_builder = open_builder.with_worker(EchoWorker::default()); + let queue_builder = with_fn(bee_builder); + let _hive = queue_builder.build(); + } + + #[rstest] + fn test_worker_default( + #[values(OpenBuilder::empty, OpenBuilder::default)] factory: F, + #[values(BeeBuilder::with_channel_queues, BeeBuilder::with_workstealing_queues)] with_fn: W, + ) where + F: Fn() -> OpenBuilder, + T: TaskQueues>, + W: Fn( + BeeBuilder>>, + ) -> FullBuilder>, T>, + { + let open_builder = factory(); + let bee_builder = open_builder.with_worker_default::>(); + let queue_builder = with_fn(bee_builder); + let _hive = queue_builder.build(); + } +} diff --git a/src/hive/context.rs b/src/hive/context.rs index 36f58f5..d8854ca 100644 --- a/src/hive/context.rs +++ b/src/hive/context.rs @@ -51,6 +51,11 @@ where self.worker_queues.push(task); task_id } + + #[cfg(test)] + fn thread_index(&self) -> usize { + self.worker_queues.thread_index() + } } impl fmt::Debug for HiveLocalContext<'_, W, Q, T> diff --git a/src/hive/cores.rs b/src/hive/cores.rs index cca74d1..769f3dd 100644 --- a/src/hive/cores.rs +++ b/src/hive/cores.rs @@ -1,6 +1,6 @@ //! Utilities for pinning worker threads to CPU cores in a `Hive`. use core_affinity::{self, CoreId}; -use parking_lot::Mutex; +use parking_lot::{Mutex, MutexGuard}; use std::collections::HashSet; use std::ops::{BitOr, BitOrAssign, Sub, SubAssign}; use std::sync::LazyLock; @@ -17,62 +17,55 @@ use std::sync::LazyLock; /// If new cores become available during the life of the program, they are immediately available /// for worker thread scheduling, but they are *not* available for pinning until the /// `refresh()` function is called. -static CORES: LazyLock>> = LazyLock::new(|| { - core_affinity::get_core_ids() - .map(|core_ids| core_ids.into_iter().map(Core::new).collect()) - .or_else(|| Some(Vec::new())) - .map(Mutex::new) - .unwrap() -}); - -/// Updates `CORES` with the currently available CPU core IDs. The correspondence between the -/// index in the sequence and the core ID is maintained for any core IDs already in the sequence. -/// If a previously available core has become unavailable, its `available` flag is set to `false`. -/// Any new cores are appended to the end of the sequence. Returns the number of new cores added to -/// the sequence. -pub fn refresh() -> usize { - let mut cur_ids = CORES.lock(); - let mut new_ids: HashSet<_> = core_affinity::get_core_ids() - .map(|core_ids| core_ids.into_iter().collect()) - .unwrap_or_default(); - cur_ids.iter_mut().for_each(|core| { - if new_ids.contains(&core.id) { - core.available = true; - new_ids.remove(&core.id); - } else { - core.available = false; - } - }); - let num_new_ids = new_ids.len(); - cur_ids.extend(new_ids.into_iter().map(Core::new)); - num_new_ids -} +pub static CORES: LazyLock = LazyLock::new(|| CoreIds::from_system()); -/// Represents a CPU core. -#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct Core { - /// the OS-specific core ID - id: CoreId, - /// whether this core is currently available for pinning threads - available: bool, -} +/// Global list of CPU core IDs. +/// +/// This is meant to be created at most once, when `CORES` is initialized. +pub struct CoreIds(Mutex>); -impl Core { - /// Creates a new `Core` with `available` set to `true`. - fn new(core_id: CoreId) -> Self { - Self { - id: core_id, - available: true, - } +impl CoreIds { + fn from_system() -> Self { + Self::new( + core_affinity::get_core_ids() + .map(|core_ids| core_ids.into_iter().map(Core::from).collect()) + .unwrap_or_default(), + ) } - /// Attempts to pin the current thread to this CPU core. Returns `true` if the thread was - /// successfully pinned. - /// - /// If the `available` flag is `false`, this immediately returns `false` and does not attempt - /// to pin the thread. - pub fn try_pin_current(&self) -> bool { - self.available && core_affinity::set_for_current(self.id) + fn new(core_ids: Vec) -> Self { + Self(Mutex::new(core_ids)) + } + + fn get(&self, index: usize) -> Option { + self.0.lock().get(index).cloned() + } + + fn update_from(&self, mut new_ids: HashSet) -> usize { + let mut cur_ids = self.0.lock(); + cur_ids.iter_mut().for_each(|core| { + if new_ids.contains(&core.id) { + core.available = true; + new_ids.remove(&core.id); + } else { + core.available = false; + } + }); + let num_new_ids = new_ids.len(); + cur_ids.extend(new_ids.into_iter().map(Core::from)); + num_new_ids + } + + /// Updates `CORES` with the currently available CPU core IDs. The correspondence between the + /// index in the sequence and the core ID is maintained for any core IDs already in the + /// sequence. If a previously available core has become unavailable, its `available` flag is + /// set to `false`. Any new cores are appended to the end of the sequence. Returns the number + /// of new cores added to the sequence. + pub fn refresh(&self) -> usize { + let new_ids: HashSet<_> = core_affinity::get_core_ids() + .map(|core_ids| core_ids.into_iter().collect()) + .unwrap_or_default(); + self.update_from(new_ids) } } @@ -119,10 +112,9 @@ impl Cores { /// Returns the `Core` associated with the specified index if the index exists and the core /// is available, otherwise returns `None`. pub fn get(&self, index: usize) -> Option { - let cores = CORES.lock(); self.0 .get(index) - .and_then(|&index| cores.get(index).cloned()) + .and_then(|&index| CORES.get(index)) .filter(|core| core.available) } @@ -130,13 +122,7 @@ impl Cores { /// set the core affinity of the current thread. The `core` will be `None` for cores that are /// not currently available. pub fn iter(&self) -> impl Iterator)> { - let cores = CORES.lock(); - self.0.iter().cloned().map(move |index| { - ( - index, - cores.get(index).filter(|core| core.available).cloned(), - ) - }) + CoreIter::new(self.0.iter().cloned()) } } @@ -184,10 +170,105 @@ impl> From for Cores { } } +/// Iterator over core (index, id) tuples. This itertor holds the `MutexGuard` for the shared +/// global `CoreIds`, so only one thread can iterate at a time. +pub struct CoreIter<'a, I: Iterator> { + index_iter: I, + cores: MutexGuard<'a, Vec>, +} + +impl<'a, I: Iterator> CoreIter<'a, I> { + fn new(index_iter: I) -> Self { + Self { + index_iter, + cores: CORES.0.lock(), + } + } +} + +impl<'a, I: Iterator> Iterator for CoreIter<'a, I> { + type Item = (usize, Option); + + fn next(&mut self) -> Option { + let index = self.index_iter.next()?; + let core = self.cores.get(index).cloned().filter(|core| core.available); + Some((index, core)) + } +} + +/// Represents a CPU core. +#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct Core { + /// the OS-specific core ID + id: CoreId, + /// whether this core is currently available for pinning threads + available: bool, +} + +impl Core { + fn new(id: CoreId, available: bool) -> Self { + Self { id, available } + } + + /// Attempts to pin the current thread to this CPU core. Returns `true` if the thread was + /// successfully pinned. + /// + /// If the `available` flag is `false`, this immediately returns `false` and does not attempt + /// to pin the thread. + pub fn try_pin_current(&self) -> bool { + self.available && core_affinity::set_for_current(self.id) + } +} + +impl From for Core { + /// Creates a new `Core` with `available` set to `true`. + fn from(id: CoreId) -> Self { + Self::new(id, true) + } +} + #[cfg(test)] #[cfg_attr(coverage_nightly, coverage(off))] mod tests { use super::*; + use std::collections::HashSet; + + #[test] + fn test_core_ids() { + let core_ids = CoreIds::new((0..10usize).map(|id| Core::from(CoreId { id })).collect()); + assert_eq!( + (0..10) + .flat_map(|i| core_ids.get(i).map(|id| id.id)) + .collect::>(), + (0..10).map(|id| CoreId { id }).collect::>() + ); + assert!( + (0..10) + .map(|i| core_ids.get(i).map(|id| id.available).unwrap_or_default()) + .all(std::convert::identity) + ); + let new_ids: HashSet = vec![10, 11, 1, 3, 5, 7, 9] + .into_iter() + .map(|id| CoreId { id }) + .collect(); + let num_added = core_ids.update_from(new_ids); + assert_eq!(num_added, 2); + let mut new_core_ids = (0..12) + .flat_map(|i| core_ids.get(i).map(|id| id.id)) + .collect::>(); + new_core_ids.sort(); + assert_eq!( + new_core_ids, + (0..12).map(|id| CoreId { id }).collect::>() + ); + assert_eq!( + (0..12) + .flat_map(|i| core_ids.get(i)) + .filter(|id| id.available) + .count(), + 7 + ); + } #[test] fn test_empty() { diff --git a/src/hive/hive.rs b/src/hive/hive.rs index b28e129..0d96dac 100644 --- a/src/hive/hive.rs +++ b/src/hive/hive.rs @@ -77,6 +77,9 @@ impl, T: TaskQueues> Hive { &worker_queues, ) { Ok(_) => return, + // currently, the only implementation of retry queue cannot be put into + // a state where `try_send_retry` fails, so this cannot be tested + #[cfg_attr(coverage_nightly, coverage(off))] Err(task) => { let (input, task_meta, _) = task.into_parts(); Outcome::from_fatal(input, task_meta, error) @@ -95,11 +98,6 @@ impl, T: TaskQueues> Hive { }) } - #[inline] - fn shared(&self) -> &Arc> { - self.0.as_ref().unwrap() - } - /// Attempts to increase the number of worker threads by `num_threads`. Returns the number of /// new worker threads that were successfully started (which may be fewer than `num_threads`), /// or a `Poisoned` error if the hive has been poisoned. @@ -550,6 +548,24 @@ impl, T: TaskQueues> Hive { self.shared().set_suspended(false); } + /// Re-submits any unprocessed tasks for processing, with their results to be sent to `tx`. + /// + /// Returns a [`Vec`] of task IDs that were submitted. + pub fn swarm_unprocessed_send>>( + &self, + outcome_tx: X, + ) -> Vec { + self.swarm_send(self.take_unprocessed_inputs(), outcome_tx) + } + + /// Re-submits any unprocessed tasks for processing, with their results to be stored in the + /// hive. + /// + /// Returns a [`Vec`] of task IDs that were resumed. + pub fn swarm_unprocessed_store(&self) -> Vec { + self.swarm_store(self.take_unprocessed_inputs()) + } + /// Removes all `Unprocessed` outcomes from this `Hive` and returns them as an iterator over /// the input values. fn take_unprocessed_inputs(&self) -> impl ExactSizeIterator { @@ -562,56 +578,6 @@ impl, T: TaskQueues> Hive { }) } - /// If this `Hive` is suspended, resumes this `Hive` and re-submits any unprocessed tasks for - /// processing, with their results to be sent to `tx`. Returns a [`Vec`] of task IDs that - /// were resumed. - pub fn resume_send(&self, outcome_tx: &OutcomeSender) -> Vec { - self.shared() - .set_suspended(false) - .then(|| self.swarm_send(self.take_unprocessed_inputs(), outcome_tx)) - .unwrap_or_default() - } - - /// If this `Hive` is suspended, resumes this `Hive` and re-submit any unprocessed tasks for - /// processing, with their results to be stored in the queue. Returns a [`Vec`] of task IDs - /// that were resumed. - pub fn resume_store(&self) -> Vec { - self.shared() - .set_suspended(false) - .then(|| self.swarm_store(self.take_unprocessed_inputs())) - .unwrap_or_default() - } - - /// Returns all stored outcomes as a [`HashMap`] of task IDs to `Outcome`s. - pub fn take_stored(&self) -> HashMap> { - self.shared().take_outcomes() - } - - /// Consumes this `Hive` and attempts to acquire the shared data object. - /// - /// This closes the task queues so that no more tasks may be submitted. If `urgent` is `true`, - /// worker threads are also prevented from taking any more tasks from the queues; otherwise, - /// this method blocks while all queued are processed. - /// - /// If this `Hive` has been cloned, and those clones have not been dropped, this method returns - /// `None`. - fn try_close(mut self, urgent: bool) -> Option> { - if self.shared().num_referrers() > 1 { - return None; - } - // take the inner value and replace it with `None` - let shared = self.0.take().unwrap(); - // close the global queue to prevent new tasks from being submitted - shared.close_task_queues(urgent); - // wait for all tasks to finish - shared.wait_on_done(); - // unwrap the Arc and return the inner Shared value - Some( - super::util::unwrap_arc(shared) - .expect("timeout waiting to take ownership of shared data"), - ) - } - /// Consumes this `Hive` and attempts to shut it down gracefully. /// /// If this `Hive` has been cloned, and those clones have not been dropped, this method returns @@ -659,6 +625,36 @@ impl, T: TaskQueues> Hive { pub fn try_into_husk(self, urgent: bool) -> Option> { self.try_close(urgent).map(|shared| shared.into_husk()) } + + /// Consumes this `Hive` and attempts to acquire the shared data object. + /// + /// This closes the task queues so that no more tasks may be submitted. If `urgent` is `true`, + /// worker threads are also prevented from taking any more tasks from the queues; otherwise, + /// this method blocks while all queued are processed. + /// + /// If this `Hive` has been cloned, and those clones have not been dropped, this method returns + /// `None`. + fn try_close(mut self, urgent: bool) -> Option> { + if self.shared().num_referrers() > 1 { + return None; + } + // take the inner value and replace it with `None` + let shared = self.0.take().unwrap(); + // close the global queue to prevent new tasks from being submitted + shared.close_task_queues(urgent); + // wait for all tasks to finish + shared.wait_on_done(); + // unwrap the Arc and return the inner Shared value + Some( + super::util::unwrap_arc(shared) + .expect("timeout waiting to take ownership of shared data"), + ) + } + + #[inline] + fn shared(&self) -> &Arc> { + self.0.as_ref().unwrap() + } } pub type DefaultHive = Hive, ChannelTaskQueues>; @@ -871,7 +867,7 @@ mod tests { use std::time::Duration; #[test] - fn test_suspend() { + fn test_suspend_resume() { let hive = channel_builder(false) .num_threads(4) .with_worker_default::>() @@ -885,6 +881,7 @@ mod tests { thread::sleep(Duration::from_secs(1)); // There should be 4 active tasks and 6 queued tasks. hive.suspend(); + assert!(hive.is_suspended()); assert_eq!(hive.num_tasks(), (6, 4)); // Wait for active tasks to complete. hive.join(); @@ -907,6 +904,7 @@ mod tests { assert_eq!(hive.alive_workers(), 4); // poison hive using private method hive.0.as_ref().unwrap().poison(); + assert!(hive.is_poisoned()); // attempt to spawn a new task assert!(matches!(hive.grow(1), Err(Poisoned))); // make sure the worker count wasn't increased diff --git a/src/hive/inner/queue/channel.rs b/src/hive/inner/queue/channel.rs index d72a853..fece56b 100644 --- a/src/hive/inner/queue/channel.rs +++ b/src/hive/inner/queue/channel.rs @@ -168,6 +168,11 @@ impl WorkerQueues for ChannelWorkerQueues { fn try_push_retry(&self, task: Task) -> Result> { self.shared.try_push_retry(task) } + + #[cfg(test)] + fn thread_index(&self) -> usize { + self.shared._thread_index + } } /// Worker thread-specific data shared with the main thread. diff --git a/src/hive/inner/queue/mod.rs b/src/hive/inner/queue/mod.rs index 8980b5f..399daeb 100644 --- a/src/hive/inner/queue/mod.rs +++ b/src/hive/inner/queue/mod.rs @@ -87,4 +87,8 @@ pub trait WorkerQueues { /// to the retry queue (e.g., if the queue is full), the task returned as an error. #[cfg(feature = "retry")] fn try_push_retry(&self, task: Task) -> Result>; + + /// Returns the unique index of the thread that owns this `WorkerQueues` instance. + #[cfg(test)] + fn thread_index(&self) -> usize; } diff --git a/src/hive/inner/queue/workstealing.rs b/src/hive/inner/queue/workstealing.rs index 4557780..cbf9331 100644 --- a/src/hive/inner/queue/workstealing.rs +++ b/src/hive/inner/queue/workstealing.rs @@ -112,6 +112,9 @@ impl GlobalQueue { } /// Tries to steal a task from a random worker using its `Stealer`. + /// + /// If no tasks are available, sleeps for `EMPTY_DELAY` and returns `PopTaskError::Empty`. + /// Returns `PopTaskError::Closed` if the queue is closed. fn try_steal_from_worker(&self) -> Result, PopTaskError> { let stealers = self.stealers.read(); let n = stealers.len(); @@ -124,6 +127,7 @@ impl GlobalQueue { if self.is_closed() && self.queue.is_empty() { PopTaskError::Closed } else { + // TODO: instead try Backoff-based snoozing used by crossbeam thread::park_timeout(EMPTY_DELAY); PopTaskError::Empty } @@ -241,6 +245,11 @@ impl WorkerQueues for WorkstealingWorkerQueues { fn try_push_retry(&self, task: Task) -> Result> { self.shared.try_push_retry(task) } + + #[cfg(test)] + fn thread_index(&self) -> usize { + self.shared._thread_index + } } impl Deref for WorkstealingWorkerQueues { diff --git a/src/hive/inner/shared.rs b/src/hive/inner/shared.rs index cf192d7..dc9728d 100644 --- a/src/hive/inner/shared.rs +++ b/src/hive/inner/shared.rs @@ -490,11 +490,6 @@ impl, T: TaskQueues> Shared { self.outcomes.push(outcome); } - /// Removes and returns all retained task outcomes. - pub fn take_outcomes(&self) -> HashMap> { - self.outcomes.drain() - } - /// Removes and returns all retained `Unprocessed` outcomes. pub fn take_unprocessed(&self) -> Vec> { let mut outcomes = self.outcomes.get_mut(); @@ -711,7 +706,7 @@ mod retry { /// Returns the current worker retry factor. pub fn worker_retry_factor(&self) -> std::time::Duration { - std::time::Duration::from_millis(self.config.retry_factor.get().unwrap_or_default()) + std::time::Duration::from_nanos(self.config.retry_factor.get().unwrap_or_default()) } /// Returns `true` if the hive is configured to retry tasks and the `attempt` field of the @@ -754,7 +749,7 @@ mod tests { super::Shared, ChannelTaskQueues>; #[test] - fn test_sync_hared() { + fn test_sync_shared() { fn assert_sync() {} assert_sync::(); } diff --git a/src/hive/mock.rs b/src/hive/mock.rs index 6a1e5fd..0db2cd6 100644 --- a/src/hive/mock.rs +++ b/src/hive/mock.rs @@ -70,6 +70,11 @@ where fn submit_task(&self, _: I) -> TaskId { self.0.next_task_id() } + + #[cfg(test)] + fn thread_index(&self) -> usize { + 0 + } } #[cfg(test)] diff --git a/src/hive/mod.rs b/src/hive/mod.rs index c9cd590..7084f44 100644 --- a/src/hive/mod.rs +++ b/src/hive/mod.rs @@ -530,7 +530,10 @@ mod tests { thread::sleep(ONE_SEC); assert_eq!(hive.num_tasks().0, 1); assert!(matches!(rx.try_recv_msg(), Message::ChannelEmpty)); - hive.grow(1).expect("error spawning threads"); + assert!(matches!(hive.grow(0), Ok(0))); + thread::sleep(ONE_SEC); + assert_eq!(hive.num_tasks().0, 1); + assert!(matches!(hive.grow(1), Ok(1))); thread::sleep(ONE_SEC); assert_eq!(hive.num_tasks().0, 0); assert!(matches!( @@ -567,6 +570,27 @@ mod tests { assert_eq!(husk.iter_successes().count(), total_threads); } + #[rstest] + fn test_use_all_cores(#[values(channel_builder, workstealing_builder)] builder_factory: F) + where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + let hive = void_thunk_hive(0, builder_factory(false)); + let num_cores = num_cpus::get(); + // queue some long-running tasks + for _ in 0..num_cores { + hive.apply_store(Thunk::from(|| thread::sleep(LONG_TASK))); + } + thread::sleep(ONE_SEC); + assert_eq!(hive.num_tasks().0, num_cores as u64); + assert_eq!(hive.use_all_cores().unwrap(), num_cores); + assert_eq!(hive.max_workers(), num_cores); + thread::sleep(ONE_SEC); + let husk = hive.try_into_husk(false).unwrap(); + assert_eq!(husk.iter_successes().count(), num_cores); + } + #[rstest] fn test_suspend(#[values(channel_builder, workstealing_builder)] builder_factory: F) where @@ -620,7 +644,43 @@ mod tests { } #[rstest] - fn test_suspend_with_cancelled_tasks( + fn test_suspend_resume_send_with_cancelled_tasks( + #[values(channel_builder, workstealing_builder)] builder_factory: F, + ) where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + let hive: Hive<_, _> = builder_factory(false) + .num_threads(TEST_TASKS) + .with_worker_default::() + .build(); + let _ = hive.swarm_store(0..TEST_TASKS as u8); + // wait for tasks to be started + thread::sleep(Duration::from_millis(500)); + assert_eq!(hive.num_tasks(), (0, TEST_TASKS as u64)); + hive.suspend(); + // wait for tasks to be cancelled + thread::sleep(Duration::from_secs(2)); + assert_eq!(hive.num_tasks(), (0, 0)); + assert_eq!(hive.num_unprocessed(), TEST_TASKS); + hive.resume(); + let (tx, rx) = super::outcome_channel(); + let new_task_ids = hive.swarm_unprocessed_send(tx); + assert_eq!(new_task_ids.len(), TEST_TASKS); + thread::sleep(Duration::from_millis(500)); + // unprocessed tasks should be requeued + assert_eq!(hive.num_tasks(), (0, TEST_TASKS as u64)); + hive.join(); + let mut outputs = rx + .into_iter() + .select_ordered_outputs(new_task_ids) + .collect::>(); + outputs.sort(); + assert_eq!(outputs, (0..TEST_TASKS as u8).collect::>()); + } + + #[rstest] + fn test_suspend_resume_store_with_cancelled_tasks( #[values(channel_builder, workstealing_builder)] builder_factory: F, ) where B: TaskQueuesBuilder, @@ -634,7 +694,8 @@ mod tests { hive.suspend(); // wait for tasks to be cancelled thread::sleep(Duration::from_secs(2)); - hive.resume_store(); + hive.resume(); + hive.swarm_unprocessed_store(); thread::sleep(Duration::from_secs(1)); // unprocessed tasks should be requeued assert_eq!(hive.num_tasks().1, TEST_TASKS as u64); @@ -1310,6 +1371,65 @@ mod tests { assert_eq!(task_ids, outcome_task_ids); } + #[rstest] + fn test_try_scan(#[values(channel_builder, workstealing_builder)] builder_factory: F) + where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + let hive = builder_factory(false) + .with_worker(Caller::from(|i: i32| i * i)) + .num_threads(4) + .build(); + let (outcomes, error, state) = hive.try_scan(0..10, 0, |acc, i| { + *acc += i; + Ok::<_, String>(*acc) + }); + let task_ids: Vec<_> = outcomes.success_task_ids(); + assert_eq!(task_ids.len(), 10); + assert_eq!(error.len(), 0); + assert_eq!(state, 45); + let mut values: Vec<_> = outcomes + .into_iter() + .select_unordered(task_ids) + .into_outputs() + .collect(); + values.sort(); + assert_eq!( + values, + (0..10) + .scan(0, |acc, i| { + *acc += i; + Some(*acc) + }) + .map(|i| i * i) + .collect::>() + ); + } + + #[rstest] + #[should_panic] + fn test_try_scan_fail(#[values(channel_builder, workstealing_builder)] builder_factory: F) + where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + let hive = builder_factory(false) + .with_worker(Caller::from(|i: i32| i * i)) + .num_threads(4) + .build(); + let (outcomes, error, state) = hive.try_scan(0..10, 0, |_, _| Err::("fail")); + let task_ids: Vec<_> = outcomes.success_task_ids(); + assert_eq!(task_ids.len(), 10); + assert_eq!(error.len(), 0); + assert_eq!(state, 45); + let _ = outcomes + .into_iter() + .select_unordered(task_ids) + .into_outputs() + .collect::>(); + } + #[rstest] fn test_try_scan_send(#[values(channel_builder, workstealing_builder)] builder_factory: F) where @@ -1514,6 +1634,36 @@ mod tests { assert_eq!(outputs, (0..NUM_FIRST_TASKS * 2).collect::>()); } + #[rstest] + fn test_close(#[values(channel_builder, workstealing_builder)] builder_factory: F) + where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + let hive1 = thunk_hive::(8, builder_factory(false)); + let _ = hive1.map_store((0..8u8).map(|i| Thunk::from(move || i))); + hive1.join(); + let hive2 = hive1.clone(); + assert!(!hive1.close(false)); + assert!(hive2.close(false)); + } + + #[rstest] + fn test_into_outcomes(#[values(channel_builder, workstealing_builder)] builder_factory: F) + where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + let hive = thunk_hive::(8, builder_factory(false)); + let task_ids = hive.map_store((0..8u8).map(|i| Thunk::from(move || i))); + hive.join(); + let outcomes = hive.try_into_outcomes(false).unwrap(); + for i in task_ids.iter() { + assert!(outcomes.get(i).unwrap().is_success()); + assert!(matches!(outcomes.get(i), Some(Outcome::Success { .. }))); + } + } + #[rstest] fn test_husk(#[values(channel_builder, workstealing_builder)] builder_factory: F) where @@ -1954,8 +2104,11 @@ mod tests { #[cfg(all(test, feature = "affinity"))] mod affinity_tests { use crate::bee::stock::{Thunk, ThunkWorker}; - use crate::hive::{Builder, TaskQueuesBuilder, channel_builder, workstealing_builder}; + use crate::channel::{Message, ReceiverExt}; + use crate::hive::{Builder, Outcome, TaskQueuesBuilder, channel_builder, workstealing_builder}; use rstest::*; + use std::thread; + use std::time::Duration; #[rstest] fn test_affinity(#[values(channel_builder, workstealing_builder)] builder_factory: F) @@ -1980,8 +2133,13 @@ mod affinity_tests { } #[rstest] - fn test_use_all_cores() { - let hive = crate::hive::channel_builder(false) + fn test_use_all_cores_builder( + #[values(channel_builder, workstealing_builder)] builder_factory: F, + ) where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + let hive = builder_factory(false) .thread_name("affinity example") .with_thread_per_core() .with_default_core_affinity() @@ -1996,6 +2154,56 @@ mod affinity_tests { }) })); } + + #[rstest] + fn test_grow_with_affinity( + #[values(channel_builder, workstealing_builder)] builder_factory: F, + ) where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + let hive = builder_factory(false) + .thread_name("affinity example") + .with_default_core_affinity() + .with_worker_default::>() + .build(); + // check that with 0 threads no tasks are scheduled + let (tx, rx) = super::outcome_channel(); + let _ = hive.apply_send(Thunk::from(|| 0), &tx); + thread::sleep(Duration::from_secs(1)); + assert_eq!(hive.num_tasks().0, 1); + assert!(matches!(rx.try_recv_msg(), Message::ChannelEmpty)); + assert!(matches!(hive.grow_with_affinity(0, vec![]), Ok(0))); + thread::sleep(Duration::from_secs(1)); + assert_eq!(hive.num_tasks().0, 1); + assert!(matches!(hive.grow_with_affinity(1, vec![0]), Ok(1))); + thread::sleep(Duration::from_secs(1)); + assert_eq!(hive.num_tasks().0, 0); + assert!(matches!( + rx.try_recv_msg(), + Message::Received(Outcome::Success { value: 0, .. }) + )); + } + + #[rstest] + fn test_use_all_cores_hive() { + let hive = crate::hive::channel_builder(false) + .thread_name("affinity example") + .with_default_core_affinity() + .with_worker_default::>() + .build(); + + let num_cores = num_cpus::get(); + assert_eq!(hive.use_all_cores_with_affinity().unwrap(), num_cores); + + hive.map_store((0..num_cpus::get()).map(move |i| { + Thunk::from(move || { + if let Some(affininty) = core_affinity::get_core_ids() { + eprintln!("task {} on thread with affinity {:?}", i, affininty); + } + }) + })); + } } #[cfg(all(test, feature = "local-batch"))] @@ -2184,16 +2392,21 @@ mod local_batch_tests { assert!(thread_counts.values().all(|count| *count > BATCH_LIMIT_0)); assert_eq!(thread_counts.values().sum::(), total_tasks); } + + #[test] + fn test_change_channel_batch_limit_nonempty() {} } #[cfg(all(test, feature = "local-batch"))] mod weighted_map_tests { - use crate::bee::stock::{Thunk, ThunkWorker}; + use crate::bee::stock::{RetryCaller, Thunk, ThunkWorker}; + use crate::bee::{ApplyError, Context}; use crate::hive::{ Builder, Outcome, TaskQueuesBuilder, Weighted, WeightedIteratorExt, channel_builder, workstealing_builder, }; use rstest::*; + use std::collections::HashMap; use std::thread; use std::time::Duration; @@ -2223,6 +2436,54 @@ mod weighted_map_tests { assert_eq!(outputs, (0..10).collect::>()) } + #[rstest] + fn test_map_weighted_with_limit( + #[values(channel_builder, workstealing_builder)] builder_factory: F, + ) where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + const NUM_THREADS: usize = 4; + const NUM_TASKS_PER_THREAD: usize = 3; + const NUM_TASKS: usize = NUM_THREADS * NUM_TASKS_PER_THREAD; + const BATCH_LIMIT: usize = 10; + const WEIGHT: u32 = 25; + const WEIGHT_LIMIT: u64 = WEIGHT as u64 * NUM_TASKS_PER_THREAD as u64; + // schedule 2 * NUM_THREADS tasks, and set the weight limit at 2 * task weight, such that, + // even though the batch size is > 2, each thread should only take 2 tasks + let hive = builder_factory(false) + .with_worker(RetryCaller::from( + |i: u8, ctx: &Context| -> Result<(u8, Option), ApplyError> { + thread::sleep(Duration::from_millis(100)); + Ok((i, ctx.thread_index())) + }, + )) + .num_threads(NUM_THREADS) + .batch_limit(BATCH_LIMIT) + .weight_limit(WEIGHT_LIMIT) + .build(); + let inputs = (0..NUM_TASKS as u8).map(|i| (i, WEIGHT)).into_weighted(); + let (mut outputs, thread_indices) = hive + .map(inputs) + .map(Outcome::unwrap) + .unzip::<_, _, Vec<_>, Vec<_>>(); + outputs.sort(); + assert_eq!(outputs, (0..NUM_TASKS as u8).collect::>()); + let counts = + thread_indices + .into_iter() + .flatten() + .fold(HashMap::new(), |mut counts, index| { + counts + .entry(index) + .and_modify(|count| *count += 1) + .or_insert(1); + counts + }); + println!("{:?}", counts); + assert!(counts.values().all(|&count| count == NUM_TASKS_PER_THREAD)); + } + #[rstest] fn test_overweight() { const WEIGHT_LIMIT: u64 = 99; @@ -2237,6 +2498,26 @@ mod weighted_map_tests { Outcome::WeightLimitExceeded { weight: 100, .. } )) } + + #[rstest] + fn test_set_weight_limit() { + const WEIGHT_LIMIT: u64 = 99; + let hive = channel_builder(false) + .with_worker_default::>() + .num_threads(1) + .weight_limit(WEIGHT_LIMIT) + .build(); + assert_eq!(WEIGHT_LIMIT, hive.worker_weight_limit()); + let outcome = hive.apply(Weighted::new(Thunk::from(|| 0), WEIGHT_LIMIT + 1)); + assert!(matches!( + outcome, + Outcome::WeightLimitExceeded { weight: 100, .. } + )); + hive.set_worker_weight_limit(WEIGHT_LIMIT + 1); + assert_eq!(WEIGHT_LIMIT + 1, hive.worker_weight_limit()); + let outcome = hive.apply(Weighted::new(Thunk::from(|| 0), WEIGHT_LIMIT + 1)); + assert!(matches!(outcome, Outcome::Success { .. })); + } } #[cfg(all(test, feature = "local-batch"))] @@ -2451,4 +2732,33 @@ mod retry_tests { let v: Result, _> = hive.swarm(0..10usize).into_results().collect(); assert!(v.is_err()); } + + #[rstest] + fn test_change_retry_limit( + #[values(channel_builder, workstealing_builder)] builder_factory: F, + ) where + B: TaskQueuesBuilder, + F: Fn(bool) -> B, + { + let hive = builder_factory(false) + .with_worker(RetryCaller::from(echo_time)) + .with_thread_per_core() + .with_no_retries() + .build(); + + assert_eq!(hive.worker_retry_limit(), 0); + assert_eq!(hive.worker_retry_factor(), Duration::from_secs(0)); + + let v: Result, _> = hive.swarm(0..10usize).into_results().collect(); + assert!(v.is_err()); + + hive.set_worker_retry_limit(3); + hive.set_worker_retry_factor(Duration::from_secs(1)); + + assert_eq!(hive.worker_retry_limit(), 3); + assert_eq!(hive.worker_retry_factor(), Duration::from_secs(1)); + + let v: Result, _> = hive.swarm(0..10usize).into_results().collect(); + assert_eq!(v.unwrap().len(), 10); + } } diff --git a/src/hive/outcome/queue.rs b/src/hive/outcome/queue.rs index da66902..e2b11d7 100644 --- a/src/hive/outcome/queue.rs +++ b/src/hive/outcome/queue.rs @@ -26,14 +26,6 @@ impl OutcomeQueue { outcomes } - /// Flushes the queue into the map of outcomes, then takes all outcomes from the map and - /// returns them. - pub fn drain(&self) -> HashMap> { - let mut outcomes: HashMap> = self.outcomes.lock().drain().collect(); - drain_into(&self.queue, &mut outcomes); - outcomes - } - /// Consumes this `OutcomeQueue`, drains the queue, and returns the outcomes as a map. pub fn into_inner(self) -> HashMap> { let mut outcomes = self.outcomes.into_inner(); @@ -93,7 +85,7 @@ mod tests { }); assert_eq!(queue.count(), (1, 1, 1)); queue.push(Outcome::Missing { task_id: 4 }); - let outcomes = queue.drain(); + let outcomes = queue.into_inner(); assert_eq!(outcomes.len(), 4); assert_eq!( outcomes[&1], From d4d68d481a6d876b1ec55f8baf2f05f4c4cdf4ea Mon Sep 17 00:00:00 2001 From: jdidion Date: Fri, 14 Mar 2025 11:15:03 -0700 Subject: [PATCH 48/67] replace rand w nanorand --- Cargo.toml | 7 ++++--- src/hive/inner/queue/workstealing.rs | 4 ++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 816458b..9f428f0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,13 +14,14 @@ crossbeam-deque = "0.8.6" crossbeam-queue = "0.3.12" crossbeam-utils = "0.8.20" derive_more = { version = "2.0.1", features = ["debug"] } +nanorand = { version = "0.7.0", default-features = false, features = [ + "std", + "tls", +] } num = "0.4.3" num_cpus = "1.16.0" parking_lot = "0.12.3" paste = "1.0.15" -rand = { version = "0.9.0", default-features = false, features = [ - "thread_rng", -] } thiserror = "1.0.63" # required with the `affinity` feature core_affinity = { version = "0.8.1", optional = true } diff --git a/src/hive/inner/queue/workstealing.rs b/src/hive/inner/queue/workstealing.rs index cbf9331..2ac349a 100644 --- a/src/hive/inner/queue/workstealing.rs +++ b/src/hive/inner/queue/workstealing.rs @@ -11,8 +11,8 @@ use crate::bee::Worker; use crossbeam_deque::{Injector, Stealer}; use crossbeam_queue::SegQueue; use derive_more::Debug; +use nanorand::{Rng, tls as rand}; use parking_lot::RwLock; -use rand::prelude::*; use std::ops::Deref; use std::sync::Arc; use std::time::Duration; @@ -119,7 +119,7 @@ impl GlobalQueue { let stealers = self.stealers.read(); let n = stealers.len(); // randomize the stealing order, to prevent always stealing from the same thread - std::iter::from_fn(|| Some(rand::rng().random_range(0..n))) + std::iter::from_fn(|| Some(rand::tls_rng().generate_range(0..n))) .take(n) .filter_map(|i| stealers[i].steal().success()) .next() From 02257e0cce3510aad10a6ace6843e6079e720552 Mon Sep 17 00:00:00 2001 From: jdidion Date: Fri, 14 Mar 2025 11:19:21 -0700 Subject: [PATCH 49/67] allow paste dep --- deny.toml | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/deny.toml b/deny.toml index 2de519b..722fc45 100644 --- a/deny.toml +++ b/deny.toml @@ -88,7 +88,7 @@ ignore = [ # List of explicitly allowed licenses # See https://spdx.org/licenses/ for list of possible licenses # [possible values: any SPDX 3.11 short identifier (+ optional exception)]. -allow = ["MIT", "Apache-2.0", "Unicode-DFS-2016", "Unicode-3.0"] +allow = ["MIT", "Apache-2.0", "Unicode-3.0"] # The confidence threshold for detecting a license from license text. # The higher the value, the more closely the license text must be to the # canonical license text of a valid SPDX license file. @@ -156,10 +156,7 @@ workspace-default-features = "allow" # on a crate-by-crate basis if desired. external-default-features = "allow" # List of crates that are allowed. Use with care! -allow = [ - #"ansi_term@0.11.0", - #{ crate = "ansi_term@0.11.0", reason = "you can specify a reason it is allowed" }, -] +allow = [{ crate = "paste@1.0.15", reason = "Unmaintained but 'finished'" }] # List of crates to deny deny = [ #"ansi_term@0.11.0", From 6965c0266cbb34b8db10fa204242c29981b2beb0 Mon Sep 17 00:00:00 2001 From: jdidion Date: Fri, 14 Mar 2025 11:22:01 -0700 Subject: [PATCH 50/67] fix test --- src/hive/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hive/mod.rs b/src/hive/mod.rs index 7084f44..4148b81 100644 --- a/src/hive/mod.rs +++ b/src/hive/mod.rs @@ -2454,7 +2454,7 @@ mod weighted_map_tests { let hive = builder_factory(false) .with_worker(RetryCaller::from( |i: u8, ctx: &Context| -> Result<(u8, Option), ApplyError> { - thread::sleep(Duration::from_millis(100)); + thread::sleep(Duration::from_millis(500)); Ok((i, ctx.thread_index())) }, )) From eafd0d2b80806859e54e0624ca0d83d82229f62f Mon Sep 17 00:00:00 2001 From: jdidion Date: Fri, 14 Mar 2025 11:32:35 -0700 Subject: [PATCH 51/67] allow paste dep --- deny.toml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/deny.toml b/deny.toml index 722fc45..a03f0c0 100644 --- a/deny.toml +++ b/deny.toml @@ -70,10 +70,7 @@ feature-depth = 1 # A list of advisory IDs to ignore. Note that ignored advisories will still # output a note when they are encountered. ignore = [ - #"RUSTSEC-0000-0000", - #{ id = "RUSTSEC-0000-0000", reason = "you can specify a reason the advisory is ignored" }, - #"a-crate-that-is-yanked@0.1.1", # you can also ignore yanked crate versions if you wish - #{ crate = "a-crate-that-is-yanked@0.1.1", reason = "you can specify why you are ignoring the yanked crate" }, + { id = "RUSTSEC-2024-0436", reason = "paste is considered 'finished'" }, ] # If this is true, then cargo deny will use the git executable to fetch advisory database. # If this is false, then it uses a built-in git library. @@ -156,7 +153,10 @@ workspace-default-features = "allow" # on a crate-by-crate basis if desired. external-default-features = "allow" # List of crates that are allowed. Use with care! -allow = [{ crate = "paste@1.0.15", reason = "Unmaintained but 'finished'" }] +allow = [ + #"ansi_term@0.11.0", + #{ crate = "ansi_term@0.11.0", reason = "you can specify a reason it is allowed" }, +] # List of crates to deny deny = [ #"ansi_term@0.11.0", From fa89d06b4d2217ceaec5be0a4addb1052cef1632 Mon Sep 17 00:00:00 2001 From: jdidion Date: Fri, 14 Mar 2025 12:58:09 -0700 Subject: [PATCH 52/67] fix lints --- src/hive/cores.rs | 12 ++++-------- src/hive/inner/queue/workstealing.rs | 15 +++++++++------ 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/src/hive/cores.rs b/src/hive/cores.rs index 769f3dd..249fe03 100644 --- a/src/hive/cores.rs +++ b/src/hive/cores.rs @@ -17,7 +17,7 @@ use std::sync::LazyLock; /// If new cores become available during the life of the program, they are immediately available /// for worker thread scheduling, but they are *not* available for pinning until the /// `refresh()` function is called. -pub static CORES: LazyLock = LazyLock::new(|| CoreIds::from_system()); +pub static CORES: LazyLock = LazyLock::new(CoreIds::from_system); /// Global list of CPU core IDs. /// @@ -177,7 +177,7 @@ pub struct CoreIter<'a, I: Iterator> { cores: MutexGuard<'a, Vec>, } -impl<'a, I: Iterator> CoreIter<'a, I> { +impl> CoreIter<'_, I> { fn new(index_iter: I) -> Self { Self { index_iter, @@ -186,7 +186,7 @@ impl<'a, I: Iterator> CoreIter<'a, I> { } } -impl<'a, I: Iterator> Iterator for CoreIter<'a, I> { +impl> Iterator for CoreIter<'_, I> { type Item = (usize, Option); fn next(&mut self) -> Option { @@ -242,11 +242,7 @@ mod tests { .collect::>(), (0..10).map(|id| CoreId { id }).collect::>() ); - assert!( - (0..10) - .map(|i| core_ids.get(i).map(|id| id.available).unwrap_or_default()) - .all(std::convert::identity) - ); + assert!((0..10).all(|i| core_ids.get(i).map(|id| id.available).unwrap_or_default())); let new_ids: HashSet = vec![10, 11, 1, 3, 5, 7, 9] .into_iter() .map(|id| CoreId { id }) diff --git a/src/hive/inner/queue/workstealing.rs b/src/hive/inner/queue/workstealing.rs index 2ac349a..66224ca 100644 --- a/src/hive/inner/queue/workstealing.rs +++ b/src/hive/inner/queue/workstealing.rs @@ -113,9 +113,9 @@ impl GlobalQueue { /// Tries to steal a task from a random worker using its `Stealer`. /// - /// If no tasks are available, sleeps for `EMPTY_DELAY` and returns `PopTaskError::Empty`. - /// Returns `PopTaskError::Closed` if the queue is closed. - fn try_steal_from_worker(&self) -> Result, PopTaskError> { + /// Returns the task if one is stolen successfully, otherwise snoozes for a bit and then + /// returns `PopTaskError::Empty`. Returns `PopTaskError::Closed` if the queue is closed. + fn try_steal_from_worker_or_snooze(&self) -> Result, PopTaskError> { let stealers = self.stealers.read(); let n = stealers.len(); // randomize the stealing order, to prevent always stealing from the same thread @@ -127,7 +127,10 @@ impl GlobalQueue { if self.is_closed() && self.queue.is_empty() { PopTaskError::Closed } else { - // TODO: instead try Backoff-based snoozing used by crossbeam + // TODO: instead try the parking approach used in rust-executors, which seems + // more performant under most circumstances + // https://github.com/Bathtor/rust-executors/blob/master/executors/src/crossbeam_workstealing_pool.rs#L976 + thread::park_timeout(EMPTY_DELAY); PopTaskError::Empty } @@ -140,7 +143,7 @@ impl GlobalQueue { if let Some(task) = self.queue.steal().success() { Ok(task) } else { - self.try_steal_from_worker() + self.try_steal_from_worker_or_snooze() } } @@ -186,7 +189,7 @@ impl GlobalQueue { } return Ok(first); } - self.try_steal_from_worker() + self.try_steal_from_worker_or_snooze() } fn is_closed(&self) -> bool { From ad83948a8f697a5f070a5c5b066331a4c988505f Mon Sep 17 00:00:00 2001 From: jdidion Date: Fri, 14 Mar 2025 13:01:25 -0700 Subject: [PATCH 53/67] fix lints --- src/hive/inner/queue/workstealing.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/hive/inner/queue/workstealing.rs b/src/hive/inner/queue/workstealing.rs index 66224ca..08afc96 100644 --- a/src/hive/inner/queue/workstealing.rs +++ b/src/hive/inner/queue/workstealing.rs @@ -130,7 +130,6 @@ impl GlobalQueue { // TODO: instead try the parking approach used in rust-executors, which seems // more performant under most circumstances // https://github.com/Bathtor/rust-executors/blob/master/executors/src/crossbeam_workstealing_pool.rs#L976 - thread::park_timeout(EMPTY_DELAY); PopTaskError::Empty } From f71017697e4af133e2361e2bc98db86de34db129 Mon Sep 17 00:00:00 2001 From: jdidion Date: Fri, 14 Mar 2025 13:11:18 -0700 Subject: [PATCH 54/67] fix doc tests --- src/hive/hive.rs | 2 +- src/hive/inner/builder.rs | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/hive/hive.rs b/src/hive/hive.rs index 0d96dac..182a702 100644 --- a/src/hive/hive.rs +++ b/src/hive/hive.rs @@ -525,7 +525,7 @@ impl, T: TaskQueues> Hive { /// .num_threads(4) /// .with_worker_default::>() /// .build(); - /// hive.map((0..10).map(|_| Thunk::of(|| thread::sleep(Duration::from_secs(3))))); + /// hive.map((0..10).map(|_| Thunk::from(|| thread::sleep(Duration::from_secs(3))))); /// thread::sleep(Duration::from_secs(1)); // Allow first set of tasks to be started. /// // There should be 4 active tasks and 6 queued tasks. /// hive.suspend(); diff --git a/src/hive/inner/builder.rs b/src/hive/inner/builder.rs index 67703f2..7bfb00c 100644 --- a/src/hive/inner/builder.rs +++ b/src/hive/inner/builder.rs @@ -29,7 +29,7 @@ pub trait Builder: BuilderConfig + Sized { /// .build(); /// /// for _ in 0..100 { - /// hive.apply_store(Thunk::of(|| { + /// hive.apply_store(Thunk::from(|| { /// println!("Hello from a worker thread!") /// })); /// } @@ -66,7 +66,7 @@ pub trait Builder: BuilderConfig + Sized { /// .build(); /// /// for _ in 0..100 { - /// hive.apply_store(Thunk::of(|| { + /// hive.apply_store(Thunk::from(|| { /// println!("Hello from a worker thread!") /// })); /// } @@ -99,7 +99,7 @@ pub trait Builder: BuilderConfig + Sized { /// .build(); /// /// for _ in 0..100 { - /// hive.apply_store(Thunk::of(|| { + /// hive.apply_store(Thunk::from(|| { /// assert_eq!(thread::current().name(), Some("foo")); /// })); /// } @@ -132,7 +132,7 @@ pub trait Builder: BuilderConfig + Sized { /// .build(); /// /// for _ in 0..100 { - /// hive.apply_store(Thunk::of(|| { + /// hive.apply_store(Thunk::from(|| { /// println!("This thread has a 4 MB stack size!"); /// })); /// } @@ -171,7 +171,7 @@ pub trait Builder: BuilderConfig + Sized { /// .build(); /// /// for _ in 0..100 { - /// hive.apply_store(Thunk::of(|| { + /// hive.apply_store(Thunk::from(|| { /// println!("This thread is pinned!"); /// })); /// } @@ -291,7 +291,7 @@ pub trait Builder: BuilderConfig + Sized { /// # fn main() { /// let hive = channel_builder(true) /// .max_retries(3) - /// .with_worker(RetryCaller::of(sometimes_fail)) + /// .with_worker(RetryCaller::from(sometimes_fail)) /// .build(); /// /// for i in 0..10 { @@ -338,7 +338,7 @@ pub trait Builder: BuilderConfig + Sized { /// let hive = channel_builder(true) /// .max_retries(3) /// .retry_factor(time::Duration::from_secs(1)) - /// .with_worker(RetryCaller::of(echo_time)) + /// .with_worker(RetryCaller::from(echo_time)) /// .build(); /// /// for i in 0..10 { From 4b1717585c4be350c32d01f61a9b139148ca06db Mon Sep 17 00:00:00 2001 From: jdidion Date: Fri, 14 Mar 2025 13:15:00 -0700 Subject: [PATCH 55/67] fix doc tests --- src/lib.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 2512472..0238b9b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -161,14 +161,14 @@ //! // return results to your own channel... //! let (tx, rx) = outcome_channel(); //! let _ = hive.swarm_send( -//! (0..10).map(|i: i32| Thunk::of(move || i * i)), +//! (0..10).map(|i: i32| Thunk::from(move || i * i)), //! tx //! ); //! assert_eq!(285, rx.into_outputs().take(10).sum()); //! //! // return results as an iterator... //! let total = hive -//! .swarm_unordered((0..10).map(|i: i32| Thunk::of(move || i * -i))) +//! .swarm_unordered((0..10).map(|i: i32| Thunk::from(move || i * -i))) //! .into_outputs() //! .sum(); //! assert_eq!(-285, total); From b7207874c72bb307a91075bc6f88dfa6066402f8 Mon Sep 17 00:00:00 2001 From: jdidion Date: Fri, 14 Mar 2025 13:19:31 -0700 Subject: [PATCH 56/67] fix --- Cargo.toml | 2 +- src/hive/mock.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 9f428f0..cb06524 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,7 +43,7 @@ name = "perf" harness = false [features] -default = ["affinity", "local-batch", "retry"] +default = [] affinity = ["dep:core_affinity"] local-batch = [] retry = [] diff --git a/src/hive/mock.rs b/src/hive/mock.rs index 0db2cd6..010a97b 100644 --- a/src/hive/mock.rs +++ b/src/hive/mock.rs @@ -107,7 +107,7 @@ mod tests { #[test] fn test_works() { let runner = MockTaskRunner::::default(); - let outcome = runner.apply(42); + let outcome = runner.apply(42usize); assert!(matches!( outcome, Outcome::SuccessWithSubtasks { From ade5e5366a0897f2da423fafe22fbf1c36d64261 Mon Sep 17 00:00:00 2001 From: jdidion Date: Fri, 14 Mar 2025 13:29:51 -0700 Subject: [PATCH 57/67] fix --- src/hive/inner/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hive/inner/mod.rs b/src/hive/inner/mod.rs index 504cf8d..a9b5946 100644 --- a/src/hive/inner/mod.rs +++ b/src/hive/inner/mod.rs @@ -38,7 +38,7 @@ type Any = AtomicOption>; type Usize = AtomicOption; #[cfg(feature = "retry")] type U8 = AtomicOption; -#[cfg(feature = "retry")] +#[cfg(any(feature = "local-batch", feature = "retry"))] type U64 = AtomicOption; /// Private, zero-size struct used to call private methods in public sealed traits. From f258782581735279aa9dbca2b8eee926df074f55 Mon Sep 17 00:00:00 2001 From: jdidion Date: Fri, 14 Mar 2025 13:41:51 -0700 Subject: [PATCH 58/67] add debugging --- src/hive/inner/queue/channel.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/hive/inner/queue/channel.rs b/src/hive/inner/queue/channel.rs index fece56b..c9a3ae2 100644 --- a/src/hive/inner/queue/channel.rs +++ b/src/hive/inner/queue/channel.rs @@ -406,6 +406,7 @@ mod local_batch { break; } } + println!("batch size: {}", batch_size); } Ok(first) } else { From c9bee1196f67e39828849ef80c86ae2bc41dcc35 Mon Sep 17 00:00:00 2001 From: jdidion Date: Fri, 14 Mar 2025 13:58:24 -0700 Subject: [PATCH 59/67] fix test --- src/hive/mod.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/hive/mod.rs b/src/hive/mod.rs index 4148b81..abbb5bf 100644 --- a/src/hive/mod.rs +++ b/src/hive/mod.rs @@ -2450,7 +2450,9 @@ mod weighted_map_tests { const WEIGHT: u32 = 25; const WEIGHT_LIMIT: u64 = WEIGHT as u64 * NUM_TASKS_PER_THREAD as u64; // schedule 2 * NUM_THREADS tasks, and set the weight limit at 2 * task weight, such that, - // even though the batch size is > 2, each thread should only take 2 tasks + // even though the batch size is > 2, each thread should only take 2 tasks; don't start any + // threads yet, as there can be a delay before the tasks are available that can lead to + // the first call to `try_pop` queuing a smaller than expected batch let hive = builder_factory(false) .with_worker(RetryCaller::from( |i: u8, ctx: &Context| -> Result<(u8, Option), ApplyError> { @@ -2458,10 +2460,11 @@ mod weighted_map_tests { Ok((i, ctx.thread_index())) }, )) - .num_threads(NUM_THREADS) .batch_limit(BATCH_LIMIT) .weight_limit(WEIGHT_LIMIT) .build(); + thread::sleep(Duration::from_secs(1)); // wait for tasks to be scheduled + assert_eq!(hive.grow(NUM_THREADS).unwrap(), NUM_THREADS); let inputs = (0..NUM_TASKS as u8).map(|i| (i, WEIGHT)).into_weighted(); let (mut outputs, thread_indices) = hive .map(inputs) From 92928a58d554b65c77c42bafbf8a0916e1047884 Mon Sep 17 00:00:00 2001 From: jdidion Date: Fri, 14 Mar 2025 14:07:49 -0700 Subject: [PATCH 60/67] fix test --- Cargo.toml | 2 +- src/hive/mod.rs | 19 ++++++++++++------- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index cb06524..d25a111 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,7 +43,7 @@ name = "perf" harness = false [features] -default = [] +default = ["local-batch"] affinity = ["dep:core_affinity"] local-batch = [] retry = [] diff --git a/src/hive/mod.rs b/src/hive/mod.rs index abbb5bf..2c2e35b 100644 --- a/src/hive/mod.rs +++ b/src/hive/mod.rs @@ -2402,8 +2402,8 @@ mod weighted_map_tests { use crate::bee::stock::{RetryCaller, Thunk, ThunkWorker}; use crate::bee::{ApplyError, Context}; use crate::hive::{ - Builder, Outcome, TaskQueuesBuilder, Weighted, WeightedIteratorExt, channel_builder, - workstealing_builder, + Builder, Outcome, OutcomeIteratorExt, TaskQueuesBuilder, Weighted, WeightedIteratorExt, + channel_builder, workstealing_builder, }; use rstest::*; use std::collections::HashMap; @@ -2463,12 +2463,17 @@ mod weighted_map_tests { .batch_limit(BATCH_LIMIT) .weight_limit(WEIGHT_LIMIT) .build(); - thread::sleep(Duration::from_secs(1)); // wait for tasks to be scheduled - assert_eq!(hive.grow(NUM_THREADS).unwrap(), NUM_THREADS); let inputs = (0..NUM_TASKS as u8).map(|i| (i, WEIGHT)).into_weighted(); - let (mut outputs, thread_indices) = hive - .map(inputs) - .map(Outcome::unwrap) + let (tx, rx) = crate::hive::outcome_channel(); + let task_ids = hive.map_send(inputs, tx); + // wait for tasks to be scheduled + thread::sleep(Duration::from_secs(1)); + assert_eq!(hive.grow(NUM_THREADS).unwrap(), NUM_THREADS); + // wait for all tasks to complete + hive.join(); + let (mut outputs, thread_indices) = rx + .into_iter() + .select_unordered_outputs(task_ids) .unzip::<_, _, Vec<_>, Vec<_>>(); outputs.sort(); assert_eq!(outputs, (0..NUM_TASKS as u8).collect::>()); From cdcec294cd18617073c744aa80c6d784e1c8b01c Mon Sep 17 00:00:00 2001 From: jdidion Date: Fri, 14 Mar 2025 14:41:42 -0700 Subject: [PATCH 61/67] use backoff rather than park --- src/hive/inner/queue/workstealing.rs | 47 +++++++++++++++++++--------- src/hive/mod.rs | 1 - 2 files changed, 32 insertions(+), 16 deletions(-) diff --git a/src/hive/inner/queue/workstealing.rs b/src/hive/inner/queue/workstealing.rs index 08afc96..3ff8647 100644 --- a/src/hive/inner/queue/workstealing.rs +++ b/src/hive/inner/queue/workstealing.rs @@ -5,21 +5,17 @@ //! thread. If the `local-batch` feature is enabled, a worker thread will try to fill its local queue //! up to the limit when stealing from the global queue. use super::{Config, PopTaskError, Status, Task, TaskQueues, Token, WorkerQueues}; -#[cfg(feature = "local-batch")] -use crate::atomic::Atomic; +use crate::atomic::{Atomic, AtomicBool}; use crate::bee::Worker; use crossbeam_deque::{Injector, Stealer}; use crossbeam_queue::SegQueue; +use crossbeam_utils::Backoff; use derive_more::Debug; use nanorand::{Rng, tls as rand}; use parking_lot::RwLock; +use std::any; use std::ops::Deref; use std::sync::Arc; -use std::time::Duration; -use std::{any, thread}; - -/// Time to wait after trying to pop and finding all queues empty. -const EMPTY_DELAY: Duration = Duration::from_millis(100); /// `TaskQueues` implementation using workstealing. #[derive(Debug)] @@ -115,7 +111,7 @@ impl GlobalQueue { /// /// Returns the task if one is stolen successfully, otherwise snoozes for a bit and then /// returns `PopTaskError::Empty`. Returns `PopTaskError::Closed` if the queue is closed. - fn try_steal_from_worker_or_snooze(&self) -> Result, PopTaskError> { + fn try_steal_from_worker_or_snooze(&self, backoff: &Backoff) -> Result, PopTaskError> { let stealers = self.stealers.read(); let n = stealers.len(); // randomize the stealing order, to prevent always stealing from the same thread @@ -130,7 +126,7 @@ impl GlobalQueue { // TODO: instead try the parking approach used in rust-executors, which seems // more performant under most circumstances // https://github.com/Bathtor/rust-executors/blob/master/executors/src/crossbeam_workstealing_pool.rs#L976 - thread::park_timeout(EMPTY_DELAY); + backoff.snooze(); PopTaskError::Empty } }) @@ -138,11 +134,11 @@ impl GlobalQueue { /// Tries to steal a task from the global queue, otherwise tries to steal a task from another /// worker thread. - fn try_pop_unchecked(&self) -> Result, PopTaskError> { + fn try_pop_unchecked(&self, backoff: &Backoff) -> Result, PopTaskError> { if let Some(task) = self.queue.steal().success() { Ok(task) } else { - self.try_steal_from_worker_or_snooze() + self.try_steal_from_worker_or_snooze(backoff) } } @@ -155,6 +151,7 @@ impl GlobalQueue { local_batch: &crossbeam_deque::Worker>, batch_limit: usize, weight_limit: u64, + backoff: &Backoff, ) -> Result, PopTaskError> { // if we only have a size limit but not a weight limit, use the batch-stealing function // provided by `Injector` @@ -188,7 +185,7 @@ impl GlobalQueue { } return Ok(first); } - self.try_steal_from_worker_or_snooze() + self.try_steal_from_worker_or_snooze(backoff) } fn is_closed(&self) -> bool { @@ -218,6 +215,8 @@ pub struct WorkstealingWorkerQueues { local: crossbeam_deque::Worker>, global: Arc>, shared: Arc>, + backoff: Backoff, + snoozing: AtomicBool, } impl WorkstealingWorkerQueues { @@ -230,6 +229,8 @@ impl WorkstealingWorkerQueues { global: Arc::clone(global), local, shared: Arc::clone(shared), + backoff: Backoff::new(), + snoozing: Default::default(), } } } @@ -240,7 +241,22 @@ impl WorkerQueues for WorkstealingWorkerQueues { } fn try_pop(&self) -> Result, PopTaskError> { - self.shared.try_pop(&self.global, &self.local) + let result = self + .shared + .try_pop(&self.global, &self.local, &self.backoff); + match &result { + Ok(_) | Err(PopTaskError::Closed) if self.snoozing.get() => { + // if the worker has been snoozing and got a task, reset the backoff + self.backoff.reset(); + self.snoozing.set(false); + } + Err(PopTaskError::Empty) => { + // if the queue was empty, the worker must have snoozed + self.snoozing.set(true); + } + _ => (), + }; + result } #[cfg(feature = "retry")] @@ -306,6 +322,7 @@ impl LocalQueueShared { &self, global: &GlobalQueue, local_batch: &crossbeam_deque::Worker>, + backoff: &Backoff, ) -> Result, PopTaskError> { if !global.status.can_pop() { return Err(PopTaskError::Closed); @@ -330,10 +347,10 @@ impl LocalQueueShared { let batch_limit = self.batch_limit.get(); if batch_limit > 0 { let weight_limit = self.weight_limit.get(); - return global.try_refill_and_pop(local_batch, batch_limit, weight_limit); + return global.try_refill_and_pop(local_batch, batch_limit, weight_limit, backoff); } } - global.try_pop_unchecked() + global.try_pop_unchecked(backoff) } fn drain_into(self, tasks: &mut Vec>) { diff --git a/src/hive/mod.rs b/src/hive/mod.rs index 2c2e35b..65bfbfe 100644 --- a/src/hive/mod.rs +++ b/src/hive/mod.rs @@ -2488,7 +2488,6 @@ mod weighted_map_tests { .or_insert(1); counts }); - println!("{:?}", counts); assert!(counts.values().all(|&count| count == NUM_TASKS_PER_THREAD)); } From 5b71f831c3c0793d48a75875fa4849af4803e969 Mon Sep 17 00:00:00 2001 From: jdidion Date: Mon, 17 Mar 2025 10:16:09 -0700 Subject: [PATCH 62/67] fix doc issues --- .github/workflows/ci.yml | 2 +- README.md | 93 +++++++++++++++++++-------------------- src/bee/mod.rs | 4 +- src/hive/builder/bee.rs | 8 ++-- src/hive/builder/full.rs | 14 +++--- src/hive/builder/mod.rs | 6 +-- src/hive/builder/open.rs | 14 ++++-- src/hive/builder/queue.rs | 25 +++-------- src/hive/hive.rs | 6 +-- src/hive/inner/builder.rs | 18 ++++---- src/hive/mod.rs | 54 ++++++++++++----------- src/hive/outcome/mod.rs | 2 +- src/hive/outcome/queue.rs | 4 ++ src/lib.rs | 56 +++++++++++++++++------ 14 files changed, 167 insertions(+), 139 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9331dab..57c57f8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -23,7 +23,7 @@ jobs: cargo clippy --all-targets -F affinity,local-batch,retry \ -- -D warnings $(cat .lints | cut -f1 -d"#" | tr '\n' ' ') - run: cargo fmt -- --check - - run: cargo doc -F affinity,local-batch,retry + - run: RUSTDOCFLAGS="-D warnings" cargo doc -F affinity,local-batch,retry - run: cargo test -F affinity,local-batch,retry --doc coverage: diff --git a/README.md b/README.md index 534f21c..8117cb3 100644 --- a/README.md +++ b/README.md @@ -16,39 +16,37 @@ is sometimes called a "worker pool"). ### Overview * Operations are defined by implementing the [`Worker`](https://docs.rs/beekeeper/latest/beekeeper/bee/worker/trait.Worker.html) trait. -* A [`Builder`](https://docs.rs/beekeeper/latest/beekeeper/hive/builder/struct.Builder.html) is used to configure and create a worker pool - called a [`Hive`](https://docs.rs/beekeeper/latest/beekeeper/hive/struct.Hive.html). +* A [`Builder`](https://docs.rs/beekeeper/latest/beekeeper/hive/builder/trait.Builder.html) is used to configure and create a worker pool called a [`Hive`](https://docs.rs/beekeeper/latest/beekeeper/hive/struct.Hive.html). +* `Hive` is generic over + * The type of [`Queen`](https://docs.rs/beekeeper/latest/beekeeper/bee/queen/trait.Queen.html) which creates `Worker` instances + * The type of [`TaskQueues`](https://docs.rs/beekeeper/latest/beekeeper/hive/trait.TaskQueues.html), which provides the global and worker thread-local queues for managing tasks +* Currently, two `TaskQueues` implementations are available: + * Channel: uses a [`crossbeam`](https://github.com/crossbeam-rs/crossbeam) channel to send tasks from the `Hive` to worker threads + * When the `local-batch` feature is enabled, local batch queues are implemented using [`crossbeam_queue::ArrayQueue`](https://docs.rs/crossbeam/latest/crossbeam/queue/struct.ArrayQueue.html) + * Workstealing: * The `Hive` creates a `Worker` instance for each thread in the pool. * Each thread in the pool continually: - * Recieves a task from an input [`channel`](https://doc.rust-lang.org/stable/std/sync/mpsc/fn.channel.html), + * Receives a task from an input queue, * Calls its `Worker`'s [`apply`](https://docs.rs/beekeeper/latest/beekeeper/bee/worker/trait.Worker.html#method.apply) method on the input, and * Produces an [`Outcome`](https://docs.rs/beekeeper/latest/beekeeper/hive/outcome/outcome/enum.Outcome.html). * Depending on which of `Hive`'s methods are called to submit a task (or batch of tasks), the `Outcome`(s) may be returned as an `Iterator`, sent to an output `channel`, or stored in the `Hive` for later retrieval. * A `Hive` may create `Worker`s may in one of three ways: - * Call the `default()` function on a `Worker` type that implements - [`Default`](std::default::Default) - * Clone an instance of a `Worker` that implements - [`Clone`](std::clone::Clone) - * Call the [`create()`](https://docs.rs/beekeeper/latest/beekeeper/bee/queen/trait.Queen.html#method.create) method on a worker factory that - implements the [`Queen`](https://docs.rs/beekeeper/latest/beekeeper/bee/queen/trait.Queen.html) trait. -* Both `Worker`s and `Queen`s may be stateful, i.e., `Worker::apply()` and `Queen::create()` - both take `&mut self`. + * Call the `default()` function on a `Worker` type that implements [`Default`](std::default::Default) + * Clone an instance of a `Worker` that implements [`Clone`](std::clone::Clone) + * Call the [`create()`](https://docs.rs/beekeeper/latest/beekeeper/bee/queen/trait.Queen.html#method.create) method on a worker factory that implements the [`Queen`](https://docs.rs/beekeeper/latest/beekeeper/bee/queen/trait.Queen.html) trait. +* A `Worker`s may be stateful, i.e., `Worker::apply()` takes a `&mut self` +* While `Queen` is not stateful, [`QueenMut`](https://docs.rs/beekeeper/latest/beekeeper/bee/queen/trait.QueenMut.html) may be (i.e., it's `create()` method takes a `&mut self`) * Although it is strongly recommended to avoid `panic`s in worker threads (and thus, within `Worker` implementations), the `Hive` does automatically restart any threads that panic. -* A `Hive` may be [`suspend`](https://docs.rs/beekeeper/latest/beekeeper/hive/struct.Hive.html#method.suspend)ed and - [`resume`](https://docs.rs/beekeeper/latest/beekeeper/hive/struct.Hive.html#method.resume)d at any time. When a `Hive` is suspended, worker threads - do no work and tasks accumulate in the input `channel`. -* Several utility functions are provided in the [util](https://docs.rs/beekeeper/latest/beekeeper/util/) module. Notably, the `map` - and `try_map` functions enable simple parallel processing of a single batch of tasks. -* Several useful `Worker` implementations are provided in the [stock](https://docs.rs/beekeeper/latest/beekeeper/bee/stock/) module. - Most notable are those in the [`call`](https://docs.rs/beekeeper/latest/beekeeper/bee/stock/call/) submodule, which provide - different ways of wrapping `callable`s, i.e., closures and function pointers. +* A `Hive` may be [`suspend`](https://docs.rs/beekeeper/latest/beekeeper/hive/struct.Hive.html#method.suspend)ed and [`resume`](https://docs.rs/beekeeper/latest/beekeeper/hive/struct.Hive.html#method.resume)d at any time. When a `Hive` is suspended, worker threads do no work and tasks accumulate in the input queue. +* Several utility functions are provided in the [util](https://docs.rs/beekeeper/latest/beekeeper/util/) module. Notably, the `map` and `try_map` functions enable simple parallel processing of a single batch of tasks. +* Several useful `Worker` implementations are provided in the [stock](https://docs.rs/beekeeper/latest/beekeeper/bee/stock/) module. Most notable are those in the [`call`](https://docs.rs/beekeeper/latest/beekeeper/bee/stock/call/) submodule, which provide different ways of wrapping `callable`s, i.e., closures and function pointers. * The following optional features are provided via feature flags: - * `affinity`: worker threads may be pinned to CPU cores to minimize the overhead of - context-switching. - * `local-batch` (>=0.3.0): worker threads take batches of tasks from the input channel and queue them locally, which may alleviate thread contention, especially when there are many short-lived tasks. + * `affinity`: worker threads may be pinned to CPU cores to minimize the overhead of context-switching. + * `local-batch` (>=0.3.0): worker threads take batches of tasks from the global input queue and add them to a local queue, which may alleviate thread contention, especially when there are many short-lived tasks. + * Tasks may be [`Weighted`](https://docs.rs/beekeeper/latest/beekeeper/hive/weighted/struct.Weighted.html) to enable balancing unevenly sized tasks between worker threads. * `retry`: Tasks that fail due to transient errors (e.g., temporarily unavailable resources) may be retried a set number of times, with an optional, exponentially increasing delay between retries. @@ -65,31 +63,27 @@ To parallelize a task, you'll need two things: * Implement your own (See Example 3 below) * `use` the necessary traits (e.g., `use beekeeper::bee::prelude::*`) * Define a `struct` for your worker - * Implement the `Worker` trait on your struct and define the `apply` method with the - logic of your task + * Implement the `Worker` trait on your struct and define the `apply` method with the logic of your task * Do at least one of the following: * Implement `Default` for your worker * Implement `Clone` for your worker - * Create a custom worker fatory that implements the `Queen` trait + * Create a custom worker fatory that implements the `Queen` or `QueenMut` trait 2. A `Hive` to execute your tasks. Your options are: * Use one of the convenience methods in the [util](https://docs.rs/beekeeper/latest/beekeeper/util/) module (see Example 1 below) - * Create a `Hive` manually using [`Builder`](https://docs.rs/beekeeper/latest/beekeeper/hive/builder/struct.Builder.html) (see Examples 2 - and 3 below) - * [`Builder::new()`](https://docs.rs/beekeeper/latest/beekeeper/hive/builder/struct.Builder.html#method.new) creates an empty `Builder` - * [`Builder::default()`](https://docs.rs/beekeeper/latest/beekeeper/hive/builder/struct.Builder.html#method.default) creates a `Builder` + * Create a `Hive` manually using a [`Builder`](https://docs.rs/beekeeper/latest/beekeeper/hive/builder/trait.Builder.html) (see Examples 2 and 3 below) + * [`OpenBuilder`](https://docs.rs/beekeeper/latest/beekeeper/hive/struct.OpenBuilder.html) is the most general builder + * [`OpenBuilder::new()`](https://docs.rs/beekeeper/latest/beekeeper/hive/struct.OpenBuilder.html#method.new) creates an empty `OpenBuilder` + * [`Builder::default()`](https://docs.rs/beekeeper/latest/beekeeper/hive/struct.OpenBuilder.html#method.default) creates a `OpenBuilder` with the global default settings (which may be changed using the functions in the [`hive`](https://docs.rs/beekeeper/latest/beekeeper/hive/) module, e.g., `beekeeper::hive::set_num_threads_default(4)`). - * Use one of the `build_*` methods to build the `Hive`: - * If you have a `Worker` that implements `Default`, use - [`build_with_default::()`](https://docs.rs/beekeeper/latest/beekeeper/hive/builder/struct.Builder.html#method.build_with_default) - * If you have a `Worker` that implements `Clone`, use - [`build_with(MyWorker::new())`](https://docs.rs/beekeeper/latest/beekeeper/hive/builder/struct.Builder.html#method.build_with) - * If you have a custom `Queen`, use - [`build_default::()`](https://docs.rs/beekeeper/latest/beekeeper/hive/builder/struct.Builder.html#method.build_default) if it implements - `Default`, otherwise use [`build(MyQueen::new())`](https://docs.rs/beekeeper/latest/beekeeper/hive/builder/struct.Builder.html#method.build) - * Note that [`Builder::num_threads()`](https://docs.rs/beekeeper/latest/beekeeper/hive/builder/struct.Builder.html#method.num_threads) must be set - to a non-zero value, otherwise the built `Hive` will not start any worker threads - until you call the [`Hive::grow()`](https://docs.rs/beekeeper/latest/beekeeper/hive/struct.Hive.html#method.grow) method. + * The builder must be specialized for the `Queen` and `TaskQueues` types: + * If you have a `Worker` that implements `Default`, use [`with_worker_default::()`](https://docs.rs/beekeeper/latest/beekeeper/hive/struct.OpenBuilder.html#method.with_worker_default) + * If you have a `Worker` that implements `Clone`, use [`with_worker(MyWorker::new())`](https://docs.rs/beekeeper/latest/beekeeper/hive/struct.OpenBuilder.html#method.with_worker) + * If you have a custom `Queen`, use [`with_queen_default::()`](https://docs.rs/beekeeper/latest/beekeeper/hive/struct.OpenBuilder.html#method.with_queen_default) if it implements `Default`, otherwise use [`with_queen(MyQueen::new())`](https://docs.rs/beekeeper/latest/beekeeper/hive/struct.OpenBuilder.html#method.with_queen) + * If you have a custom `QueenMut`, use [`with_queen_mut_default::()`](https://docs.rs/beekeeper/latest/beekeeper/hive/struct.OpenBuilder.html#method.with_queen_mut_default) if it implements `Default`, otherwise use [`with_queen_mut(MyQueenMut::new())`](https://docs.rs/beekeeper/latest/beekeeper/hive/struct.OpenBuilder.html#method.with_queen_mut) + * Use the [`with_channel_queues`](https://docs.rs/beekeeper/latest/beekeeper/hive/struct.OpenBuilder.html#method.with_channel_queues.html) or [`with_workstealing_queues`](https://docs.rs/beekeeper/latest/beekeeper/hive/struct.OpenBuilder.html#method.with_workstealing_queues.html) to configure the `TaskQueues` implementation + * Use the `build()` methods to build the `Hive` + * Note that [`Builder::num_threads()`](https://docs.rs/beekeeper/latest/beekeeper/hive/builder/struct.Builder.html#method.num_threads) must be set to a non-zero value, otherwise the built `Hive` will not start any worker threads until you call the [`Hive::grow()`](https://docs.rs/beekeeper/latest/beekeeper/hive/struct.Hive.html#method.grow) method. Once you've created a `Hive`, use its methods to submit tasks for processing. There are four groups of methods available: @@ -98,7 +92,7 @@ four groups of methods available: that implements `IntoIterator`) * `map`: submits an arbitrary batch of tasks (i.e., anything that implements `IntoIterator`) * `scan`: Similar to `map`, but you also provide 1) an initial value for a state variable, and - 2) a function that transforms each item in the input iterator into the input type required by + 1) a function that transforms each item in the input iterator into the input type required by the `Worker`, and also has access to (and may modify) the state variable. There are multiple methods in each group that differ by how the task results (called @@ -170,14 +164,14 @@ let hive = Builder::new() // return results to your own channel... let (tx, rx) = outcome_channel(); let _ = hive.swarm_send( - (0..10).map(|i: i32| Thunk::of(move || i * i)), + (0..10).map(|i: i32| Thunk::from(move || i * i)), tx ); assert_eq!(285, rx.into_outputs().take(10).sum()); // return results as an iterator... let total = hive - .swarm_unordered((0..10).map(|i: i32| Thunk::of(move || i * -i))) + .swarm_unordered((0..10).map(|i: i32| Thunk::from(move || i * -i))) .into_outputs() .sum(); assert_eq!(-285, total); @@ -230,7 +224,7 @@ impl Worker for CatWorker { fn apply( &mut self, input: Self::Input, - _: &Context + _: &Context ) -> WorkerResult { self.write_char(input).map_err(|error| { ApplyError::Fatal { input: Some(input), error } @@ -252,7 +246,7 @@ impl CatQueen { } } -impl Queen for CatQueen { +impl QueenMut for CatQueen { type Kind = CatWorker; fn create(&mut self) -> Self::Kind { @@ -288,7 +282,7 @@ impl Drop for CatQueen { // build the Hive let hive = Builder::new() .num_threads(4) - .build_default::() + .build_default_mut::() .unwrap(); // prepare inputs @@ -311,8 +305,11 @@ assert_eq!(output, b"abcdefgh"); // shutdown the hive, use the Queen to wait on child processes, and // report errors let (mut queen, _outcomes) = hive.try_into_husk().unwrap().into_parts(); -let (wait_ok, wait_err): (Vec<_>, Vec<_>) = - queen.wait_for_all().into_iter().partition(Result::is_ok); +let (wait_ok, wait_err): (Vec<_>, Vec<_>) = queen + .into_inner() + .wait_for_all() + .into_iter() + .partition(Result::is_ok); if !wait_err.is_empty() { panic!( "Error(s) occurred while waiting for child processes: {:?}", diff --git a/src/bee/mod.rs b/src/bee/mod.rs index c42bc21..b1be414 100644 --- a/src/bee/mod.rs +++ b/src/bee/mod.rs @@ -87,9 +87,9 @@ //! //! It is often not necessary to manually implement the `Queen` trait. For exmaple, if your `Worker` //! implements `Default`, then you can use [`DefaultQueen`] implicitly by calling -//! [`Builder::build_with_default`](crate::hive::Builder::build_with_default). Similarly, +//! [`OpenBuilder::with_worker_default`](crate::hive::OpenBuilder::with_worker_default). Similarly, //! if your `Worker` implements `Clone`, then you can use [`CloneQueen`] -//! implicitly by calling [`Builder::build_with`](crate::hive::Builder::build_with). +//! implicitly by calling [`OpenBuilder::with_worker`](crate::hive::OpenBuilder::with_worker). //! //! A `Queen` should never panic when creating `Worker`s. //! diff --git a/src/hive/builder/bee.rs b/src/hive/builder/bee.rs index d666ed1..a393b84 100644 --- a/src/hive/builder/bee.rs +++ b/src/hive/builder/bee.rs @@ -14,7 +14,7 @@ pub struct BeeBuilder { impl BeeBuilder { /// Creates a new `BeeBuilder` with the given queen and no options configured. - pub fn empty>(queen: Q) -> Self { + pub fn empty(queen: Q) -> Self { Self { config: Config::empty(), queen, @@ -23,7 +23,7 @@ impl BeeBuilder { /// Creates a new `BeeBuilder` with the given `queen` and options configured with global /// preset values. - pub fn preset>(queen: Q) -> Self { + pub fn preset(queen: Q) -> Self { Self { config: Config::default(), queen, @@ -182,8 +182,8 @@ mod tests { #[rstest] fn test_queen( #[values( - BeeBuilder::::empty::, - BeeBuilder::::preset:: + BeeBuilder::::empty, + BeeBuilder::::preset )] factory: F, #[values( diff --git a/src/hive/builder/full.rs b/src/hive/builder/full.rs index 2662bcd..4865420 100644 --- a/src/hive/builder/full.rs +++ b/src/hive/builder/full.rs @@ -17,7 +17,7 @@ pub struct FullBuilder> { impl> FullBuilder { /// Creates a new `FullBuilder` with the given queen and no options configured. - pub fn empty>(queen: Q) -> Self { + pub fn empty(queen: Q) -> Self { Self { config: Config::empty(), queen, @@ -27,10 +27,10 @@ impl> FullBuilder { /// Creates a new `FullBuilder` with the given `queen` and options configured with global /// defaults. - pub fn preset>(queen: I) -> Self { + pub fn preset(queen: Q) -> Self { Self { config: Config::default(), - queen: queen.into(), + queen, _queues: PhantomData, } } @@ -91,8 +91,8 @@ mod tests { #[rstest] fn test_channel( #[values( - FullBuilder::>>::empty::, - FullBuilder::>>::preset:: + FullBuilder::>>::empty, + FullBuilder::>>::preset )] factory: F, ) where @@ -105,8 +105,8 @@ mod tests { #[rstest] fn test_workstealing( #[values( - FullBuilder::>>::empty::, - FullBuilder::>>::preset:: + FullBuilder::>>::empty, + FullBuilder::>>::preset )] factory: F, ) where diff --git a/src/hive/builder/mod.rs b/src/hive/builder/mod.rs index e8b2269..e2dcb3e 100644 --- a/src/hive/builder/mod.rs +++ b/src/hive/builder/mod.rs @@ -31,8 +31,8 @@ //! [`ApplyError::Retryable`](crate::bee::ApplyError#Retryable) before giving up. //! * [`Builder::retry_factor`]: [`Duration`](std::time::Duration) factor for exponential backoff //! when retrying an `ApplyError::Retryable` error. -//! * [`Builder::with_default_retries`] sets the retry options to the global defaults, while -//! [`Builder::with_no_retries`] disabled retrying. +//! * [`Builder::with_default_max_retries`] and [`Builder::with_default_retry_factor`] set the +//! retry options to the global defaults, while [`Builder::with_no_retries`] disabled retrying. //! //! The following configuration options are available when the `affinity` feature is enabled: //! * [`Builder::core_affinity`]: List of CPU core indices to which the threads should be pinned. @@ -51,7 +51,7 @@ pub use queue::TaskQueuesBuilder; pub use queue::channel::ChannelBuilder; pub use queue::workstealing::WorkstealingBuilder; -use crate::hive::inner::{BuilderConfig, Token}; +use crate::hive::inner::{Builder, BuilderConfig, Token}; /// Creates a new `OpenBuilder`. If `with_defaults` is `true`, the builder will be pre-configured /// with the global defaults. diff --git a/src/hive/builder/open.rs b/src/hive/builder/open.rs index ea81a06..5600870 100644 --- a/src/hive/builder/open.rs +++ b/src/hive/builder/open.rs @@ -23,8 +23,8 @@ use crate::hive::Config; /// /// # Examples /// -/// Build a [`Hive`] that uses a maximum of eight threads simultaneously and each thread has -/// a 8 MB stack size: +/// Build a [`Hive`](crate::hive::Hive) that uses a maximum of eight threads simultaneously and +/// each thread has a 8 MB stack size: /// /// ``` /// # use beekeeper::hive::{Builder, OpenBuilder}; @@ -116,8 +116,8 @@ impl OpenBuilder { /// assert_eq!(husk.queen().get().num_workers, 8); /// # } /// ``` - pub fn with_queen>(self, queen: I) -> BeeBuilder { - BeeBuilder::from_config_and_queen(self.0, queen.into()) + pub fn with_queen(self, queen: Q) -> BeeBuilder { + BeeBuilder::from_config_and_queen(self.0, queen) } /// Consumes this `Builder` and returns a new [`BeeBuilder`] using a [`Queen`] created with @@ -126,6 +126,12 @@ impl OpenBuilder { BeeBuilder::from_config_and_queen(self.0, Q::default()) } + /// Consumes this `Builder` and returns a new [`BeeBuilder`] using a [`QueenCell`] wrapping + /// the given [`QueenMut`] to create [`Worker`]s. + pub fn with_queen_mut(self, queen: Q) -> BeeBuilder> { + BeeBuilder::from_config_and_queen(self.0, QueenCell::new(queen)) + } + /// Consumes this `Builder` and returns a new [`BeeBuilder`] using a [`QueenMut`] created with /// [`Q::default()`](std::default::Default) to create [`Worker`]s. pub fn with_queen_mut_default(self) -> BeeBuilder> { diff --git a/src/hive/builder/queue.rs b/src/hive/builder/queue.rs index 02b0ece..b3c8f0a 100644 --- a/src/hive/builder/queue.rs +++ b/src/hive/builder/queue.rs @@ -1,6 +1,6 @@ -use super::FullBuilder; +use super::{Builder, FullBuilder}; use crate::bee::{CloneQueen, DefaultQueen, Queen, QueenCell, QueenMut, Worker}; -use crate::hive::{Builder, TaskQueues}; +use crate::hive::TaskQueues; /// Trait implemented by builders specialized to a `TaskQueues` type. pub trait TaskQueuesBuilder: Builder + Clone + Default + Sized { @@ -12,10 +12,7 @@ pub trait TaskQueuesBuilder: Builder + Clone + Default + Sized { /// Consumes this `Builder` and returns a new [`FullBuilder`] using the given [`Queen`] to /// create [`Worker`]s. - fn with_queen>( - self, - queen: I, - ) -> FullBuilder>; + fn with_queen(self, queen: Q) -> FullBuilder>; /// Consumes this `Builder` and returns a new [`FullBuilder`] using a [`Queen`] created with /// [`Q::default()`](std::default::Default) to create [`Worker`]s. @@ -78,12 +75,8 @@ pub mod channel { /// Consumes this `Builder` and returns a new [`FullBuilder`] using the given [`Queen`] to /// create [`Worker`]s. - fn with_queen(self, queen: I) -> FullBuilder> - where - Q: Queen, - I: Into, - { - FullBuilder::from_config_and_queen(self.0, queen.into()) + fn with_queen(self, queen: Q) -> FullBuilder> { + FullBuilder::from_config_and_queen(self.0, queen) } } @@ -118,12 +111,8 @@ pub mod workstealing { /// Consumes this `Builder` and returns a new [`FullBuilder`] using the given [`Queen`] to /// create [`Worker`]s. - fn with_queen(self, queen: I) -> FullBuilder> - where - Q: Queen, - I: Into, - { - FullBuilder::from_config_and_queen(self.0, queen.into()) + fn with_queen(self, queen: Q) -> FullBuilder> { + FullBuilder::from_config_and_queen(self.0, queen) } } diff --git a/src/hive/hive.rs b/src/hive/hive.rs index 182a702..7bc886f 100644 --- a/src/hive/hive.rs +++ b/src/hive/hive.rs @@ -447,7 +447,7 @@ impl, T: TaskQueues> Hive { /// Returns the number of worker threads that have been requested, i.e., the maximum number of /// tasks that could be processed concurrently. This may be greater than - /// [`active_workers`](Self::active_workers) if any of the worker threads failed to start. + /// [`alive_workers`](Self::alive_workers) if any of the worker threads failed to start. pub fn max_workers(&self) -> usize { self.shared().num_threads() } @@ -491,8 +491,8 @@ impl, T: TaskQueues> Hive { /// corrupted such that it is no longer able to process tasks. /// /// Note that, when a `Hive` is poisoned, it is still possible to call methods that extract - /// its stored [`Outcome`]s (e.g., [`take_stored`](Self::take_stored)) or consume it (e.g., - /// [`try_into_husk`](Self::try_into_husk)). + /// its stored [`Outcome`]s (e.g., [`remove_all`](crate::hive::OutcomeStore::remove_all)) or + /// consume it (e.g., [`try_into_husk`](Self::try_into_husk)). pub fn is_poisoned(&self) -> bool { self.shared().is_poisoned() } diff --git a/src/hive/inner/builder.rs b/src/hive/inner/builder.rs index 7bfb00c..93d84a9 100644 --- a/src/hive/inner/builder.rs +++ b/src/hive/inner/builder.rs @@ -11,8 +11,8 @@ pub trait BuilderConfig { /// This is a sealed trait, meaning it cannot be implemented outside of this crate. pub trait Builder: BuilderConfig + Sized { /// Sets the maximum number of worker threads that will be alive at any given moment in the - /// built [`Hive`]. If not specified, the built `Hive` will not be initialized with worker - /// threads until [`Hive::grow`] is called. + /// built [`Hive`](crate::hive::Hive). If not specified, the built `Hive` will not be + /// initialized with worker threads until [`Hive::grow`](crate::hive::Hive::grow) is called. /// /// # Examples /// @@ -49,7 +49,8 @@ pub trait Builder: BuilderConfig + Sized { self } - /// Specifies that the built [`Hive`] will use all available CPU cores for worker threads. + /// Specifies that the built [`Hive`](crate::hive::Hive) will use all available CPU cores for + /// worker threads. /// /// # Examples /// @@ -80,8 +81,9 @@ pub trait Builder: BuilderConfig + Sized { self } - /// Sets the thread name for each of the threads spawned by the built [`Hive`]. If not - /// specified, threads spawned by the thread pool will be unnamed. + /// Sets the thread name for each of the threads spawned by the built + /// [`Hive`](crate::hive::Hive). If not specified, threads spawned by the thread pool will be + /// unnamed. /// /// # Examples /// @@ -111,9 +113,9 @@ pub trait Builder: BuilderConfig + Sized { self } - /// Sets the stack size (in bytes) for each of the threads spawned by the built [`Hive`]. - /// If not specified, threads spawned by the hive will have a stack size [as specified in - /// the `std::thread` documentation][thread]. + /// Sets the stack size (in bytes) for each of the threads spawned by the built + /// [`Hive`](crate::hive::Hive). If not specified, threads spawned by the hive will have a + /// stack size [as specified in the `std::thread` documentation][thread]. /// /// [thread]: https://doc.rust-lang.org/nightly/std/thread/index.html#stack-size /// diff --git a/src/hive/mod.rs b/src/hive/mod.rs index 65bfbfe..77cb443 100644 --- a/src/hive/mod.rs +++ b/src/hive/mod.rs @@ -2,8 +2,8 @@ //! //! A [`Hive`](crate::hive::Hive) has a pool of worker threads that it uses to execute tasks. //! -//! The `Hive` has a [`Queen`] of type `Q`, which it uses to create a [`Worker`] of type `W` for -//! each thread it starts in the pool. +//! The `Hive` has a [`Queen`](crate::bee::Queen) of type `Q`, which it uses to create a +//! [`Worker`] of type `W` for each thread it starts in the pool. //! //! Each task is submitted to the `Hive` as an input of type `W::Input`, and, optionally, a //! channel where the [`Outcome`] of processing the task will be sent upon completion. To these, @@ -21,12 +21,12 @@ //! # Creating a `Hive` //! //! The typical way to create a `Hive` is using a [`Builder`]. Use -//! [`Builder::new()`](crate::hive::builder::Builder::new) to create an empty (completely -//! unconfigured) `Builder`, or [`Builder::default()`](crate::hive::builder::Builder::default) to +//! [`OpenBuilder::empty()`](crate::hive::OpenBuilder::empty) to create an empty (completely +//! unconfigured) `Builder`, or [`OpenBuilder::default()`](crate::hive::OpenBuilder::default) to //! create a `Builder` configured with the global default values (see below). //! -//! See the [`Builder`] documentation for more details on the options that may be configured, and -//! the `build*` methods available to create the `Hive`. +//! See the [`builder` module documentation](crate::hive::builder) for more details on the options +//! that may be configured, and the `build*` methods available to create the `Hive`. //! //! Building a `Hive` consumes the `Builder`. To create multiple identical `Hive`s, you can `clone` //! the `Builder`. @@ -71,12 +71,11 @@ //! terms of its index, which is a value in the range `0..n`, where `n` is the number of available //! CPU cores. Internally, a mapping is maintained between the index and the OS-specific core ID. //! -//! The [`Builder::core_affinity`](crate::hive::builder::Builder::core_affinity) method accepts a -//! range of core indices that are reserved as *available* for the `Hive` to use for thread-pinning, -//! but they may or may not actually be used (depending on the number of worker threads and core -//! availability). The number of available cores can be smaller or larger than the number of -//! threads. Any thread that is spawned for which there is no corresponding core index is simply -//! started with no core affinity. +//! The [`Builder::core_affinity`] method accepts a range of core indices that are reserved as +//! *available* for the `Hive` to use for thread-pinning, but they may or may not actually be used +//! (depending on the number of worker threads and core availability). The number of available +//! cores can be smaller or larger than the number of threads. Any thread that is spawned for which +//! there is no corresponding core index is simply started with no core affinity. //! //! ``` //! use beekeeper::hive::prelude::*; @@ -115,7 +114,7 @@ //! to return [`ApplyError::Retryable`](crate::bee::ApplyError::Retryable) for transient failures. //! //! When a `Retryable` error occurs, the following steps happen: -//! * The `attempt` number in the task's [`Context`] is incremented. +//! * The `attempt` number in the task's [`Context`](crate::bee::Context) is incremented. //! * If the `attempt` number exceeds `max_retries`, the error is converted to //! `Outcome::MaxRetriesAttempted` and sent/stored. //! * Otherwise, the task is added to the `Hive`'s retry queue. @@ -140,8 +139,9 @@ //! //! With the `local-batch` feature enabled, `Builder` gains the //! [`batch_limit`](crate::hive::Builder::batch_limit) method for configuring size of worker threads' -//! local queues, and `Hive` gains the [`set_worker_batch_limit`](crate::hive::Hive::set_batch_limit) -//! method for changing the batch size of an existing `Hive`. +//! local queues, and `Hive` gains the +//! [`set_worker_batch_limit`](crate::hive::Hive::set_worker_batch_limit) method for changing the +//! batch size of an existing `Hive`. //! //! ## Global defaults //! @@ -154,9 +154,10 @@ //! * `num_threads` //! * [`set_num_threads_default`]: sets the default to a specific value //! * [`set_num_threads_default_all`]: sets the default to all available CPU cores -//! * [`batch_limit`](crate::hive::set_BATCH_LIMIT_default) (requires `feature = "local-batch"`) -//! * [`max_retries`](crate::hive::set_max_retries_default] (requires `feature = "retry"`) -//! * [`retry_factor`](crate::hive::set_retry_factor_default] (requires `feature = "retry"`) +//! * [`batch_limit`](crate::hive::set_batch_limit_default) (requires `feature = "local-batch"`) +//! * [`weight_limit`](crate::hive::set_weight_limit_default) (requires `feature = "local-batch"`) +//! * [`max_retries`](crate::hive::set_max_retries_default) (requires `feature = "retry"`) +//! * [`retry_factor`](crate::hive::set_retry_factor_default) (requires `feature = "retry"`) //! //! The global defaults can be reset their original values using the [`reset_defaults`] function. //! @@ -164,9 +165,10 @@ //! //! A `Hive` is simply a wrapper around a data structure that is shared between the `Hive`, its //! worker threads, and any clones that have been made of the `Hive`. In other works, cloning a -//! `Hive` simply creates another reference to the same shared data (similar to cloning an [`Arc`]). -//! The worker threads and the shared data structure are dropped automatically when the last `Hive` -//! referring to them is dropped (see "Disposing of a Hive" below). +//! `Hive` simply creates another reference to the same shared data (similar to cloning an +//! [`Arc`](std::sync::Arc)). The worker threads and the shared data structure are dropped +//! automatically when the last `Hive` referring to them is dropped (see "Disposing of a Hive" +//! below). //! //! # Submitting tasks //! @@ -339,8 +341,8 @@ //! outcomes. //! //! Processing can be resumed by calling the [`resume`](crate::hive::Hive::resume) method. -//! Alternatively, the [`resume_send`](crate::hive::Hive::resume_send) or -//! [`resume_store`](crate::hive::Hive::resume_store) method can be used to both resume and +//! The [`swarm_unprocessed_send`](crate::hive::Hive::swarm_unprocessed_send) or +//! [`swarm_unprocessed_store`](crate::hive::Hive::swarm_unprocessed_store) methods can be used to //! submit any unprocessed tasks stored in the `Hive` for (re)processing. //! //! ## Hive poisoning @@ -380,7 +382,7 @@ //! The `Husk` can be used to create a new `Builder` //! ([`Husk::as_builder`](crate::hive::husk::Husk::as_builder)) or a new `Hive` //! ([`Husk::into_hive`](crate::hive::husk::Husk::into_hive)). -mod builder; +pub mod builder; mod context; #[cfg(feature = "affinity")] pub mod cores; @@ -404,14 +406,14 @@ pub use self::cores::{Core, Cores}; pub use self::hive::{DefaultHive, Hive, Poisoned}; pub use self::husk::Husk; pub use self::inner::{ - Builder, ChannelTaskQueues, TaskInput, WorkstealingTaskQueues, set_config::*, + Builder, ChannelTaskQueues, TaskInput, TaskQueues, WorkstealingTaskQueues, set_config::*, }; pub use self::outcome::{Outcome, OutcomeBatch, OutcomeIteratorExt, OutcomeStore}; #[cfg(feature = "local-batch")] pub use self::weighted::{Weighted, WeightedExactSizeIteratorExt, WeightedIteratorExt}; use self::context::HiveLocalContext; -use self::inner::{Config, Shared, Task, TaskQueues, WorkerQueues}; +use self::inner::{Config, Shared, Task, WorkerQueues}; use self::outcome::{DerefOutcomes, OutcomeQueue, OwnedOutcomes}; use self::sentinel::Sentinel; use crate::bee::Worker; diff --git a/src/hive/outcome/mod.rs b/src/hive/outcome/mod.rs index b9f9f9f..83f8907 100644 --- a/src/hive/outcome/mod.rs +++ b/src/hive/outcome/mod.rs @@ -18,7 +18,7 @@ use derive_more::Debug; /// The possible outcomes of a task execution. /// /// Each outcome includes the task ID of the task that produced it. Tasks that submitted -/// subtasks (via [`crate::bee::Context::submit_task`]) produce `Outcome` variants that have +/// subtasks (via [`crate::bee::Context::submit`]) produce `Outcome` variants that have /// `subtask_ids`. /// /// Note that `Outcome`s can only be compared or ordered with other `Outcome`s produced by the same diff --git a/src/hive/outcome/queue.rs b/src/hive/outcome/queue.rs index e2b11d7..9ed7b70 100644 --- a/src/hive/outcome/queue.rs +++ b/src/hive/outcome/queue.rs @@ -8,6 +8,10 @@ use std::ops::{Deref, DerefMut}; /// Data structure that supports queuing `Outcomes` from multiple threads (without locking) and /// fetching from a single thread (which requires draining the queue into a map that is behind a /// mutex). +/// +/// TODO: test vs using a +/// [`SkipMap`](https://docs.rs/crossbeam-skiplist/latest/crossbeam_skiplist/struct.SkipMap.html) or +/// [`DashMap`](https://docs.rs/dashmap/latest/dashmap/struct.DashMap.html) pub struct OutcomeQueue { queue: SegQueue>, outcomes: Mutex>>, diff --git a/src/lib.rs b/src/lib.rs index 0238b9b..5b12532 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,9 +8,24 @@ //! * Operations are defined by implementing the [`Worker`](crate::bee::Worker) trait. //! * A [`Builder`](crate::hive::Builder) is used to configure and create a worker pool //! called a [`Hive`](crate::hive::Hive). +//! * `Hive` is generic over: +//! * The type of [`Queen`](crate::bee::Queen) which creates `Worker` instances +//! * The type of [`TaskQueues`](crate::hive::TaskQueues), which provides the global and +//! worker thread-local queues for managing tasks; there are currently two implementations: +//! * Channel: A +//! [`crossbeam` channel](https://docs.rs/crossbeam-channel/latest/crossbeam_channel/) +//! is used to send tasks from the `Hive` to the worker threads. *This is a good choice +//! for most workloads*. +//! * Workstealing: A +//! [`crossbeam_dequeue::Injector`](https://docs.rs/crossbeam-deque/latest/crossbeam_deque/struct.Injector.html) +//! is used to submit tasks and serves as a global queue. Worker threads each have their +//! own local queue and can take tasks either from the global queue or steal from other +//! workers' local queues if their own queue is empty. This is a good choice for workloads +//! that are either highly variable from task to task (in terms of processing time), or +//! are fork-join in nature (i.e., tasks that submit sub-tasks). //! * The `Hive` creates a `Worker` instance for each thread in the pool. //! * Each thread in the pool continually: -//! * Recieves a task from an input [`channel`](::std::sync::mpsc::channel), +//! * Receives a task from an input queue, //! * Calls its `Worker`'s [`apply`](crate::bee::Worker::apply) method on the input, and //! * Produces an [`Outcome`](crate::hive::Outcome). //! * Depending on which of `Hive`'s methods are called to submit a task (or batch of tasks), the @@ -21,8 +36,9 @@ //! * Clone an instance of a `Worker` that implements [`Clone`] //! * Call the [`create()`](crate::bee::Queen::create) method on a worker factory that //! implements the [`Queen`](crate::bee::Queen) trait. -//! * Both `Worker`s and `Queen`s may be stateful, i.e., `Worker::apply()` and `Queen::create()` -//! both take `&mut self`. +//! * A `Worker`s may be stateful, i.e., `Worker::apply()` takes `&mut self`. +//! * While `Queen` is not stateful, [`QueenMut`](crate::bee::QueenMut) may be (i.e., it's +//! `create()` method takes `&mut self`) //! * Although it is strongly recommended to avoid `panic`s in worker threads (and thus, within //! `Worker` implementations), the `Hive` does automatically restart any threads that panic. //! * A `Hive` may be [`suspend`](crate::hive::Hive::suspend)ed and @@ -36,6 +52,11 @@ //! * The following optional features are provided via feature flags: //! * `affinity`: worker threads may be pinned to CPU cores to minimize the overhead of //! context-switching. +//! * `local-batch` (>=0.3.0): worker threads take batches of tasks from the global input queue +//! and add them to a local queue, which may alleviate thread contention, especially when +//! there are many short-lived tasks. +//! * Tasks may be [`Weighted`](crate::hive::Weighted) to enable balancing unevenly sized +//! tasks between worker threads. //! * `retry`: Tasks that fail due to transient errors (e.g., temporarily unavailable resources) //! may be retried a set number of times, with an optional, exponentially increasing delay //! between retries. @@ -58,23 +79,31 @@ //! * Implement [`Default`] for your worker //! * Implement [`Clone`] for your worker //! * Create a custom worker fatory that implements the [`Queen`](crate::bee::Queen) -//! trait +//! or [`QueenMut`](crate::bee::QueenMut) trait //! 2. A [`Hive`](crate::hive::Hive) to execute your tasks. Your options are: //! * Use one of the convenience methods in the [`util`] module (see Example 1 below) -//! * Create a `Hive` manually using [`Builder`](crate::hive::Builder) (see Examples 2 +//! * Create a `Hive` manually using a [`Builder`](crate::hive::Builder) (see Examples 2 //! and 3 below) -//! * [`Builder::new()`](crate::hive::Builder::new) creates an empty `Builder` -//! * [`Builder::default()`](crate::hive::Builder::default) creates a `Builder` +//! * [`OpenBuilder`](crate::hive::OpenBuilder) is the most general builder +//! * [`OpenBuilder::empty()`](crate::hive::OpenBuilder::empty) creates an empty `OpenBuilder` +//! * [`OpenBuilder::default()`](crate::hive::OpenBuilder::default) creates a `OpenBuilder` //! with the global default settings (which may be changed using the functions in the //! [`hive`] module, e.g., `beekeeper::hive::set_num_threads_default(4)`). -//! * Use one of the `build_*` methods to build the `Hive`: +//! * The builder must be specialized for the `Queen` and `TaskQueues` types: //! * If you have a `Worker` that implements `Default`, use -//! [`build_with_default::()`](crate::hive::Builder::build_with_default) +//! [`with_worker_default::()`](crate::hive::OpenBuilder::with_worker_default) //! * If you have a `Worker` that implements `Clone`, use -//! [`build_with(MyWorker::new())`](crate::hive::Builder::build_with) +//! [`with_worker(MyWorker::new())`](crate::hive::OpenBuilder::with_worker) //! * If you have a custom `Queen`, use -//! [`build_default::()`](crate::hive::Builder::build_default) if it implements -//! `Default`, otherwise use [`build(MyQueen::new())`](crate::hive::Builder::build) +//! [`with_queen_default::()`](crate::hive::OpenBuilder::with_queen_default) +//! if it implements `Default`, otherwise use +//! [`with_queen(MyQueen::new())`](crate::hive::OpenBuilder::with_queen) +//! * If instead your queen implements `QueenMut`, use +//! [`with_queen_mut_default::()`](crate::hive::OpenBuilder::with_queen_mut_default) +//! or [`with_queen_mut(MyQueenMut::new())`](crate::hive::OpenBuilder::with_queen_mut) +//! * Use [`with_channel_queues`](crate::hive::OpenBuilder::with_channel_queues) or +//! [`with_workstealing_queues`](crate::hive::OpenBuilder::with_workstealing_queues) +//! to specify the `TaskQueues` type. //! * Note that [`Builder::num_threads()`](crate::hive::Builder::num_threads) must be set //! to a non-zero value, otherwise the built `Hive` will not start any worker threads //! until you call the [`Hive::grow()`](crate::hive::Hive::grow) method. @@ -98,8 +127,7 @@ //! * The methods with the `_send` suffix accept a channel [`Sender`](crate::channel::Sender) and //! send the `Outcome`s to that channel as they are completed //! * The methods with the `_store` suffix store the `Outcome`s in the `Hive`; these may be -//! retrieved later using the [`Hive::take_stored()`](crate::hive::Hive::take_stored) method, -//! using one of the `remove*` methods (which requires +//! retrieved later using one of the `remove*` methods (which requires //! [`OutcomeStore`](crate::hive::OutcomeStore) to be in scope), or by //! using one of the methods on [`Husk`](crate::hive::Husk) after shutting down the `Hive` using //! [`Hive::try_into_husk()`](crate::hive::Hive::try_into_husk). From e6bde1a97c30b16d1dfd9651289edb6088008e2b Mon Sep 17 00:00:00 2001 From: jdidion Date: Mon, 17 Mar 2025 12:55:06 -0700 Subject: [PATCH 63/67] update/fix docs --- Cargo.toml | 3 ++ src/bee/mod.rs | 51 ++++++++++++--------- src/hive/builder/mod.rs | 28 ++++++------ src/hive/inner/builder.rs | 4 ++ src/hive/mod.rs | 96 ++++++++++++++++++++++++++++----------- src/lib.rs | 16 +++---- 6 files changed, 127 insertions(+), 71 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index d25a111..56119a5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,7 @@ num = "0.4.3" num_cpus = "1.16.0" parking_lot = "0.12.3" paste = "1.0.15" +simple-mermaid = "0.2.0" thiserror = "1.0.63" # required with the `affinity` feature core_affinity = { version = "0.8.1", optional = true } @@ -37,6 +38,8 @@ itertools = "0.14.0" serial_test = "3.2.0" rstest = "0.22.0" stacker = "0.1.17" +aquamarine = "0.6.0" +simple-mermaid = "0.2.0" [[bench]] name = "perf" diff --git a/src/bee/mod.rs b/src/bee/mod.rs index b1be414..244eb85 100644 --- a/src/bee/mod.rs +++ b/src/bee/mod.rs @@ -3,6 +3,14 @@ //! A [`Hive`](crate::hive::Hive) is populated by bees: //! * The [`Worker`]s process the tasks submitted to the `Hive`. //! * The [`Queen`] creates a new `Worker` for each thread in the `Hive`. +//! * [`QueenMut`] can be used to implement a stateful queen - it must be wrapped in a +//! [`QueenCell`] to make it thread-safe. +//! +//! It is easiest to use the [`prelude`] when implementing your bees: +//! +//! ``` +//! use beekeeper::bee::prelude::*; +//! ``` //! //! # Worker //! @@ -18,7 +26,15 @@ //! //! The `Worker` trait has a single method, [`apply`](crate::bee::Worker::apply), which //! takes an input of type `Input` and a [`Context`] and returns a `Result` containing an either an -//! `Output` or an [`ApplyError`]. +//! `Output` or an [`ApplyError`]. Note that `Worker::apply()` takes a `&mut self` parameter, +//! meaning that it can modify its own state. +//! +//! If a fatal error occurs during processing of the task, the worker should return +//! [`ApplyError::Fatal`]. +//! +//! If the task instead fails due to a transient error, the worker should return +//! [`ApplyError::Retryable`]. If the `retry` feature is enabled, then a task that fails with a +//! `ApplyError::Retryable` error will be retried, otherwise the error is converted to `Fatal`. //! //! The `Context` contains information about the task, including: //! * The task ID. Each task submitted to a `Hive` is assigned an ID that is unique within @@ -29,16 +45,14 @@ //! periodically check the cancellation flag by calling //! [`Context::is_cancelled()`](crate::bee::context::Context::is_cancelled). If the cancellation //! flag is set, the worker may terminate early by returning [`ApplyError::Cancelled`]. -//! * If the `retry` feature is enabled, the `Context` also contains the retry -//! [`attempt`](crate::bee::context::Context::attempt), which starts at `0` the first time the task -//! is attempted and increments by `1` for each subsequent retry attempt. -//! -//! If a fatal error occurs during processing of the task, the worker should return -//! [`ApplyError::Fatal`]. +//! * The retry [`attempt`](crate::bee::context::Context::attempt), which starts at `0` the first +//! time the task is attempted. If the `retry` feature is enabled and the task fails with +//! [`ApplyError::Retryable], this value increments by `1` for each subsequent retry attempt. //! -//! If the task instead fails due to a transient error, the worker should return -//! [`ApplyError::Retryable`]. If the `retry` feature is enabled, then a task that fails with a -//! `ApplyError::Retryable` error will be retried, otherwise the error is converted to `Fatal`. +//! The `Context` also provides the ability to submit new tasks to the `Hive` using the +//! [`submit`](crate::bee::Context::submit) method. The IDs of submitted subtasks are stored in the +//! `Context` and are returned in a field of the [`Outcome`](crate::hive::Outcome) that results +//! from the parent task. //! //! A `Worker` should not panic. However, if it must execute code that may panic, it can do so //! within a closure passed to [`Panic::try_call`](crate::panic::Panic::try_call) and convert an @@ -85,7 +99,11 @@ //! A queen is defined by implementing the [`Queen`] trait. A single `Queen` instance is used to //! create the `Worker` instances for each worker thread in a `Hive`. //! -//! It is often not necessary to manually implement the `Queen` trait. For exmaple, if your `Worker` +//! If you need for the queen to have mutable state, you can instead implement [`QueenMut`], whose +//! [`create`](crate::bee::QueenMut::create) method takes `&mut self` as a parameter. When +//! creating a `Hive`, the `QueenMut` must be wrapped in a [`QueenCell`] to make it thread-safe. +//! +//! It is often not necessary to manually implement the `Queen` trait. For example, if your `Worker` //! implements `Default`, then you can use [`DefaultQueen`] implicitly by calling //! [`OpenBuilder::with_worker_default`](crate::hive::OpenBuilder::with_worker_default). Similarly, //! if your `Worker` implements `Clone`, then you can use [`CloneQueen`] @@ -93,17 +111,6 @@ //! //! A `Queen` should never panic when creating `Worker`s. //! -//! # Implementation Notes -//! -//! It is easiest to use the [`prelude`] when implementing your bees: -//! -//! ``` -//! use beekeeper::bee::prelude::*; -//! ``` -//! -//! Note that both `Queen::create()` and `Worker::apply()` receive `&mut self`, meaning that they -//! can modify their own state. -//! //! The state of a `Hive`'s `Queen` may be interrogated either //! [during](crate::hive::Hive::queen) or [after](crate::hive::Hive::try_into_husk) the //! life of the `Hive`. However, `Worker`s may never be accessed directly. Thus, it is often diff --git a/src/hive/builder/mod.rs b/src/hive/builder/mod.rs index e2dcb3e..62af81d 100644 --- a/src/hive/builder/mod.rs +++ b/src/hive/builder/mod.rs @@ -7,13 +7,7 @@ //! * Fully-typed: builder that has type parameters for the `Worker`, `Queen`, and `TaskQueues` //! types. This is the only builder with a `build` method to create a `Hive`. //! -//! Generic - Queue -//! | / -//! Bee / -//! | / -//! Full -//! -//! All builders implement the `BuilderConfig` trait, which provides methods to set configuration +//! All builders implement the `Builder` trait, which provides methods to set configuration //! parameters. The configuration options available: //! * [`Builder::num_threads`]: number of worker threads that will be spawned by the built `Hive`. //! * [`Builder::with_default_num_threads`] will set `num_threads` to the global default value. @@ -26,19 +20,25 @@ //! [`std::thread`](https://doc.rust-lang.org/stable/std/thread/index.html#stack-size) //! documentation for details on the default stack size. //! +//! The following configuration options are available when the `affinity` feature is enabled: +//! * [`Builder::core_affinity`]: List of CPU core indices to which the threads should be pinned. +//! * [`Builder::with_default_core_affinity`] will set the list to all CPU core indices, though +//! only the first `num_threads` indices will be used. +//! +//! The following configuration options are available when the `local-batch` feature is enabled: +//! * [`Builder::batch_limit`]: Maximum number of tasks that can queued by a worker. +//! * [`Builder::weight_limit`]: Maximum "weight" of tasks that can be queued by a worker. +//! * [`Builder::with_default_batch_limit`] and [`Builder::with_default_weight_limit`] set the +//! local-batch options to the global defaults, while [`Builder::with_no_local_batching`] +//! disables local-batching. +//! //! The following configuration options are available when the `retry` feature is enabled: //! * [`Builder::max_retries`]: maximum number of times a `Worker` will retry an //! [`ApplyError::Retryable`](crate::bee::ApplyError#Retryable) before giving up. //! * [`Builder::retry_factor`]: [`Duration`](std::time::Duration) factor for exponential backoff //! when retrying an `ApplyError::Retryable` error. //! * [`Builder::with_default_max_retries`] and [`Builder::with_default_retry_factor`] set the -//! retry options to the global defaults, while [`Builder::with_no_retries`] disabled retrying. -//! -//! The following configuration options are available when the `affinity` feature is enabled: -//! * [`Builder::core_affinity`]: List of CPU core indices to which the threads should be pinned. -//! * [`Builder::with_default_core_affinity`] will set the list to all CPU core indices, though -//! only the first `num_threads` indices will be used. -//! +//! retry options to the global defaults, while [`Builder::with_no_retries`] disables retrying. mod bee; mod full; mod open; diff --git a/src/hive/inner/builder.rs b/src/hive/inner/builder.rs index 93d84a9..aff5f0a 100644 --- a/src/hive/inner/builder.rs +++ b/src/hive/inner/builder.rs @@ -7,6 +7,10 @@ pub trait BuilderConfig { } /// Trait that provides `Builder` types with methods for setting configuration parameters. +/// There are multiple `Builder` implementations. See the +/// [module documentation](crate::hive::builder) for more details. +/// +#[doc = simple_mermaid::mermaid!("diagram.mmd")] /// /// This is a sealed trait, meaning it cannot be implemented outside of this crate. pub trait Builder: BuilderConfig + Sized { diff --git a/src/hive/mod.rs b/src/hive/mod.rs index 77cb443..a87429f 100644 --- a/src/hive/mod.rs +++ b/src/hive/mod.rs @@ -1,13 +1,15 @@ //! A worker pool implementation. //! -//! A [`Hive`](crate::hive::Hive) has a pool of worker threads that it uses to execute tasks. +//! A [`Hive`](crate::hive::Hive) has a pool of worker threads that it uses to execute tasks. //! //! The `Hive` has a [`Queen`](crate::bee::Queen) of type `Q`, which it uses to create a -//! [`Worker`] of type `W` for each thread it starts in the pool. +//! [`Worker`] of type `Queen::Kind` for each thread it starts in the pool. The `Hive` also has +//! `TaskQueues` implementation of type `T`, which provides the global and worker threads-local +//! queues for managing tasks. //! //! Each task is submitted to the `Hive` as an input of type `W::Input`, and, optionally, a //! channel where the [`Outcome`] of processing the task will be sent upon completion. To these, -//! the `Hive` adds additional context to create the task. It then adds the task to an internal +//! the `Hive` adds additional context to create the task. It then adds the task to its global //! queue that is shared with all the worker threads. //! //! Each worker thread executes a loop in which it receives a task, evaluates it with its `Worker`, @@ -26,7 +28,7 @@ //! create a `Builder` configured with the global default values (see below). //! //! See the [`builder` module documentation](crate::hive::builder) for more details on the options -//! that may be configured, and the `build*` methods available to create the `Hive`. +//! that may be configured. //! //! Building a `Hive` consumes the `Builder`. To create multiple identical `Hive`s, you can `clone` //! the `Builder`. @@ -93,8 +95,7 @@ //! hive.grow(12); //! //! // increase the number of threads and also provide additional cores for pinning -//! // this requires the `affinity` feature -//! // hive.grow_with_affinity(4, 16..20); +//! hive.grow_with_affinity(4, 16..20); //! ``` //! //! As an application developer depending on `beekeeper`, you must ensure you assign each core @@ -106,7 +107,7 @@ //! //! ## Retrying tasks (requires `feature = "retry"`) //! -//! Some types of tasks (e.g., those requirng network I/O operations) may fail transiently but +//! Some types of tasks (e.g., those requiring network I/O operations) may fail transiently but //! could be successful if retried at a later time. Such retry behavior is supported by the `retry` //! feature and only requires a) configuring the `Builder` by setting //! [`max_retries`](crate::hive::Builder::max_retries) and (optionally) @@ -122,18 +123,18 @@ //! `2^(attempt - 1) * retry_factor`. //! * If a `retry_factor` is not configured, then the task is queued with no delay. //! * When a worker thread becomes available, it first checks the retry queue to see if there is -//! a task to retry before taking a new task from the input channel. +//! a task to retry before taking a new task from the global queue. //! -//! Note that `ApplyError::Retryable` is not feature-gated - a `Worker` can be implemented to be -//! retry-aware but used with a `Hive` for which retry is not enabled, or in an application where -//! the `retry` feature is not enabled. In such cases, `Retryable` errors are automatically -//! converted to `Fatal` errors by the worker thread. +//! Note that `ApplyError::Retryable` and `Context::attempt` are not feature-gated - a `Worker` can +//! be implemented to be retry-aware but used with a `Hive` for which retry is not enabled, or in +//! an application where the `retry` feature is not enabled. In such cases, `Retryable` errors are +//! automatically converted to `Fatal` errors by the worker thread. //! //! ## Batching tasks (requires `feature = "local-batch"`) //! //! The performance of a `Hive` can degrade as the number of worker threads grows and/or the //! average duration of a task shrinks, due to increased contention between worker threads when -//! receiving tasks from the shared input channel. To improve performance, workers can take more +//! receiving tasks from the shared global queue. To improve performance, workers can take more //! than one task each time they access the input channel, and store the extra tasks in a local //! queue. This behavior is activated by enabling the `local-batch` feature. //! @@ -143,11 +144,45 @@ //! [`set_worker_batch_limit`](crate::hive::Hive::set_worker_batch_limit) method for changing the //! batch size of an existing `Hive`. //! +//! ### Task weighting +//! +//! With the `local-batch` feature enabled, it also becomes possible to assign a weight to each +//! input. This is useful for workloads where some tasks may take substantially longer to process +//! than others. Combined with setting a weight limit (using [`Builder::weight_limit`] +//! or [`Hive::set_worker_weight_limit`](crate::hive::Hive::set_worker_batch_limit)), this limits +//! the number of tasks that can be queued by a worker thread based on the minimum of the batch +//! size and the total task weight. +//! +//! A weighted input is an instance of [`Weighted`], where `T` is the worker's input type. +//! Instances of `Weighted` can be created explicitly, or you can convert input values using +//! the methods on `Weighted` or iterators over input values using the methods on the +//! [`WeightedIteratorExt`] and [`WeightedExactSizeIteratorExt`] extension traits. +//! +//! The `Hive` methods for submitting tasks accept both weighted and unweighted input, but weighted +//! inputs are *not* supported with the `local-batch` feature disabled. +//! +//! ``` +//! use beekeeper::hive::prelude::*; +//! # type MyWorker = beekeeper::bee::stock::EchoWorker; +//! +//! let hive = channel_builder(false) +//! .num_threads(4) +//! .batch_limit(10) +//! .weight_limit(10) +//! .with_worker_default::() +//! .build(); +//! +//! // creates weighted inputs, where each input's weight is the same +//! // as it's value, e.g. `((0,0), (1,1),..,(9,9))` +//! let inputs = (0..10).into_iter().into_identity_weighted(); +//! let outputs = hive.map(inputs).into_outputs(); +//! ``` +//! //! ## Global defaults //! //! The [`hive`](crate::hive) module has functions for setting the global default values for some //! of the `Builder` parameters. These default values are used to pre-configure the `Builder` when -//! using `Builder::default()`. +//! using [`OpenBuilder::default()`]. //! //! The available global defaults are: //! @@ -215,7 +250,7 @@ //! You can create an instance of the enabled outcome channel type using the [`outcome_channel`] //! function. //! -//! `Hive` as several methods (with the `_send` suffix) for submitting tasks whose outcomes will be +//! `Hive` has several methods (with the `_send` suffix) for submitting tasks whose outcomes will be //! delivered to a user-specified channel. Note that, for these methods, the `tx` parameter is of //! type `Borrow>`, which allows you to pass in either a value or a reference. //! Passing a value causes the `Sender` to be dropped after the call; passing a reference allows @@ -240,7 +275,9 @@ //! # Retrieving outcomes //! //! Each task that is successfully submitted to a `Hive` will have a corresponding `Outcome`. -//! [`Outcome`] is similar to `Result`, except that the error variants are enumerated: +//! [`Outcome`] is similar to `Result`, except that the error variants are enumerated. +//! * [`Success`](Outcome::Success): the task completed successfully. The output is provided in +//! the `value` field. //! * [`Failure`](Outcome::Failure): the task failed with an error of type `W::Error`. If possible, //! the input value is also provided. //! * [`Panic`](Outcome::Panic): the `Worker` panicked while processing the task. The panic @@ -253,6 +290,12 @@ //! ID was found. This variant is only used when a list of outcomes is requested, such as when //! using one of the `select_*` methods on an `Outcome` iterator (see below). //! +//! Two additional `Outcome` variants depend on optional feature flags: +//! * [`WeightLimitExceeded`](Outcome::WeightLimitExceeded): depends on the `local-batch` feature; +//! the task's weight exceeded the limit set for the `Hive`. +//! * [`MaxRetriesAttempted`](Outcome::MaxRetriesAttempted): depends on the `retry` feature; the +//! task failed after being retried the maximum number of times. +//! //! An `Outcome` can be converted into a `Result` (using `into()`) or //! [`unwrap`](crate::hive::Outcome::unwrap)ped into an output value of type `W::Output`. //! @@ -264,7 +307,7 @@ //! methods create a dedicated outcome channel to use for each batch of tasks, and thus expect the //! channel receiver to receive exactly the outcomes with the task IDs of the submitted tasks. If, //! somehow, an unexpected `Outcome` is received, it is silently dropped. If any expected outcomes -//! have not been received after the channel sender has disconnected, then those task IDs are' +//! have not been received after the channel sender has disconnected, then those task IDs are //! yielded as `Outcome::Missing` results. //! //! When the [`OutcomeIteratorExt`] trait is in scope, then additional methods become available on @@ -277,7 +320,7 @@ //! //! ## Outcome channels //! -//! Using one of the `*_send` methods with a channel enables the `Hive` to send you `Outcome`s +//! Using one of the `*_send` methods with a channel enables the `Hive` to send `Outcome`s //! asynchronously as they become available. This means that you will likely receive the outcomes //! out of order (i.e., not in the same order as the provided inputs). //! @@ -343,18 +386,17 @@ //! Processing can be resumed by calling the [`resume`](crate::hive::Hive::resume) method. //! The [`swarm_unprocessed_send`](crate::hive::Hive::swarm_unprocessed_send) or //! [`swarm_unprocessed_store`](crate::hive::Hive::swarm_unprocessed_store) methods can be used to -//! submit any unprocessed tasks stored in the `Hive` for (re)processing. +//! submit any unprocessed tasks stored in the `Hive` for (re)processing after resuming the `Hive`. //! //! ## Hive poisoning //! //! The internal data structure shared between a `Hive`, its clones, and its worker threads is //! considered thread-safe. However, there is no formal proof that it is incorruptible. A `Hive` //! attempts to detect if it has become corrupted and, if so, sets the `poisoned` flag on the -//! shared data. A poisoned `Hive` will not accept or process any new tasks, and all worker threads -//! will terminate after finishing their current tasks. If a task is submitted to a poisoned `Hive`, -//! it will immediately be converted to an `Unprocessed` outcome and sent/stored. The only thing -//! that can be done with a poisoned `Hive` is to access its stored `Outcome`s or convert it to a -//! `Husk` (see below). +//! shared data. A poisoned `Hive` will not accept or process any new tasks, all worker threads +//! will terminate after finishing their current tasks, and all queued tasks are converted to +//! `Unprocessed` outcomes and sent/stored. The only thing that can be done with a poisoned `Hive` +//! is to access its stored `Outcome`s or convert it to a `Husk` (see below). //! //! # Disposing of a `Hive` //! @@ -363,7 +405,7 @@ //! its shared data is dropped, then the following steps happen: //! * The `Hive` is poisoned to prevent any new tasks from being submitted or queued tasks from //! being processed. -//! * All of the `Hive`s queued tasks are coverted to `Unprocessed` outcomes and either sent to +//! * All of the `Hive`'s queued tasks are coverted to `Unprocessed` outcomes and either sent to //! their outcome channel or stored in the `Hive`. //! * If the `Hive` was in a suspended state, it is resumed. This is necessary to unblock the //! worker threads and allow them to terminate. @@ -433,12 +475,12 @@ pub fn outcome_channel() -> (OutcomeSender, OutcomeReceiver) { } pub mod prelude { - #[cfg(feature = "local-batch")] - pub use super::Weighted; pub use super::{ Builder, Hive, Husk, Outcome, OutcomeBatch, OutcomeIteratorExt, OutcomeStore, Poisoned, TaskQueuesBuilder, channel_builder, open_builder, outcome_channel, workstealing_builder, }; + #[cfg(feature = "local-batch")] + pub use super::{Weighted, WeightedExactSizeIteratorExt, WeightedIteratorExt}; } #[cfg(test)] diff --git a/src/lib.rs b/src/lib.rs index 5b12532..bc73071 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,7 +9,7 @@ //! * A [`Builder`](crate::hive::Builder) is used to configure and create a worker pool //! called a [`Hive`](crate::hive::Hive). //! * `Hive` is generic over: -//! * The type of [`Queen`](crate::bee::Queen) which creates `Worker` instances +//! * The type of [`Queen`](crate::bee::Queen), which creates `Worker` instances //! * The type of [`TaskQueues`](crate::hive::TaskQueues), which provides the global and //! worker thread-local queues for managing tasks; there are currently two implementations: //! * Channel: A @@ -58,8 +58,8 @@ //! * Tasks may be [`Weighted`](crate::hive::Weighted) to enable balancing unevenly sized //! tasks between worker threads. //! * `retry`: Tasks that fail due to transient errors (e.g., temporarily unavailable resources) -//! may be retried a set number of times, with an optional, exponentially increasing delay -//! between retries. +//! may be retried up to a set number of times, with an optional, exponentially increasing +//! delay between retries. //! * Several alternative `channel` implementations are supported: //! * [`crossbeam`](https://docs.rs/crossbeam/latest/crossbeam/) //! * [`flume`](https://github.com/zesterer/flume) @@ -78,7 +78,7 @@ //! * Do at least one of the following: //! * Implement [`Default`] for your worker //! * Implement [`Clone`] for your worker -//! * Create a custom worker fatory that implements the [`Queen`](crate::bee::Queen) +//! * Create a custom worker factory that implements the [`Queen`](crate::bee::Queen) //! or [`QueenMut`](crate::bee::QueenMut) trait //! 2. A [`Hive`](crate::hive::Hive) to execute your tasks. Your options are: //! * Use one of the convenience methods in the [`util`] module (see Example 1 below) @@ -325,12 +325,12 @@ //! }) //! .into_bytes(); //! -//! // verify the output - note that `swarm` ensures the outputs are in -//! // the same order as the inputs +//! // verify the output - note that `swarm` ensures the outputs are in the same order +//! // as the inputs //! assert_eq!(output, b"abcdefgh"); //! -//! // shutdown the hive, use the Queen to wait on child processes, and -//! // report errors +//! // shutdown the hive, use the Queen to wait on child processes, and report errors; +//! // the `_outcomes` will be empty as we did not use any `_store` methods //! let (queen, _outcomes) = hive.try_into_husk(false).unwrap().into_parts(); //! let (wait_ok, wait_err): (Vec<_>, Vec<_>) = //! queen.into_inner().wait_for_all().into_iter().partition(Result::is_ok); From b2d05198a50b8c83763025bba25f9712e9165d48 Mon Sep 17 00:00:00 2001 From: jdidion Date: Mon, 17 Mar 2025 12:57:27 -0700 Subject: [PATCH 64/67] add diagram --- src/hive/inner/diagram.mmd | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 src/hive/inner/diagram.mmd diff --git a/src/hive/inner/diagram.mmd b/src/hive/inner/diagram.mmd new file mode 100644 index 0000000..a9da1e2 --- /dev/null +++ b/src/hive/inner/diagram.mmd @@ -0,0 +1,5 @@ +graph TD; + Generic-->Queue + Generic-->Bee + Bee-->Full + Queue-->Full \ No newline at end of file From 674ca53a485474f9017097080419ae044c61b04e Mon Sep 17 00:00:00 2001 From: jdidion Date: Mon, 17 Mar 2025 14:03:09 -0700 Subject: [PATCH 65/67] combine weighted iterator extension traits --- src/hive/mod.rs | 16 +++++------ src/hive/weighted.rs | 63 ++++++++++++++++---------------------------- 2 files changed, 31 insertions(+), 48 deletions(-) diff --git a/src/hive/mod.rs b/src/hive/mod.rs index a87429f..0a1bde3 100644 --- a/src/hive/mod.rs +++ b/src/hive/mod.rs @@ -156,7 +156,7 @@ //! A weighted input is an instance of [`Weighted`], where `T` is the worker's input type. //! Instances of `Weighted` can be created explicitly, or you can convert input values using //! the methods on `Weighted` or iterators over input values using the methods on the -//! [`WeightedIteratorExt`] and [`WeightedExactSizeIteratorExt`] extension traits. +//! [`WeightedIteratorExt`] extension trait. //! //! The `Hive` methods for submitting tasks accept both weighted and unweighted input, but weighted //! inputs are *not* supported with the `local-batch` feature disabled. @@ -452,7 +452,7 @@ pub use self::inner::{ }; pub use self::outcome::{Outcome, OutcomeBatch, OutcomeIteratorExt, OutcomeStore}; #[cfg(feature = "local-batch")] -pub use self::weighted::{Weighted, WeightedExactSizeIteratorExt, WeightedIteratorExt}; +pub use self::weighted::{Weighted, WeightedIteratorExt}; use self::context::HiveLocalContext; use self::inner::{Config, Shared, Task, WorkerQueues}; @@ -480,7 +480,7 @@ pub mod prelude { TaskQueuesBuilder, channel_builder, open_builder, outcome_channel, workstealing_builder, }; #[cfg(feature = "local-batch")] - pub use super::{Weighted, WeightedExactSizeIteratorExt, WeightedIteratorExt}; + pub use super::{Weighted, WeightedIteratorExt}; } #[cfg(test)] @@ -2575,7 +2575,7 @@ mod weighted_map_tests { mod weighted_swarm_tests { use crate::bee::stock::{EchoWorker, Thunk, ThunkWorker}; use crate::hive::{ - Builder, Outcome, TaskQueuesBuilder, WeightedExactSizeIteratorExt, channel_builder, + Builder, Outcome, TaskQueuesBuilder, WeightedIteratorExt, channel_builder, workstealing_builder, }; use rstest::*; @@ -2604,7 +2604,7 @@ mod weighted_swarm_tests { }) }) .map(|thunk| (thunk, 0)) - .into_weighted(); + .into_weighted_exact(); let outputs: Vec<_> = hive.swarm(inputs).map(Outcome::unwrap).collect(); assert_eq!(outputs, (0..10).collect::>()) } @@ -2630,7 +2630,7 @@ mod weighted_swarm_tests { i }) }) - .into_default_weighted(); + .into_default_weighted_exact(); let outputs: Vec<_> = hive.swarm(inputs).map(Outcome::unwrap).collect(); assert_eq!(outputs, (0..10).collect::>()) } @@ -2656,7 +2656,7 @@ mod weighted_swarm_tests { i }) }) - .into_const_weighted(0); + .into_const_weighted_exact(0); let outputs: Vec<_> = hive.swarm(inputs).map(Outcome::unwrap).collect(); assert_eq!(outputs, (0..10).collect::>()) } @@ -2675,7 +2675,7 @@ mod weighted_swarm_tests { .num_threads(NUM_THREADS) .batch_limit(BATCH_LIMIT) .build(); - let inputs = (0..10u8).into_identity_weighted(); + let inputs = (0..10u8).into_identity_weighted_exact(); let outputs: Vec<_> = hive.swarm(inputs).map(Outcome::unwrap).collect(); assert_eq!(outputs, (0..10).collect::>()) } diff --git a/src/hive/weighted.rs b/src/hive/weighted.rs index 317be43..fed4bef 100644 --- a/src/hive/weighted.rs +++ b/src/hive/weighted.rs @@ -115,14 +115,8 @@ pub trait WeightedIteratorExt: IntoIterator + Sized { Weighted::new(item, weight) }) } -} - -impl WeightedIteratorExt for T {} -/// Extends `IntoIterator` to add methods to convert any iterator into an iterator over `Weighted` -/// items. -pub trait WeightedExactSizeIteratorExt: IntoIterator + Sized { - fn into_weighted(self) -> impl ExactSizeIterator> + fn into_weighted_exact(self) -> impl ExactSizeIterator> where Self: IntoIterator, Self::IntoIter: ExactSizeIterator + 'static, @@ -131,14 +125,17 @@ pub trait WeightedExactSizeIteratorExt: IntoIterator + Sized { .map(|(value, weight)| Weighted::new(value, weight)) } - fn into_default_weighted(self) -> impl ExactSizeIterator> + fn into_default_weighted_exact(self) -> impl ExactSizeIterator> where Self::IntoIter: ExactSizeIterator + 'static, { self.into_iter().map(Into::into) } - fn into_const_weighted(self, weight: u32) -> impl ExactSizeIterator> + fn into_const_weighted_exact( + self, + weight: u32, + ) -> impl ExactSizeIterator> where Self::IntoIter: ExactSizeIterator + 'static, { @@ -146,7 +143,7 @@ pub trait WeightedExactSizeIteratorExt: IntoIterator + Sized { .map(move |item| Weighted::new(item, weight)) } - fn into_identity_weighted(self) -> impl ExactSizeIterator> + fn into_identity_weighted_exact(self) -> impl ExactSizeIterator> where Self::Item: ToPrimitive + Clone, Self::IntoIter: ExactSizeIterator + 'static, @@ -154,7 +151,10 @@ pub trait WeightedExactSizeIteratorExt: IntoIterator + Sized { self.into_iter().map(Weighted::from_identity) } - fn into_weighted_with(self, f: F) -> impl ExactSizeIterator> + fn into_weighted_exact_with( + self, + f: F, + ) -> impl ExactSizeIterator> where Self::IntoIter: ExactSizeIterator + 'static, F: Fn(&Self::Item) -> u32, @@ -166,17 +166,12 @@ pub trait WeightedExactSizeIteratorExt: IntoIterator + Sized { } } -impl WeightedExactSizeIteratorExt for T -where - T: IntoIterator, - T::IntoIter: ExactSizeIterator, -{ -} +impl WeightedIteratorExt for T {} #[cfg(test)] #[cfg_attr(coverage_nightly, coverage(off))] mod tests { - use super::Weighted; + use super::*; #[test] fn test_new() { @@ -214,12 +209,6 @@ mod tests { assert_eq!(weighted.weight(), 10); assert_eq!(weighted.into_parts(), (42, 10)); } -} - -#[cfg(test)] -#[cfg_attr(coverage_nightly, coverage(off))] -mod iter_tests { - use super::WeightedIteratorExt; #[test] fn test_into_weighted() { @@ -263,46 +252,40 @@ mod iter_tests { .into_weighted_with(|i| i * 2) .for_each(|weighted| assert_eq!(weighted.weight(), weighted.value * 2)); } -} - -#[cfg(test)] -#[cfg_attr(coverage_nightly, coverage(off))] -mod exact_iter_test { - use super::WeightedExactSizeIteratorExt; #[test] - fn test_into_weighted() { + fn test_into_weighted_exact() { (0..10) .map(|i| (i, i)) - .into_weighted() + .into_weighted_exact() .for_each(|weighted| assert_eq!(weighted.weight(), weighted.value)); } #[test] - fn test_into_default_weighted() { + fn test_into_default_weighted_exact() { (0..10) - .into_default_weighted() + .into_default_weighted_exact() .for_each(|weighted| assert_eq!(weighted.weight(), 0)); } #[test] - fn test_into_identity_weighted() { + fn test_into_identity_weighted_exact() { (0..10) - .into_identity_weighted() + .into_identity_weighted_exact() .for_each(|weighted| assert_eq!(weighted.weight(), weighted.value)); } #[test] - fn test_into_const_weighted() { + fn test_into_const_weighted_exact() { (0..10) - .into_const_weighted(5) + .into_const_weighted_exact(5) .for_each(|weighted| assert_eq!(weighted.weight(), 5)); } #[test] - fn test_into_weighted_with() { + fn test_into_weighted_exact_with() { (0..10) - .into_weighted_with(|i| i * 2) + .into_weighted_exact_with(|i| i * 2) .for_each(|weighted| assert_eq!(weighted.weight(), weighted.value * 2)); } } From 5d6df2e388e8f66d024fed6ccae6e83f75318357 Mon Sep 17 00:00:00 2001 From: jdidion Date: Mon, 17 Mar 2025 14:12:22 -0700 Subject: [PATCH 66/67] combine weighted iterator extension traits --- src/hive/weighted.rs | 33 ++++++++++++++++++++++++++++----- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/src/hive/weighted.rs b/src/hive/weighted.rs index fed4bef..3931535 100644 --- a/src/hive/weighted.rs +++ b/src/hive/weighted.rs @@ -72,23 +72,31 @@ impl From<(T, P)> for Weighted { /// Extends `IntoIterator` to add methods to convert any iterator into an iterator over `Weighted` /// items. pub trait WeightedIteratorExt: IntoIterator + Sized { - fn into_weighted(self) -> impl Iterator> + /// Converts this iterator over (T, P) items into an iterator over `Weighted` items with + /// weights set to `P::into_u32()`. + fn into_weighted(self) -> impl Iterator> where - Self: IntoIterator, + P: ToPrimitive, + Self: IntoIterator, { self.into_iter() .map(|(value, weight)| Weighted::new(value, weight)) } + /// Converts this iterator into an iterator over `Weighted` with weights set to 0. fn into_default_weighted(self) -> impl Iterator> { self.into_iter().map(Into::into) } + /// Converts this iterator into an iterator over `Weighted` with weights set to + /// `weight`. fn into_const_weighted(self, weight: u32) -> impl Iterator> { self.into_iter() .map(move |item| Weighted::new(item, weight)) } + /// Converts this iterator into an iterator over `Weighted` with weights set to + /// `item.clone().into_u32()`. fn into_identity_weighted(self) -> impl Iterator> where Self::Item: ToPrimitive + Clone, @@ -96,16 +104,21 @@ pub trait WeightedIteratorExt: IntoIterator + Sized { self.into_iter().map(Weighted::from_identity) } - fn into_weighted_zip(self, weights: W) -> impl Iterator> + /// Zips this iterator with `weights` and converts each tuple into a `Weighted` + /// with the weight set to the corresponding value from `weights`. + fn into_weighted_zip(self, weights: W) -> impl Iterator> where - W: IntoIterator, + P: ToPrimitive + Clone + Default, + W: IntoIterator, W::IntoIter: 'static, { self.into_iter() - .zip(weights.into_iter().chain(std::iter::repeat(0))) + .zip(weights.into_iter().chain(std::iter::repeat(P::default()))) .map(Into::into) } + /// Converts this interator into an iterator over `Weighted` with weights set to + /// the result of calling `f` on each item. fn into_weighted_with(self, f: F) -> impl Iterator> where F: Fn(&Self::Item) -> u32, @@ -116,6 +129,8 @@ pub trait WeightedIteratorExt: IntoIterator + Sized { }) } + /// Converts this `ExactSizeIterator` over (T, P) items into an `ExactSizeIterator` over + /// `Weighted` items with weights set to `P::into_u32()`. fn into_weighted_exact(self) -> impl ExactSizeIterator> where Self: IntoIterator, @@ -125,6 +140,8 @@ pub trait WeightedIteratorExt: IntoIterator + Sized { .map(|(value, weight)| Weighted::new(value, weight)) } + /// Converts this `ExactSizeIterator` into an `ExactSizeIterator` over `Weighted` + /// with weights set to 0. fn into_default_weighted_exact(self) -> impl ExactSizeIterator> where Self::IntoIter: ExactSizeIterator + 'static, @@ -132,6 +149,8 @@ pub trait WeightedIteratorExt: IntoIterator + Sized { self.into_iter().map(Into::into) } + /// Converts this `ExactSizeIterator` into an `ExactSizeIterator` over `Weighted` + /// with weights set to `weight`. fn into_const_weighted_exact( self, weight: u32, @@ -143,6 +162,8 @@ pub trait WeightedIteratorExt: IntoIterator + Sized { .map(move |item| Weighted::new(item, weight)) } + /// Converts this `ExactSizeIterator` into an `ExactSizeIterator` over `Weighted` + /// with weights set to `item.clone().into_u32()`. fn into_identity_weighted_exact(self) -> impl ExactSizeIterator> where Self::Item: ToPrimitive + Clone, @@ -151,6 +172,8 @@ pub trait WeightedIteratorExt: IntoIterator + Sized { self.into_iter().map(Weighted::from_identity) } + /// Converts this `ExactSizeIterator` into an `ExactSizeIterator` over `Weighted` + /// with weights set to the result of calling `f` on each item. fn into_weighted_exact_with( self, f: F, From a20361fafa19d87e3b685e1f33297398ae3b636f Mon Sep 17 00:00:00 2001 From: jdidion Date: Mon, 17 Mar 2025 14:37:40 -0700 Subject: [PATCH 67/67] update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7d5356c..c5c30ed 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,7 @@ The general theme of this release is performance improvement by eliminating thre * Switched to storing `Outcome`s in the hive using a data structure that does not require locking when inserting, which should reduce thread contention when using `*_store` operations. * Switched to using `crossbeam_channel` for the task input channel in `ChannelTaskQueues`. These are multi-produer, multi-consumer channels (mpmc; as opposed to `std::mpsc`, which is single-consumer), which means it is no longer necessary for worker threads to aquire a Mutex lock on the channel receiver when getting tasks. * Added the `beekeeper::hive::mock` module, which has a `MockTaskRunner` for `apply`ing a worker in a mock context. This is useful for testing your `Worker`. + * Updated to `2024` edition and Rust version `1.85` ## 0.2.1