diff --git a/tokio/src/runtime/context.rs b/tokio/src/runtime/context.rs index d78935e7243..3bf9919404f 100644 --- a/tokio/src/runtime/context.rs +++ b/tokio/src/runtime/context.rs @@ -159,6 +159,11 @@ cfg_rt! { CONTEXT.try_with(|ctx| ctx.current_task_id.get()).unwrap_or(None) } + #[cfg(tokio_unstable)] + pub(crate) fn worker_index() -> Option { + with_scheduler(|ctx| ctx.and_then(|c| c.worker_index())) + } + #[track_caller] pub(crate) fn defer(waker: &Waker) { with_scheduler(|maybe_scheduler| { diff --git a/tokio/src/runtime/mod.rs b/tokio/src/runtime/mod.rs index 77cf183dc56..f1cc39b3d92 100644 --- a/tokio/src/runtime/mod.rs +++ b/tokio/src/runtime/mod.rs @@ -569,6 +569,40 @@ cfg_rt! { mod local_runtime; pub use local_runtime::{LocalRuntime, LocalOptions}; + + /// Returns the index of the current worker thread, if called from a + /// runtime worker thread. + /// + /// The returned value is a 0-based index matching the worker indices + /// used by [`RuntimeMetrics`] methods such as + /// [`worker_total_busy_duration`](RuntimeMetrics::worker_total_busy_duration). + /// + /// Returns `None` when called from outside a runtime worker thread + /// (for example, from a blocking thread or a non-Tokio thread). On the + /// multi-thread runtime, the thread that calls [`Runtime::block_on`] is + /// not a worker thread, so this also returns `None` there. + /// + /// For the current-thread runtime, this always returns `Some(0)` + /// (including inside `block_on`, since the calling thread *is* the + /// worker thread). + /// + /// # Examples + /// + /// ``` + /// # #[cfg(not(target_family = "wasm"))] + /// # { + /// #[tokio::main(flavor = "multi_thread", worker_threads = 4)] + /// async fn main() { + /// let index = tokio::spawn(async { + /// tokio::runtime::worker_index() + /// }).await.unwrap(); + /// println!("Task ran on worker {:?}", index); + /// } + /// # } + /// ``` + pub fn worker_index() -> Option { + context::worker_index() + } } cfg_taskdump! { diff --git a/tokio/src/runtime/scheduler/mod.rs b/tokio/src/runtime/scheduler/mod.rs index 3f142120d33..f991e8abda7 100644 --- a/tokio/src/runtime/scheduler/mod.rs +++ b/tokio/src/runtime/scheduler/mod.rs @@ -293,6 +293,15 @@ cfg_rt! { match_flavor!(self, Context(context) => context.defer(waker)); } + #[cfg(tokio_unstable)] + pub(crate) fn worker_index(&self) -> Option { + match self { + Context::CurrentThread(_) => Some(0), + #[cfg(feature = "rt-multi-thread")] + Context::MultiThread(context) => Some(context.worker_index()), + } + } + #[cfg(all(tokio_unstable, feature = "time", feature = "rt-multi-thread"))] pub(crate) fn with_time_temp_local_context(&self, f: F) -> R where diff --git a/tokio/src/runtime/scheduler/multi_thread/worker.rs b/tokio/src/runtime/scheduler/multi_thread/worker.rs index f48e6ba5271..72bdc2bd31c 100644 --- a/tokio/src/runtime/scheduler/multi_thread/worker.rs +++ b/tokio/src/runtime/scheduler/multi_thread/worker.rs @@ -1006,6 +1006,11 @@ impl Context { None => f(None), }) } + + #[cfg(tokio_unstable)] + pub(crate) fn worker_index(&self) -> usize { + self.worker.index + } } impl Core { diff --git a/tokio/tests/rt_worker_index.rs b/tokio/tests/rt_worker_index.rs new file mode 100644 index 00000000000..873c5247d4a --- /dev/null +++ b/tokio/tests/rt_worker_index.rs @@ -0,0 +1,94 @@ +#![warn(rust_2018_idioms)] +#![cfg(all( + feature = "full", + tokio_unstable, + not(target_os = "wasi"), + target_has_atomic = "64" +))] + +use tokio::runtime::{self, Runtime}; + +#[test] +fn worker_index_multi_thread() { + let rt = Runtime::new().unwrap(); + rt.block_on(async { + let index = tokio::task::spawn(async { runtime::worker_index() }) + .await + .unwrap(); + let num_workers = rt.metrics().num_workers(); + let index = index.expect("should be Some on worker thread"); + assert!( + index < num_workers, + "worker_index {index} >= num_workers {num_workers}" + ); + }); +} + +#[test] +fn worker_index_current_thread() { + let rt = runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + rt.block_on(async { + let index = runtime::worker_index(); + assert_eq!(index, Some(0)); + }); +} + +#[test] +fn worker_index_outside_runtime() { + assert_eq!(runtime::worker_index(), None); +} + +#[test] +fn worker_index_matches_metrics_worker_thread_id() { + let rt = runtime::Builder::new_multi_thread() + .worker_threads(4) + .enable_all() + .build() + .unwrap(); + let metrics = rt.metrics(); + + rt.block_on(async { + // Spawn a task and verify the worker_index matches the metrics index + tokio::task::spawn(async move { + let index = runtime::worker_index().expect("should be on worker thread"); + let current_thread = std::thread::current().id(); + let metrics_thread = metrics.worker_thread_id(index); + assert_eq!( + metrics_thread, + Some(current_thread), + "worker_index() returned {index} but metrics.worker_thread_id({index}) \ + does not match current thread" + ); + }) + .await + .unwrap(); + }); +} + +#[test] +fn worker_index_from_spawn_blocking() { + let rt = Runtime::new().unwrap(); + rt.block_on(async { + let index = tokio::task::spawn_blocking(runtime::worker_index) + .await + .unwrap(); + assert_eq!( + index, None, + "spawn_blocking should not be on a worker thread" + ); + }); +} + +#[test] +fn worker_index_block_on_multi_thread() { + let rt = Runtime::new().unwrap(); + // block_on runs on the calling thread, not a worker thread + let index = rt.block_on(async { runtime::worker_index() }); + assert_eq!( + index, None, + "block_on thread is not a worker thread on multi-thread runtime" + ); +}