diff --git a/tokio/src/runtime/task/trace/mod.rs b/tokio/src/runtime/task/trace/mod.rs
index 5455e1133db..199b5d0eef7 100644
--- a/tokio/src/runtime/task/trace/mod.rs
+++ b/tokio/src/runtime/task/trace/mod.rs
@@ -34,7 +34,8 @@ pub(crate) struct Context {
/// The function that is invoked at each leaf future inside of Tokio
///
/// For example, within tokio::time:sleep, sockets. etc.
- trace_leaf_fn: Cell>,
+ #[allow(clippy::type_complexity)]
+ trace_leaf_fn: Cell >>,
}
/// A [`Frame`] in an intrusive, doubly-linked tree of [`Frame`]s.
@@ -114,19 +115,38 @@ impl Context {
}
}
+ /// Calls the provided closure if we are being traced.
fn try_with_current_trace_leaf_fn(f: F) -> Option
where
- F: FnOnce(&Cell>) -> R,
+ F: for<'a> FnOnce(&'a mut dyn FnMut(&TraceMeta)) -> R,
{
+ let mut ret = None;
+
+ let inner = |context: &Context| {
+ if let Some(mut trace_leaf_fn) = context.trace_leaf_fn.get() {
+ context.trace_leaf_fn.set(None);
+ let _restore = defer(move || {
+ context.trace_leaf_fn.set(Some(trace_leaf_fn));
+ });
+
+ // SAFETY: The trace leaf fn is valid for the duration in which it's stored in the
+ // context. Furthermore, re-entrant calls are not possible since we store `None` for
+ // the duration in which we hold a mutable reference, so access is exclusive for that
+ // duration.
+ ret = Some(f(unsafe { trace_leaf_fn.as_mut() }));
+ }
+ };
+
// SAFETY: This call can only access the trace_leaf_fn field, so it cannot
// break the trace frame linked list.
- unsafe { Self::try_with_current(|context| f(&context.trace_leaf_fn)) }
+ unsafe { Self::try_with_current(inner) };
+
+ ret
}
/// Produces `true` if the current task is being traced; otherwise false.
pub(crate) fn is_tracing() -> bool {
- Self::try_with_current_trace_leaf_fn(|maybe_trace_leaf| maybe_trace_leaf.get().is_some())
- .unwrap_or(false)
+ Self::try_with_current_trace_leaf_fn(|_| ()).is_some()
}
}
@@ -189,20 +209,46 @@ pub struct TraceMeta {
/// assert!(count > 0);
/// # }
/// ```
-pub fn trace_with(f: F, trace_leaf: fn(&TraceMeta)) -> R
+pub fn trace_with(f: FN, mut trace_leaf: FT) -> R
where
- F: FnOnce() -> R,
+ FN: FnOnce() -> R,
+ FT: FnMut(&TraceMeta),
{
- // store our new trace_leaf function
- let previous =
- Context::try_with_current_trace_leaf_fn(|current| current.replace(Some(trace_leaf)));
+ let trace_leaf_dyn = (&mut trace_leaf) as &mut (dyn FnMut(&TraceMeta) + '_);
+ // SAFETY: The raw pointer is removed from the thread local before `trace_leaf` is dropped, so
+ // this transmute cannot lead to the violation of any lifetime requirements.
+ let trace_leaf_dyn = unsafe {
+ std::mem::transmute::<
+ *mut (dyn FnMut(&TraceMeta) + '_),
+ *mut (dyn FnMut(&TraceMeta) + 'static),
+ >(trace_leaf_dyn)
+ };
+ // SAFETY: Pointer comes from reference, so not null.
+ let trace_leaf_dyn = unsafe { NonNull::new_unchecked(trace_leaf_dyn) };
+
+ let mut old_trace_leaf_fn = None;
+
+ // Even if this access fails, that's okay. In that case, we still call the closure without
+ // actually performing any tracing.
+ //
+ // SAFETY: This call can only access the trace_leaf_fn field, so it cannot
+ // break the trace frame linked list.
+ unsafe {
+ Context::try_with_current(|ctx| {
+ old_trace_leaf_fn = ctx.trace_leaf_fn.replace(Some(trace_leaf_dyn));
+ })
+ };
- // restore previous on drop. This is ensures state remains consistent
- // even if the trace_leaf function panics
let _restore = defer(move || {
- if let Some(previous) = previous {
- Context::try_with_current_trace_leaf_fn(|current| current.set(previous));
- }
+ // This ensures that `trace_leaf_fn` cannot be accessed after this call returns.
+ //
+ // SAFETY: This call can only access the trace_leaf_fn field, so it cannot
+ // break the trace frame linked list.
+ unsafe {
+ Context::try_with_current(|ctx| {
+ ctx.trace_leaf_fn.set(old_trace_leaf_fn);
+ })
+ };
});
f()
@@ -249,13 +295,14 @@ impl Trace {
// internal implementation details of this crate).
#[inline(never)]
pub(crate) fn trace_leaf(cx: &mut task::Context<'_>) -> Poll<()> {
- let trace_leaf_fn = Context::try_with_current_trace_leaf_fn(|cell| cell.get()).flatten();
- if let Some(trace_leaf_fn) = trace_leaf_fn {
+ let root_addr = Context::current_frame_addr();
+
+ let ret = Context::try_with_current_trace_leaf_fn(|leaf_fn| {
let meta = TraceMeta {
- root_addr: Context::current_frame_addr(),
+ root_addr,
trace_leaf_addr: trace_leaf as *const c_void,
};
- trace_leaf_fn(&meta);
+ leaf_fn(&meta);
// Use the same logic that `yield_now` uses to send out wakeups after
// the task yields.
@@ -268,10 +315,11 @@ pub(crate) fn trace_leaf(cx: &mut task::Context<'_>) -> Poll<()> {
}
}
});
+ });
- Poll::Pending
- } else {
- Poll::Ready(())
+ match ret {
+ Some(()) => Poll::Pending,
+ None => Poll::Ready(()),
}
}
diff --git a/tokio/src/runtime/task/trace/trace_impl.rs b/tokio/src/runtime/task/trace/trace_impl.rs
index 5bf90792568..3197cca8fe7 100644
--- a/tokio/src/runtime/task/trace/trace_impl.rs
+++ b/tokio/src/runtime/task/trace/trace_impl.rs
@@ -2,73 +2,42 @@
//!
//! This implementation may eventually be extracted into a separate `tokio-taskdump` crate.
-use std::{cell::Cell, ptr};
+use std::ptr;
use crate::runtime::task::trace::{trace_with, Trace, TraceMeta};
-use super::defer;
-
-/// Thread local state used to communicate between calling the trace and the interior `trace_leaf` function
-struct TraceContext {
- collector: Cell>,
-}
-
-thread_local! {
- static TRACE_CONTEXT: TraceContext = const {
- TraceContext {
- collector: Cell::new(None),
- }
- };
-}
-
/// Capture using the default `backtrace::trace`-based implementation.
#[inline(never)]
pub(super) fn capture(f: F) -> (R, Trace)
where
F: FnOnce() -> R,
{
- let collector = Trace::empty();
-
- let previous = TRACE_CONTEXT.with(|state| state.collector.replace(Some(collector)));
+ let mut trace = Trace::empty();
- // restore previous collector on drop even if the callback panics
- let _restore = defer(move || {
- TRACE_CONTEXT.with(|state| state.collector.set(previous));
- });
+ let result = trace_with(f, |meta| trace_leaf(meta, &mut trace));
- let result = trace_with(f, trace_leaf);
-
- // take the collector before _restore runs
- let collector = TRACE_CONTEXT.with(|state| state.collector.take()).unwrap();
-
- (result, collector)
+ (result, trace)
}
-/// Capture a backtrace via `backtrace::trace` and collect it into `STATE`
-#[inline(never)]
-pub(crate) fn trace_leaf(meta: &TraceMeta) {
- TRACE_CONTEXT.with(|state| {
- if let Some(mut collector) = state.collector.take() {
- let mut frames: Vec = vec![];
- let mut above_leaf = false;
-
- if let Some(root_addr) = meta.root_addr {
- backtrace::trace(|frame| {
- let below_root = !ptr::eq(frame.symbol_address(), root_addr);
+/// Capture a backtrace via `backtrace::trace` and collect it into `trace`.
+pub(crate) fn trace_leaf(meta: &TraceMeta, trace: &mut Trace) {
+ let mut frames: Vec = vec![];
+ let mut above_leaf = false;
- if above_leaf && below_root {
- frames.push(frame.to_owned().into());
- }
+ if let Some(root_addr) = meta.root_addr {
+ backtrace::trace(|frame| {
+ let below_root = !ptr::eq(frame.symbol_address(), root_addr);
- if ptr::eq(frame.symbol_address(), meta.trace_leaf_addr) {
- above_leaf = true;
- }
+ if above_leaf && below_root {
+ frames.push(frame.to_owned().into());
+ }
- below_root
- });
+ if ptr::eq(frame.symbol_address(), meta.trace_leaf_addr) {
+ above_leaf = true;
}
- collector.push_backtrace(frames);
- state.collector.set(Some(collector));
- }
- });
+
+ below_root
+ });
+ }
+ trace.push_backtrace(frames);
}
diff --git a/tokio/tests/task_trace_self.rs b/tokio/tests/task_trace_self.rs
index 1fe6056686b..e3e0c479132 100644
--- a/tokio/tests/task_trace_self.rs
+++ b/tokio/tests/task_trace_self.rs
@@ -109,48 +109,41 @@ async fn task_trace_self() {
/// Collect frames between `trace_leaf_for_test` and `root_addr` using
/// `backtrace::trace`, resolve them, and store pretty-printed symbol names
-/// (with compiler hashes stripped) into `TRACE_WITH_LOG`.
+/// (with compiler hashes stripped) into `logs`.
#[inline(never)]
-fn trace_leaf_for_test(meta: &TraceMeta) {
- TRACE_WITH_LOG.with(|log| {
- let mut frames: Vec = vec![];
- let mut above_leaf = false;
+fn trace_leaf_for_test(meta: &TraceMeta, log: &mut Vec>) {
+ let mut frames: Vec = vec![];
+ let mut above_leaf = false;
- if let Some(root_addr) = meta.root_addr {
- backtrace::trace(|frame| {
- let below_root = !ptr::eq(frame.symbol_address(), root_addr);
+ if let Some(root_addr) = meta.root_addr {
+ backtrace::trace(|frame| {
+ let below_root = !ptr::eq(frame.symbol_address(), root_addr);
- if above_leaf && below_root {
- frames.push(frame.to_owned().into());
- }
+ if above_leaf && below_root {
+ frames.push(frame.to_owned().into());
+ }
- if ptr::eq(frame.symbol_address(), meta.trace_leaf_addr) {
- above_leaf = true;
- }
+ if ptr::eq(frame.symbol_address(), meta.trace_leaf_addr) {
+ above_leaf = true;
+ }
- below_root
- });
- }
+ below_root
+ });
+ }
- // Resolve frames into human-readable symbol names with hashes stripped.
- let mut bt = backtrace::Backtrace::from(frames);
- bt.resolve();
- let mut names = vec![];
- for frame in bt.frames() {
- for symbol in frame.symbols() {
- if let Some(name) = symbol.name() {
- names.push(strip_symbol_hash(&format!("{name}")).to_owned());
- }
+ // Resolve frames into human-readable symbol names with hashes stripped.
+ let mut bt = backtrace::Backtrace::from(frames);
+ bt.resolve();
+ let mut names = vec![];
+ for frame in bt.frames() {
+ for symbol in frame.symbols() {
+ if let Some(name) = symbol.name() {
+ names.push(strip_symbol_hash(&format!("{name}")).to_owned());
}
}
+ }
- log.borrow_mut().push(names);
- });
-}
-
-thread_local! {
- static TRACE_WITH_LOG: std::cell::RefCell>> =
- const { std::cell::RefCell::new(vec![]) };
+ log.push(names);
}
/// Strip the trailing `::h` hash that rustc appends to symbol names.
@@ -197,13 +190,17 @@ impl Future for TaskDump {
Poll::Pending => {}
};
+ let mut logs = Vec::new();
+
// Tracing poll with a noop waker. If the future is at a yield
// point, trace_leaf fires our callback and returns Pending. We discard
// the result — this poll is purely for capturing the backtrace.
let noop = futures::task::noop_waker();
let mut noop_cx = Context::from_waker(&noop);
- let logs = this.logs.clone();
- let trace_poll = trace_with(|| this.f.as_mut().poll(&mut noop_cx), trace_leaf_for_test);
+ let trace_poll = trace_with(
+ || this.f.as_mut().poll(&mut noop_cx),
+ |meta| trace_leaf_for_test(meta, &mut logs),
+ );
// trace should always produce poll pending
assert!(
matches!(trace_poll, Poll::Pending),
@@ -211,11 +208,7 @@ impl Future for TaskDump {
);
// Drain any frames captured by trace_leaf_for_test into our log.
- TRACE_WITH_LOG.with(|tl| {
- let mut tl = tl.borrow_mut();
- let mut dest = logs.lock().unwrap();
- dest.append(&mut tl);
- });
+ this.logs.lock().unwrap().extend(logs);
Poll::Pending
}
}