From 8963b5d4cf382ce3ebb332f7ddc816dd5004a35f Mon Sep 17 00:00:00 2001 From: noah Date: Mon, 3 Feb 2025 21:41:42 -0600 Subject: [PATCH] rt: overhaul task hooks This change overhauls the entire task hooks system so that users can propagate arbitrary information between task hook invocations and pass context data between the hook "harnesses" for parent and child tasks at time of spawn. This is intended to be significantly more extensible and long-term maintainable than the current task hooks system, and should ultimately be much easier to stabilize. --- tokio/src/lib.rs | 3 - tokio/src/runtime/blocking/pool.rs | 5 + tokio/src/runtime/blocking/schedule.rs | 24 +- tokio/src/runtime/builder.rs | 237 +-------- tokio/src/runtime/config.rs | 19 +- tokio/src/runtime/context.rs | 68 ++- tokio/src/runtime/handle.rs | 23 +- tokio/src/runtime/local_runtime/runtime.rs | 4 +- tokio/src/runtime/mod.rs | 7 +- tokio/src/runtime/runtime.rs | 9 + .../runtime/scheduler/current_thread/mod.rs | 156 ++++-- tokio/src/runtime/scheduler/mod.rs | 40 +- .../runtime/scheduler/multi_thread/handle.rs | 82 +++- .../runtime/scheduler/multi_thread/worker.rs | 43 +- .../scheduler/multi_thread_alt/handle.rs | 72 ++- .../scheduler/multi_thread_alt/worker.rs | 24 +- tokio/src/runtime/task/core.rs | 28 +- tokio/src/runtime/task/harness.rs | 39 +- tokio/src/runtime/task/list.rs | 25 +- tokio/src/runtime/task/mod.rs | 36 +- tokio/src/runtime/task/raw.rs | 49 +- tokio/src/runtime/task_hooks.rs | 81 --- tokio/src/runtime/task_hooks/mod.rs | 99 ++++ tokio/src/runtime/tests/mod.rs | 22 +- tokio/src/runtime/tests/queue.rs | 17 +- tokio/src/runtime/tests/task.rs | 19 +- tokio/src/task/builder.rs | 29 +- tokio/src/task/join_set.rs | 10 +- tokio/src/task/local.rs | 41 +- tokio/src/task/mod.rs | 4 + tokio/src/task/spawn.rs | 36 +- tokio/tests/rt_poll_callbacks.rs | 128 ----- tokio/tests/task_builder.rs | 30 +- tokio/tests/task_hooks.rs | 462 ++++++++++++++++-- tokio/tests/tracing_task.rs | 8 +- 35 files changed, 1260 insertions(+), 719 deletions(-) delete mode 100644 tokio/src/runtime/task_hooks.rs create mode 100644 tokio/src/runtime/task_hooks/mod.rs delete mode 100644 tokio/tests/rt_poll_callbacks.rs diff --git a/tokio/src/lib.rs b/tokio/src/lib.rs index 6b0f48bd105..cf287b0dac8 100644 --- a/tokio/src/lib.rs +++ b/tokio/src/lib.rs @@ -351,10 +351,7 @@ //! - [`task::Builder`] //! - Some methods on [`task::JoinSet`] //! - [`runtime::RuntimeMetrics`] -//! - [`runtime::Builder::on_task_spawn`] -//! - [`runtime::Builder::on_task_terminate`] //! - [`runtime::Builder::unhandled_panic`] -//! - [`runtime::TaskMeta`] //! //! This flag enables **unstable** features. The public API of these features //! may break in 1.x releases. To enable these features, the `--cfg diff --git a/tokio/src/runtime/blocking/pool.rs b/tokio/src/runtime/blocking/pool.rs index 23180dc5245..990a1fd4a7b 100644 --- a/tokio/src/runtime/blocking/pool.rs +++ b/tokio/src/runtime/blocking/pool.rs @@ -375,10 +375,15 @@ impl Spawner { F: FnOnce() -> R + Send + 'static, R: Send + 'static, { + // let parent = with_c let id = task::Id::next(); let fut = blocking_task::>(BlockingTask::new(func), spawn_meta, id.as_u64()); + #[cfg(tokio_unstable)] + let (task, handle) = task::unowned(fut, BlockingSchedule::new(rt), id, None); + + #[cfg(not(tokio_unstable))] let (task, handle) = task::unowned(fut, BlockingSchedule::new(rt), id); let spawned = self.spawn_task(Task::new(task, is_mandatory), rt); diff --git a/tokio/src/runtime/blocking/schedule.rs b/tokio/src/runtime/blocking/schedule.rs index 875bf1c314e..0d658061cca 100644 --- a/tokio/src/runtime/blocking/schedule.rs +++ b/tokio/src/runtime/blocking/schedule.rs @@ -1,7 +1,9 @@ #[cfg(feature = "test-util")] use crate::runtime::scheduler; -use crate::runtime::task::{self, Task, TaskHarnessScheduleHooks}; +use crate::runtime::task::{self, Task}; use crate::runtime::Handle; +#[cfg(tokio_unstable)] +use crate::runtime::{OptionalTaskHooksFactory, OptionalTaskHooksFactoryRef}; /// `task::Schedule` implementation that does nothing (except some bookkeeping /// in test-util builds). This is unique to the blocking scheduler as tasks @@ -12,7 +14,8 @@ use crate::runtime::Handle; pub(crate) struct BlockingSchedule { #[cfg(feature = "test-util")] handle: Handle, - hooks: TaskHarnessScheduleHooks, + #[cfg(tokio_unstable)] + hooks_factory: OptionalTaskHooksFactory, } impl BlockingSchedule { @@ -33,9 +36,8 @@ impl BlockingSchedule { BlockingSchedule { #[cfg(feature = "test-util")] handle: handle.clone(), - hooks: TaskHarnessScheduleHooks { - task_terminate_callback: handle.inner.hooks().task_terminate_callback.clone(), - }, + #[cfg(tokio_unstable)] + hooks_factory: handle.inner.hooks_factory(), } } } @@ -62,9 +64,13 @@ impl task::Schedule for BlockingSchedule { unreachable!(); } - fn hooks(&self) -> TaskHarnessScheduleHooks { - TaskHarnessScheduleHooks { - task_terminate_callback: self.hooks.task_terminate_callback.clone(), - } + #[cfg(tokio_unstable)] + fn hooks_factory(&self) -> OptionalTaskHooksFactory { + self.hooks_factory.clone() + } + + #[cfg(tokio_unstable)] + fn hooks_factory_ref(&self) -> OptionalTaskHooksFactoryRef<'_> { + self.hooks_factory.as_ref().map(AsRef::as_ref) } } diff --git a/tokio/src/runtime/builder.rs b/tokio/src/runtime/builder.rs index f4bf3a25921..5c62d0ed9ec 100644 --- a/tokio/src/runtime/builder.rs +++ b/tokio/src/runtime/builder.rs @@ -1,15 +1,19 @@ #![cfg_attr(loom, allow(unused_imports))] +use crate::runtime::blocking::BlockingPool; use crate::runtime::handle::Handle; -use crate::runtime::{blocking, driver, Callback, HistogramBuilder, Runtime, TaskCallback}; +use crate::runtime::scheduler::CurrentThread; +use crate::runtime::{blocking, driver, Callback, HistogramBuilder, Runtime}; #[cfg(tokio_unstable)] -use crate::runtime::{metrics::HistogramConfiguration, LocalOptions, LocalRuntime, TaskMeta}; +use crate::runtime::{ + metrics::HistogramConfiguration, LocalOptions, LocalRuntime, OptionalTaskHooksFactory, + TaskHookHarnessFactory, +}; use crate::util::rand::{RngSeed, RngSeedGenerator}; - -use crate::runtime::blocking::BlockingPool; -use crate::runtime::scheduler::CurrentThread; use std::fmt; use std::io; +#[cfg(tokio_unstable)] +use std::sync::Arc; use std::thread::ThreadId; use std::time::Duration; @@ -85,19 +89,8 @@ pub struct Builder { /// To run after each thread is unparked. pub(super) after_unpark: Option, - /// To run before each task is spawned. - pub(super) before_spawn: Option, - - /// To run before each poll #[cfg(tokio_unstable)] - pub(super) before_poll: Option, - - /// To run after each poll - #[cfg(tokio_unstable)] - pub(super) after_poll: Option, - - /// To run after each task is terminated. - pub(super) after_termination: Option, + pub(super) task_hook_harness_factory: OptionalTaskHooksFactory, /// Customizable keep alive timeout for `BlockingPool` pub(super) keep_alive: Option, @@ -311,13 +304,8 @@ impl Builder { before_park: None, after_unpark: None, - before_spawn: None, - after_termination: None, - #[cfg(tokio_unstable)] - before_poll: None, - #[cfg(tokio_unstable)] - after_poll: None, + task_hook_harness_factory: None, keep_alive: None, @@ -706,188 +694,19 @@ impl Builder { self } - /// Executes function `f` just before a task is spawned. - /// - /// `f` is called within the Tokio context, so functions like - /// [`tokio::spawn`](crate::spawn) can be called, and may result in this callback being - /// invoked immediately. - /// - /// This can be used for bookkeeping or monitoring purposes. - /// - /// Note: There can only be one spawn callback for a runtime; calling this function more - /// than once replaces the last callback defined, rather than adding to it. - /// - /// This *does not* support [`LocalSet`](crate::task::LocalSet) at this time. - /// - /// **Note**: This is an [unstable API][unstable]. The public API of this type - /// may break in 1.x releases. See [the documentation on unstable - /// features][unstable] for details. - /// - /// [unstable]: crate#unstable-features - /// - /// # Examples - /// - /// ``` - /// # use tokio::runtime; - /// # pub fn main() { - /// let runtime = runtime::Builder::new_current_thread() - /// .on_task_spawn(|_| { - /// println!("spawning task"); - /// }) - /// .build() - /// .unwrap(); + /// Factory method for producing "fallback" task hook harnesses. /// - /// runtime.block_on(async { - /// tokio::task::spawn(std::future::ready(())); - /// - /// for _ in 0..64 { - /// tokio::task::yield_now().await; - /// } - /// }) - /// # } - /// ``` + /// The order of operations for assigning the hook harness for a task are as follows: + /// 1. [`crate::task::spawn_with_hooks`], if used. + /// 2. [`crate::runtime::task_hooks::TaskHookHarnessFactory`], if it returns something other than [Option::None]. + /// 3. This function. #[cfg(all(not(loom), tokio_unstable))] #[cfg_attr(docsrs, doc(cfg(tokio_unstable)))] - pub fn on_task_spawn(&mut self, f: F) -> &mut Self - where - F: Fn(&TaskMeta<'_>) + Send + Sync + 'static, - { - self.before_spawn = Some(std::sync::Arc::new(f)); - self - } - - /// Executes function `f` just before a task is polled - /// - /// `f` is called within the Tokio context, so functions like - /// [`tokio::spawn`](crate::spawn) can be called, and may result in this callback being - /// invoked immediately. - /// - /// **Note**: This is an [unstable API][unstable]. The public API of this type - /// may break in 1.x releases. See [the documentation on unstable - /// features][unstable] for details. - /// - /// [unstable]: crate#unstable-features - /// - /// # Examples - /// - /// ``` - /// # use std::sync::{atomic::AtomicUsize, Arc}; - /// # use tokio::task::yield_now; - /// # pub fn main() { - /// let poll_start_counter = Arc::new(AtomicUsize::new(0)); - /// let poll_start = poll_start_counter.clone(); - /// let rt = tokio::runtime::Builder::new_multi_thread() - /// .enable_all() - /// .on_before_task_poll(move |meta| { - /// println!("task {} is about to be polled", meta.id()) - /// }) - /// .build() - /// .unwrap(); - /// let task = rt.spawn(async { - /// yield_now().await; - /// }); - /// let _ = rt.block_on(task); - /// - /// # } - /// ``` - #[cfg(tokio_unstable)] - pub fn on_before_task_poll(&mut self, f: F) -> &mut Self + pub fn hook_harness_factory(&mut self, hooks: T) -> &mut Self where - F: Fn(&TaskMeta<'_>) + Send + Sync + 'static, + T: TaskHookHarnessFactory + Send + Sync + 'static, { - self.before_poll = Some(std::sync::Arc::new(f)); - self - } - - /// Executes function `f` just after a task is polled - /// - /// `f` is called within the Tokio context, so functions like - /// [`tokio::spawn`](crate::spawn) can be called, and may result in this callback being - /// invoked immediately. - /// - /// **Note**: This is an [unstable API][unstable]. The public API of this type - /// may break in 1.x releases. See [the documentation on unstable - /// features][unstable] for details. - /// - /// [unstable]: crate#unstable-features - /// - /// # Examples - /// - /// ``` - /// # use std::sync::{atomic::AtomicUsize, Arc}; - /// # use tokio::task::yield_now; - /// # pub fn main() { - /// let poll_stop_counter = Arc::new(AtomicUsize::new(0)); - /// let poll_stop = poll_stop_counter.clone(); - /// let rt = tokio::runtime::Builder::new_multi_thread() - /// .enable_all() - /// .on_after_task_poll(move |meta| { - /// println!("task {} completed polling", meta.id()); - /// }) - /// .build() - /// .unwrap(); - /// let task = rt.spawn(async { - /// yield_now().await; - /// }); - /// let _ = rt.block_on(task); - /// - /// # } - /// ``` - #[cfg(tokio_unstable)] - pub fn on_after_task_poll(&mut self, f: F) -> &mut Self - where - F: Fn(&TaskMeta<'_>) + Send + Sync + 'static, - { - self.after_poll = Some(std::sync::Arc::new(f)); - self - } - - /// Executes function `f` just after a task is terminated. - /// - /// `f` is called within the Tokio context, so functions like - /// [`tokio::spawn`](crate::spawn) can be called. - /// - /// This can be used for bookkeeping or monitoring purposes. - /// - /// Note: There can only be one task termination callback for a runtime; calling this - /// function more than once replaces the last callback defined, rather than adding to it. - /// - /// This *does not* support [`LocalSet`](crate::task::LocalSet) at this time. - /// - /// **Note**: This is an [unstable API][unstable]. The public API of this type - /// may break in 1.x releases. See [the documentation on unstable - /// features][unstable] for details. - /// - /// [unstable]: crate#unstable-features - /// - /// # Examples - /// - /// ``` - /// # use tokio::runtime; - /// # pub fn main() { - /// let runtime = runtime::Builder::new_current_thread() - /// .on_task_terminate(|_| { - /// println!("killing task"); - /// }) - /// .build() - /// .unwrap(); - /// - /// runtime.block_on(async { - /// tokio::task::spawn(std::future::ready(())); - /// - /// for _ in 0..64 { - /// tokio::task::yield_now().await; - /// } - /// }) - /// # } - /// ``` - #[cfg(all(not(loom), tokio_unstable))] - #[cfg_attr(docsrs, doc(cfg(tokio_unstable)))] - pub fn on_task_terminate(&mut self, f: F) -> &mut Self - where - F: Fn(&TaskMeta<'_>) + Send + Sync + 'static, - { - self.after_termination = Some(std::sync::Arc::new(f)); + self.task_hook_harness_factory = Some(Arc::new(hooks)); self } @@ -1508,12 +1327,8 @@ impl Builder { Config { before_park: self.before_park.clone(), after_unpark: self.after_unpark.clone(), - before_spawn: self.before_spawn.clone(), #[cfg(tokio_unstable)] - before_poll: self.before_poll.clone(), - #[cfg(tokio_unstable)] - after_poll: self.after_poll.clone(), - after_termination: self.after_termination.clone(), + task_hook_factory: self.task_hook_harness_factory.clone(), global_queue_interval: self.global_queue_interval, event_interval: self.event_interval, local_queue_capacity: self.local_queue_capacity, @@ -1662,12 +1477,8 @@ cfg_rt_multi_thread! { Config { before_park: self.before_park.clone(), after_unpark: self.after_unpark.clone(), - before_spawn: self.before_spawn.clone(), - #[cfg(tokio_unstable)] - before_poll: self.before_poll.clone(), #[cfg(tokio_unstable)] - after_poll: self.after_poll.clone(), - after_termination: self.after_termination.clone(), + task_hook_factory: self.task_hook_harness_factory.clone(), global_queue_interval: self.global_queue_interval, event_interval: self.event_interval, local_queue_capacity: self.local_queue_capacity, @@ -1715,12 +1526,8 @@ cfg_rt_multi_thread! { Config { before_park: self.before_park.clone(), after_unpark: self.after_unpark.clone(), - before_spawn: self.before_spawn.clone(), - after_termination: self.after_termination.clone(), - #[cfg(tokio_unstable)] - before_poll: self.before_poll.clone(), #[cfg(tokio_unstable)] - after_poll: self.after_poll.clone(), + task_hook_factory: self.task_hook_harness_factory.clone(), global_queue_interval: self.global_queue_interval, event_interval: self.event_interval, local_queue_capacity: self.local_queue_capacity, diff --git a/tokio/src/runtime/config.rs b/tokio/src/runtime/config.rs index 43ce5aebd63..dd79289c85e 100644 --- a/tokio/src/runtime/config.rs +++ b/tokio/src/runtime/config.rs @@ -2,7 +2,10 @@ any(not(all(tokio_unstable, feature = "full")), target_family = "wasm"), allow(dead_code) )] -use crate::runtime::{Callback, TaskCallback}; + +use crate::runtime::Callback; +#[cfg(tokio_unstable)] +use crate::runtime::OptionalTaskHooksFactory; use crate::util::RngSeedGenerator; pub(crate) struct Config { @@ -21,19 +24,9 @@ pub(crate) struct Config { /// Callback for a worker unparking itself pub(crate) after_unpark: Option, - /// To run before each task is spawned. - pub(crate) before_spawn: Option, - - /// To run after each task is terminated. - pub(crate) after_termination: Option, - - /// To run before each poll - #[cfg(tokio_unstable)] - pub(crate) before_poll: Option, - - /// To run after each poll + /// Called on task spawn to generate the attached task hook harness. #[cfg(tokio_unstable)] - pub(crate) after_poll: Option, + pub(crate) task_hook_factory: OptionalTaskHooksFactory, /// The multi-threaded scheduler includes a per-worker LIFO slot used to /// store the last scheduled task. This can improve certain usage patterns, diff --git a/tokio/src/runtime/context.rs b/tokio/src/runtime/context.rs index e8f17bb374a..c0fcc64aa15 100644 --- a/tokio/src/runtime/context.rs +++ b/tokio/src/runtime/context.rs @@ -1,10 +1,14 @@ +#[cfg(all(feature = "rt", tokio_unstable))] +use crate::loom::cell::UnsafeCell; use crate::loom::thread::AccessError; +#[cfg(all(feature = "rt", tokio_unstable))] +use crate::runtime::{OptionalTaskHooksMut, OptionalTaskHooksWeak, TaskHookHarness}; use crate::task::coop; - -use std::cell::Cell; - #[cfg(any(feature = "rt", feature = "macros", feature = "time"))] use crate::util::rand::FastRand; +use std::cell::Cell; +#[cfg(all(feature = "rt", tokio_unstable))] +use std::ptr::NonNull; cfg_rt! { mod blocking; @@ -49,6 +53,10 @@ struct Context { #[cfg(feature = "rt")] current_task_id: Cell>, + /// Tracks the current set of task hooks, + #[cfg(all(feature = "rt", tokio_unstable))] + current_task_hooks: OptionalTaskHooksWeak, + /// Tracks if the current thread is currently driving a runtime. /// Note, that if this is set to "entered", the current scheduler /// handle may not reference the runtime currently executing. This @@ -92,6 +100,9 @@ tokio_thread_local! { #[cfg(feature = "rt")] current_task_id: Cell::new(None), + #[cfg(all(feature = "rt", tokio_unstable))] + current_task_hooks: UnsafeCell::new(None), + // Tracks if the current thread is currently driving a runtime. // Note, that if this is set to "entered", the current scheduler // handle may not reference the runtime currently executing. This @@ -139,6 +150,16 @@ pub(crate) fn budget(f: impl FnOnce(&Cell) -> R) -> Result>) -> Result { + CONTEXT.try_with(|ctx| { + ctx.current_task_hooks.with_mut(|x| { + unsafe { + *x = hooks; + } + }) + })?; + + Ok(SetTaskHooksGuard) + } + + #[track_caller] + #[cfg(tokio_unstable)] + pub(super) fn clear_task_hooks() -> Result<(), AccessError> { + CONTEXT.try_with(|ctx| { + ctx.current_task_hooks.with_mut(|x| { + unsafe { + *x = None; + } + }) + })?; + + Ok(()) + } + + #[track_caller] + #[cfg(tokio_unstable)] + pub(super) fn with_task_hooks(f: impl FnOnce(OptionalTaskHooksMut<'_>) -> R) -> Result { + CONTEXT.try_with(|ctx| { + ctx.current_task_hooks.with_mut(|ptr| { + let hooks = unsafe { &mut *ptr }; + unsafe { + f(hooks.as_mut().map(|x| x.as_mut())) + } + }) + }) + } + #[track_caller] pub(crate) fn defer(waker: &Waker) { with_scheduler(|maybe_scheduler| { diff --git a/tokio/src/runtime/handle.rs b/tokio/src/runtime/handle.rs index 91f13d6c2ed..371993c9722 100644 --- a/tokio/src/runtime/handle.rs +++ b/tokio/src/runtime/handle.rs @@ -1,5 +1,5 @@ #[cfg(tokio_unstable)] -use crate::runtime; +use crate::runtime::{self, OptionalTaskHooks}; use crate::runtime::{context, scheduler, RuntimeFlavor, RuntimeMetrics}; /// Handle to the runtime. @@ -191,6 +191,13 @@ impl Handle { F::Output: Send + 'static, { let fut_size = mem::size_of::(); + #[cfg(tokio_unstable)] + return if fut_size > BOX_FUTURE_THRESHOLD { + self.spawn_named(Box::pin(future), SpawnMeta::new_unnamed(fut_size), None) + } else { + self.spawn_named(future, SpawnMeta::new_unnamed(fut_size), None) + }; + #[cfg(not(tokio_unstable))] if fut_size > BOX_FUTURE_THRESHOLD { self.spawn_named(Box::pin(future), SpawnMeta::new_unnamed(fut_size)) } else { @@ -329,7 +336,12 @@ impl Handle { } #[track_caller] - pub(crate) fn spawn_named(&self, future: F, _meta: SpawnMeta<'_>) -> JoinHandle + pub(crate) fn spawn_named( + &self, + future: F, + _meta: SpawnMeta<'_>, + #[cfg(tokio_unstable)] parent: OptionalTaskHooks, + ) -> JoinHandle where F: Future + Send + 'static, F::Output: Send + 'static, @@ -345,6 +357,9 @@ impl Handle { let future = super::task::trace::Trace::root(future); #[cfg(all(tokio_unstable, feature = "tracing"))] let future = crate::util::trace::task(future, "task", _meta, id.as_u64()); + #[cfg(tokio_unstable)] + return self.inner.spawn(future, id, parent); + #[cfg(not(tokio_unstable))] self.inner.spawn(future, id) } @@ -354,6 +369,7 @@ impl Handle { &self, future: F, _meta: SpawnMeta<'_>, + #[cfg(tokio_unstable)] hooks_override: OptionalTaskHooks, ) -> JoinHandle where F: Future + 'static, @@ -370,6 +386,9 @@ impl Handle { let future = super::task::trace::Trace::root(future); #[cfg(all(tokio_unstable, feature = "tracing"))] let future = crate::util::trace::task(future, "task", _meta, id.as_u64()); + #[cfg(tokio_unstable)] + return self.inner.spawn_local(future, id, hooks_override); + #[cfg(not(tokio_unstable))] self.inner.spawn_local(future, id) } diff --git a/tokio/src/runtime/local_runtime/runtime.rs b/tokio/src/runtime/local_runtime/runtime.rs index 358a771956b..11fdc097b17 100644 --- a/tokio/src/runtime/local_runtime/runtime.rs +++ b/tokio/src/runtime/local_runtime/runtime.rs @@ -155,9 +155,9 @@ impl LocalRuntime { // safety: spawn_local can only be called from `LocalRuntime`, which this is unsafe { if std::mem::size_of::() > BOX_FUTURE_THRESHOLD { - self.handle.spawn_local_named(Box::pin(future), meta) + self.handle.spawn_local_named(Box::pin(future), meta, None) } else { - self.handle.spawn_local_named(future, meta) + self.handle.spawn_local_named(future, meta, None) } } } diff --git a/tokio/src/runtime/mod.rs b/tokio/src/runtime/mod.rs index 78a0114f48e..026bdd7ef68 100644 --- a/tokio/src/runtime/mod.rs +++ b/tokio/src/runtime/mod.rs @@ -380,13 +380,10 @@ cfg_rt! { pub use dump::Dump; } - mod task_hooks; - pub(crate) use task_hooks::{TaskHooks, TaskCallback}; cfg_unstable! { - pub use task_hooks::TaskMeta; + mod task_hooks; + pub use task_hooks::*; } - #[cfg(not(tokio_unstable))] - pub(crate) use task_hooks::TaskMeta; mod handle; pub use handle::{EnterGuard, Handle, TryCurrentError}; diff --git a/tokio/src/runtime/runtime.rs b/tokio/src/runtime/runtime.rs index ea6bb247941..355dd22b6e4 100644 --- a/tokio/src/runtime/runtime.rs +++ b/tokio/src/runtime/runtime.rs @@ -245,6 +245,15 @@ impl Runtime { F::Output: Send + 'static, { let fut_size = mem::size_of::(); + #[cfg(tokio_unstable)] + return if fut_size > BOX_FUTURE_THRESHOLD { + self.handle + .spawn_named(Box::pin(future), SpawnMeta::new_unnamed(fut_size), None) + } else { + self.handle + .spawn_named(future, SpawnMeta::new_unnamed(fut_size), None) + }; + #[cfg(not(tokio_unstable))] if fut_size > BOX_FUTURE_THRESHOLD { self.handle .spawn_named(Box::pin(future), SpawnMeta::new_unnamed(fut_size)) diff --git a/tokio/src/runtime/scheduler/current_thread/mod.rs b/tokio/src/runtime/scheduler/current_thread/mod.rs index 13c803e0d71..905822c75c3 100644 --- a/tokio/src/runtime/scheduler/current_thread/mod.rs +++ b/tokio/src/runtime/scheduler/current_thread/mod.rs @@ -1,17 +1,19 @@ use crate::loom::sync::atomic::AtomicBool; use crate::loom::sync::Arc; +#[cfg(tokio_unstable)] +use crate::runtime::context::with_task_hooks; use crate::runtime::driver::{self, Driver}; use crate::runtime::scheduler::{self, Defer, Inject}; -use crate::runtime::task::{ - self, JoinHandle, OwnedTasks, Schedule, Task, TaskHarnessScheduleHooks, -}; +use crate::runtime::task::{self, JoinHandle, OwnedTasks, Schedule, Task}; +use crate::runtime::{blocking, context, Config, MetricsBatch, SchedulerMetrics, WorkerMetrics}; +#[cfg(tokio_unstable)] use crate::runtime::{ - blocking, context, Config, MetricsBatch, SchedulerMetrics, TaskHooks, TaskMeta, WorkerMetrics, + OnChildTaskSpawnContext, OnTopLevelTaskSpawnContext, OptionalTaskHooks, + OptionalTaskHooksFactory, OptionalTaskHooksFactoryRef, }; use crate::sync::notify::Notify; use crate::util::atomic_cell::AtomicCell; use crate::util::{waker_ref, RngSeedGenerator, Wake, WakerRef}; - use std::cell::RefCell; use std::collections::VecDeque; use std::future::{poll_fn, Future}; @@ -20,7 +22,7 @@ use std::task::Poll::{Pending, Ready}; use std::task::Waker; use std::thread::ThreadId; use std::time::Duration; -use std::{fmt, thread}; +use std::{fmt, panic, thread}; /// Executes tasks on the current thread pub(crate) struct CurrentThread { @@ -47,7 +49,8 @@ pub(crate) struct Handle { pub(crate) seed_generator: RngSeedGenerator, /// User-supplied hooks to invoke for things - pub(crate) task_hooks: TaskHooks, + #[cfg(tokio_unstable)] + pub(crate) task_hooks: OptionalTaskHooksFactory, /// If this is a `LocalRuntime`, flags the owning thread ID. pub(crate) local_tid: Option, @@ -142,14 +145,8 @@ impl CurrentThread { .unwrap_or(DEFAULT_GLOBAL_QUEUE_INTERVAL); let handle = Arc::new(Handle { - task_hooks: TaskHooks { - task_spawn_callback: config.before_spawn.clone(), - task_terminate_callback: config.after_termination.clone(), - #[cfg(tokio_unstable)] - before_poll_callback: config.before_poll.clone(), - #[cfg(tokio_unstable)] - after_poll_callback: config.after_poll.clone(), - }, + #[cfg(tokio_unstable)] + task_hooks: config.task_hook_factory.clone(), shared: Shared { inject: Inject::new(), owned: OwnedTasks::new(1), @@ -448,19 +445,61 @@ impl Handle { pub(crate) fn spawn( me: &Arc, future: F, - id: crate::runtime::task::Id, + id: task::Id, + #[cfg(tokio_unstable)] hooks_override: OptionalTaskHooks, ) -> JoinHandle where F: crate::future::Future + Send + 'static, F::Output: Send + 'static, { - let (handle, notified) = me.shared.owned.bind(future, me.clone(), id); - - me.task_hooks.spawn(&TaskMeta { - id, - _phantom: Default::default(), + // preference order for hook selection: + // 1. "hook override" - comes from builder + // 2. parent task's hook + // 3. runtime hook factory + #[cfg(tokio_unstable)] + let hooks = hooks_override.or_else(|| { + with_task_hooks(|parent| { + parent + .map(|parent| { + if let Ok(r) = panic::catch_unwind(panic::AssertUnwindSafe(|| { + parent.on_child_spawn(&mut OnChildTaskSpawnContext { + id, + _phantom: Default::default(), + }) + })) { + r + } else { + None + } + }) + .flatten() + }) + .ok() + .flatten() + .or_else(|| { + if let Some(hooks) = me.hooks_factory_ref() { + if let Ok(r) = panic::catch_unwind(panic::AssertUnwindSafe(|| { + hooks.on_top_level_spawn(&mut OnTopLevelTaskSpawnContext { + id, + _phantom: Default::default(), + }) + })) { + r + } else { + None + } + } else { + None + } + }) }); + #[cfg(tokio_unstable)] + let (handle, notified) = me.shared.owned.bind(future, me.clone(), id, hooks); + + #[cfg(not(tokio_unstable))] + let (handle, notified) = me.shared.owned.bind(future, me.clone(), id); + if let Some(notified) = notified { me.schedule(notified); } @@ -477,18 +516,62 @@ impl Handle { pub(crate) unsafe fn spawn_local( me: &Arc, future: F, - id: crate::runtime::task::Id, + id: task::Id, + #[cfg(tokio_unstable)] hooks_override: OptionalTaskHooks, ) -> JoinHandle where F: crate::future::Future + 'static, F::Output: 'static, { - let (handle, notified) = me.shared.owned.bind_local(future, me.clone(), id); + // preference order for hook selection: + // 1. "hook override" - comes from builder + // 2. parent task's hook + // 3. runtime hook factory + #[cfg(tokio_unstable)] + let hooks = hooks_override.or_else(|| { + with_task_hooks(|parent| { + parent + .map(|parent| { + if let Ok(r) = panic::catch_unwind(panic::AssertUnwindSafe(|| { + parent.on_child_spawn(&mut OnChildTaskSpawnContext { + id, + _phantom: Default::default(), + }) + })) { + r + } else { + None + } + }) + .flatten() + }) + .ok() + .flatten() + .or_else(|| { + if let Some(hooks) = me.hooks_factory_ref() { + if let Ok(r) = panic::catch_unwind(panic::AssertUnwindSafe(|| { + hooks.on_top_level_spawn(&mut OnTopLevelTaskSpawnContext { + id, + _phantom: Default::default(), + }) + })) { + r + } else { + None + } + } else { + None + } + }) + }); - me.task_hooks.spawn(&TaskMeta { + let (handle, notified) = me.shared.owned.bind_local( + future, + me.clone(), id, - _phantom: Default::default(), - }); + #[cfg(tokio_unstable)] + hooks, + ); if let Some(notified) = notified { me.schedule(notified); @@ -654,10 +737,14 @@ impl Schedule for Arc { }); } - fn hooks(&self) -> TaskHarnessScheduleHooks { - TaskHarnessScheduleHooks { - task_terminate_callback: self.task_hooks.task_terminate_callback.clone(), - } + #[cfg(tokio_unstable)] + fn hooks_factory(&self) -> OptionalTaskHooksFactory { + self.task_hooks.clone() + } + + #[cfg(tokio_unstable)] + fn hooks_factory_ref(&self) -> OptionalTaskHooksFactoryRef<'_> { + self.task_hooks.as_ref().map(AsRef::as_ref) } cfg_unstable! { @@ -770,17 +857,8 @@ impl CoreGuard<'_> { let task = context.handle.shared.owned.assert_owner(task); - #[cfg(tokio_unstable)] - let task_id = task.task_id(); - let (c, ()) = context.run_task(core, || { - #[cfg(tokio_unstable)] - context.handle.task_hooks.poll_start_callback(task_id); - task.run(); - - #[cfg(tokio_unstable)] - context.handle.task_hooks.poll_stop_callback(task_id); }); core = c; diff --git a/tokio/src/runtime/scheduler/mod.rs b/tokio/src/runtime/scheduler/mod.rs index e0a1b20b5bc..007b5389a18 100644 --- a/tokio/src/runtime/scheduler/mod.rs +++ b/tokio/src/runtime/scheduler/mod.rs @@ -7,8 +7,6 @@ cfg_rt! { pub(crate) mod inject; pub(crate) use inject::Inject; - - use crate::runtime::TaskHooks; } cfg_rt_multi_thread! { @@ -28,6 +26,10 @@ cfg_rt_multi_thread! { } use crate::runtime::driver; +#[cfg(all(feature = "rt", tokio_unstable))] +use crate::runtime::task::Schedule; +#[cfg(all(feature = "rt", tokio_unstable))] +use crate::runtime::{OptionalTaskHooks, OptionalTaskHooksFactory}; #[derive(Debug, Clone)] pub(crate) enum Handle { @@ -138,11 +140,27 @@ cfg_rt! { } } - pub(crate) fn spawn(&self, future: F, id: Id) -> JoinHandle + pub(crate) fn spawn(&self, + future: F, + id: Id, + #[cfg(tokio_unstable)] + hooks_override: OptionalTaskHooks + ) -> JoinHandle where F: Future + Send + 'static, F::Output: Send + 'static, { + #[cfg(tokio_unstable)] + return match self { + Handle::CurrentThread(h) => current_thread::Handle::spawn(h, future, id, hooks_override), + + #[cfg(feature = "rt-multi-thread")] + Handle::MultiThread(h) => multi_thread::Handle::spawn(h, future, id, hooks_override), + + #[cfg(all(tokio_unstable, feature = "rt-multi-thread"))] + Handle::MultiThreadAlt(h) => multi_thread_alt::Handle::spawn(h, future, id, hooks_override), + }; + #[cfg(not(tokio_unstable))] match self { Handle::CurrentThread(h) => current_thread::Handle::spawn(h, future, id), @@ -160,12 +178,15 @@ cfg_rt! { /// This should only be called in `LocalRuntime` if the runtime has been verified to be owned /// by the current thread. #[allow(irrefutable_let_patterns)] - pub(crate) unsafe fn spawn_local(&self, future: F, id: Id) -> JoinHandle + pub(crate) unsafe fn spawn_local(&self, future: F, id: Id, #[cfg(tokio_unstable)] hooks_override: OptionalTaskHooks) -> JoinHandle where F: Future + 'static, F::Output: 'static, { if let Handle::CurrentThread(h) = self { + #[cfg(tokio_unstable)] + return current_thread::Handle::spawn_local(h, future, id, hooks_override); + #[cfg(not(tokio_unstable))] current_thread::Handle::spawn_local(h, future, id) } else { panic!("Only current_thread and LocalSet have spawn_local internals implemented") @@ -196,14 +217,9 @@ cfg_rt! { } } - pub(crate) fn hooks(&self) -> &TaskHooks { - match self { - Handle::CurrentThread(h) => &h.task_hooks, - #[cfg(feature = "rt-multi-thread")] - Handle::MultiThread(h) => &h.task_hooks, - #[cfg(all(tokio_unstable, feature = "rt-multi-thread"))] - Handle::MultiThreadAlt(h) => &h.task_hooks, - } + #[cfg(tokio_unstable)] + pub(crate) fn hooks_factory(&self) -> OptionalTaskHooksFactory { + match_flavor!(self, Handle(h) => h.hooks_factory()) } cfg_rt_multi_thread! { diff --git a/tokio/src/runtime/scheduler/multi_thread/handle.rs b/tokio/src/runtime/scheduler/multi_thread/handle.rs index 4075713c979..fa8973a8fab 100644 --- a/tokio/src/runtime/scheduler/multi_thread/handle.rs +++ b/tokio/src/runtime/scheduler/multi_thread/handle.rs @@ -1,14 +1,20 @@ use crate::future::Future; use crate::loom::sync::Arc; +#[cfg(tokio_unstable)] +use crate::runtime::context::with_task_hooks; use crate::runtime::scheduler::multi_thread::worker; +#[cfg(tokio_unstable)] +use crate::runtime::task::Schedule; use crate::runtime::{ blocking, driver, task::{self, JoinHandle}, - TaskHooks, TaskMeta, }; +#[cfg(tokio_unstable)] +use crate::runtime::{OnChildTaskSpawnContext, OnTopLevelTaskSpawnContext, OptionalTaskHooks}; use crate::util::RngSeedGenerator; - use std::fmt; +#[cfg(tokio_unstable)] +use std::panic; mod metrics; @@ -29,18 +35,24 @@ pub(crate) struct Handle { /// Current random number generator seed pub(crate) seed_generator: RngSeedGenerator, - - /// User-supplied hooks to invoke for things - pub(crate) task_hooks: TaskHooks, } impl Handle { /// Spawns a future onto the thread pool - pub(crate) fn spawn(me: &Arc, future: F, id: task::Id) -> JoinHandle + pub(crate) fn spawn( + me: &Arc, + future: F, + id: task::Id, + #[cfg(tokio_unstable)] hooks_override: OptionalTaskHooks, + ) -> JoinHandle where F: crate::future::Future + Send + 'static, F::Output: Send + 'static, { + #[cfg(tokio_unstable)] + return Self::bind_new_task(me, future, id, hooks_override); + + #[cfg(not(tokio_unstable))] Self::bind_new_task(me, future, id) } @@ -48,17 +60,65 @@ impl Handle { self.close(); } - pub(super) fn bind_new_task(me: &Arc, future: T, id: task::Id) -> JoinHandle + pub(super) fn bind_new_task( + me: &Arc, + future: T, + id: task::Id, + #[cfg(tokio_unstable)] hooks_override: OptionalTaskHooks, + ) -> JoinHandle where T: Future + Send + 'static, T::Output: Send + 'static, { - let (handle, notified) = me.shared.owned.bind(future, me.clone(), id); + // preference order for hook selection: + // 1. "hook override" - comes from builder + // 2. parent task's hook + // 3. runtime hook factory + #[cfg(tokio_unstable)] + let hooks = hooks_override.or_else(|| { + with_task_hooks(|parent| { + parent + .map(|parent| { + if let Ok(r) = panic::catch_unwind(panic::AssertUnwindSafe(|| { + parent.on_child_spawn(&mut OnChildTaskSpawnContext { + id, + _phantom: Default::default(), + }) + })) { + r + } else { + None + } + }) + .flatten() + }) + .ok() + .flatten() + .or_else(|| { + if let Some(hooks) = me.hooks_factory_ref() { + if let Ok(r) = panic::catch_unwind(panic::AssertUnwindSafe(|| { + hooks.on_top_level_spawn(&mut OnTopLevelTaskSpawnContext { + id, + _phantom: Default::default(), + }) + })) { + r + } else { + None + } + } else { + None + } + }) + }); - me.task_hooks.spawn(&TaskMeta { + let (handle, notified) = me.shared.owned.bind( + future, + me.clone(), id, - _phantom: Default::default(), - }); + #[cfg(tokio_unstable)] + hooks, + ); me.schedule_option_task_without_yield(notified); diff --git a/tokio/src/runtime/scheduler/multi_thread/worker.rs b/tokio/src/runtime/scheduler/multi_thread/worker.rs index 8db4fb4ec96..e573c78b0e5 100644 --- a/tokio/src/runtime/scheduler/multi_thread/worker.rs +++ b/tokio/src/runtime/scheduler/multi_thread/worker.rs @@ -58,13 +58,15 @@ use crate::loom::sync::{Arc, Mutex}; use crate::runtime; +use crate::runtime::context; use crate::runtime::scheduler::multi_thread::{ idle, queue, Counters, Handle, Idle, Overflow, Parker, Stats, TraceStatus, Unparker, }; use crate::runtime::scheduler::{inject, Defer, Lock}; -use crate::runtime::task::{OwnedTasks, TaskHarnessScheduleHooks}; +use crate::runtime::task::OwnedTasks; use crate::runtime::{blocking, driver, scheduler, task, Config, SchedulerMetrics, WorkerMetrics}; -use crate::runtime::{context, TaskHooks}; +#[cfg(tokio_unstable)] +use crate::runtime::{OptionalTaskHooksFactory, OptionalTaskHooksFactoryRef}; use crate::task::coop; use crate::util::atomic_cell::AtomicCell; use crate::util::rand::{FastRand, RngSeedGenerator}; @@ -281,7 +283,6 @@ pub(super) fn create( let remotes_len = remotes.len(); let handle = Arc::new(Handle { - task_hooks: TaskHooks::from_config(&config), shared: Shared { remotes: remotes.into_boxed_slice(), inject, @@ -570,9 +571,6 @@ impl Context { } fn run_task(&self, task: Notified, mut core: Box) -> RunResult { - #[cfg(tokio_unstable)] - let task_id = task.task_id(); - let task = self.worker.handle.shared.owned.assert_owner(task); // Make sure the worker is not in the **searching** state. This enables @@ -592,16 +590,8 @@ impl Context { // Run the task coop::budget(|| { - // Unlike the poll time above, poll start callback is attached to the task id, - // so it is tightly associated with the actual poll invocation. - #[cfg(tokio_unstable)] - self.worker.handle.task_hooks.poll_start_callback(task_id); - task.run(); - #[cfg(tokio_unstable)] - self.worker.handle.task_hooks.poll_stop_callback(task_id); - let mut lifo_polls = 0; // As long as there is budget remaining and a task exists in the @@ -665,16 +655,7 @@ impl Context { *self.core.borrow_mut() = Some(core); let task = self.worker.handle.shared.owned.assert_owner(task); - #[cfg(tokio_unstable)] - let task_id = task.task_id(); - - #[cfg(tokio_unstable)] - self.worker.handle.task_hooks.poll_start_callback(task_id); - task.run(); - - #[cfg(tokio_unstable)] - self.worker.handle.task_hooks.poll_stop_callback(task_id); } }) } @@ -1057,10 +1038,18 @@ impl task::Schedule for Arc { self.schedule_task(task, false); } - fn hooks(&self) -> TaskHarnessScheduleHooks { - TaskHarnessScheduleHooks { - task_terminate_callback: self.task_hooks.task_terminate_callback.clone(), - } + #[cfg(tokio_unstable)] + fn hooks_factory(&self) -> OptionalTaskHooksFactory { + self.shared.config.task_hook_factory.clone() + } + + #[cfg(tokio_unstable)] + fn hooks_factory_ref(&self) -> OptionalTaskHooksFactoryRef<'_> { + self.shared + .config + .task_hook_factory + .as_ref() + .map(AsRef::as_ref) } fn yield_now(&self, task: Notified) { diff --git a/tokio/src/runtime/scheduler/multi_thread_alt/handle.rs b/tokio/src/runtime/scheduler/multi_thread_alt/handle.rs index 3b730974925..372110c07cb 100644 --- a/tokio/src/runtime/scheduler/multi_thread_alt/handle.rs +++ b/tokio/src/runtime/scheduler/multi_thread_alt/handle.rs @@ -4,11 +4,13 @@ use crate::runtime::scheduler::multi_thread_alt::worker; use crate::runtime::{ blocking, driver, task::{self, JoinHandle}, - TaskHooks, TaskMeta, + OnChildTaskSpawnContext, OnTopLevelTaskSpawnContext, OptionalTaskHooks, }; use crate::util::RngSeedGenerator; -use std::fmt; +use crate::runtime::context::with_task_hooks; +use crate::runtime::task::Schedule; +use std::{fmt, panic}; cfg_unstable_metrics! { mod metrics; @@ -27,19 +29,21 @@ pub(crate) struct Handle { /// Current random number generator seed pub(crate) seed_generator: RngSeedGenerator, - - /// User-supplied hooks to invoke for things - pub(crate) task_hooks: TaskHooks, } impl Handle { /// Spawns a future onto the thread pool - pub(crate) fn spawn(me: &Arc, future: F, id: task::Id) -> JoinHandle + pub(crate) fn spawn( + me: &Arc, + future: F, + id: task::Id, + hooks_override: OptionalTaskHooks, + ) -> JoinHandle where F: crate::future::Future + Send + 'static, F::Output: Send + 'static, { - Self::bind_new_task(me, future, id) + Self::bind_new_task(me, future, id, hooks_override) } pub(crate) fn shutdown(&self) { @@ -47,19 +51,59 @@ impl Handle { self.driver.unpark(); } - pub(super) fn bind_new_task(me: &Arc, future: T, id: task::Id) -> JoinHandle + pub(super) fn bind_new_task( + me: &Arc, + future: T, + id: task::Id, + hooks_override: OptionalTaskHooks, + ) -> JoinHandle where T: Future + Send + 'static, T::Output: Send + 'static, { - let (handle, notified) = me.shared.owned.bind(future, me.clone(), id); - - me.task_hooks.spawn(&TaskMeta { - #[cfg(tokio_unstable)] - id, - _phantom: Default::default(), + // preference order for hook selection: + // 1. "hook override" - comes from builder + // 2. parent task's hook + // 3. runtime hook factory + let hooks = hooks_override.or_else(|| { + with_task_hooks(|parent| { + parent + .map(|parent| { + if let Ok(r) = panic::catch_unwind(panic::AssertUnwindSafe(|| { + parent.on_child_spawn(&mut OnChildTaskSpawnContext { + id, + _phantom: Default::default(), + }) + })) { + r + } else { + None + } + }) + .flatten() + }) + .ok() + .flatten() + .or_else(|| { + if let Some(hooks) = me.hooks_factory_ref() { + if let Ok(r) = panic::catch_unwind(panic::AssertUnwindSafe(|| { + hooks.on_top_level_spawn(&mut OnTopLevelTaskSpawnContext { + id, + _phantom: Default::default(), + }) + })) { + r + } else { + None + } + } else { + None + } + }) }); + let (handle, notified) = me.shared.owned.bind(future, me.clone(), id, hooks); + if let Some(notified) = notified { me.shared.schedule_task(notified, false); } diff --git a/tokio/src/runtime/scheduler/multi_thread_alt/worker.rs b/tokio/src/runtime/scheduler/multi_thread_alt/worker.rs index f3f42e6158a..2eef399ccb0 100644 --- a/tokio/src/runtime/scheduler/multi_thread_alt/worker.rs +++ b/tokio/src/runtime/scheduler/multi_thread_alt/worker.rs @@ -58,14 +58,17 @@ use crate::loom::sync::{Arc, Condvar, Mutex, MutexGuard}; use crate::runtime; +use crate::runtime::context; use crate::runtime::driver::Driver; use crate::runtime::scheduler::multi_thread_alt::{ idle, queue, stats, Counters, Handle, Idle, Overflow, Stats, TraceStatus, }; use crate::runtime::scheduler::{self, inject, Lock}; -use crate::runtime::task::{OwnedTasks, TaskHarnessScheduleHooks}; -use crate::runtime::{blocking, driver, task, Config, SchedulerMetrics, WorkerMetrics}; -use crate::runtime::{context, TaskHooks}; +use crate::runtime::task::OwnedTasks; +use crate::runtime::{ + blocking, driver, task, Config, OptionalTaskHooksFactory, OptionalTaskHooksFactoryRef, + SchedulerMetrics, WorkerMetrics, +}; use crate::task::coop; use crate::util::atomic_cell::AtomicCell; use crate::util::rand::{FastRand, RngSeedGenerator}; @@ -304,7 +307,6 @@ pub(super) fn create( let (inject, inject_synced) = inject::Shared::new(); let handle = Arc::new(Handle { - task_hooks: TaskHooks::from_config(&config), shared: Shared { remotes: remotes.into_boxed_slice(), inject, @@ -1558,10 +1560,16 @@ impl task::Schedule for Arc { self.shared.schedule_task(task, false); } - fn hooks(&self) -> TaskHarnessScheduleHooks { - TaskHarnessScheduleHooks { - task_terminate_callback: self.task_hooks.task_terminate_callback.clone(), - } + fn hooks_factory(&self) -> OptionalTaskHooksFactory { + self.shared.config.task_hook_factory.clone() + } + + fn hooks_factory_ref(&self) -> OptionalTaskHooksFactoryRef<'_> { + self.shared + .config + .task_hook_factory + .as_ref() + .map(AsRef::as_ref) } fn yield_now(&self, task: Notified) { diff --git a/tokio/src/runtime/task/core.rs b/tokio/src/runtime/task/core.rs index 5d3ca0e00c9..7f122581c29 100644 --- a/tokio/src/runtime/task/core.rs +++ b/tokio/src/runtime/task/core.rs @@ -14,7 +14,9 @@ use crate::loom::cell::UnsafeCell; use crate::runtime::context; use crate::runtime::task::raw::{self, Vtable}; use crate::runtime::task::state::State; -use crate::runtime::task::{Id, Schedule, TaskHarnessScheduleHooks}; +use crate::runtime::task::{Id, Schedule}; +#[cfg(tokio_unstable)] +use crate::runtime::OptionalTaskHooks; use crate::util::linked_list; use std::num::NonZeroU64; @@ -186,7 +188,8 @@ pub(super) struct Trailer { /// Consumer task waiting on completion of this task. pub(super) waker: UnsafeCell>, /// Optional hooks needed in the harness. - pub(super) hooks: TaskHarnessScheduleHooks, + #[cfg(tokio_unstable)] + pub(super) hooks: UnsafeCell, } generate_addr_of_methods! { @@ -208,7 +211,13 @@ pub(super) enum Stage { impl Cell { /// Allocates a new task cell, containing the header, trailer, and core /// structures. - pub(super) fn new(future: T, scheduler: S, state: State, task_id: Id) -> Box> { + pub(super) fn new( + future: T, + scheduler: S, + state: State, + task_id: Id, + #[cfg(tokio_unstable)] hooks: OptionalTaskHooks, + ) -> Box> { // Separated into a non-generic function to reduce LLVM codegen fn new_header( state: State, @@ -229,7 +238,13 @@ impl Cell { let tracing_id = future.id(); let vtable = raw::vtable::(); let result = Box::new(Cell { - trailer: Trailer::new(scheduler.hooks()), + #[cfg(tokio_unstable)] + trailer: Trailer::new( + #[cfg(tokio_unstable)] + hooks, + ), + #[cfg(not(tokio_unstable))] + trailer: Trailer::new(), header: new_header( state, vtable, @@ -462,11 +477,12 @@ impl Header { } impl Trailer { - fn new(hooks: TaskHarnessScheduleHooks) -> Self { + fn new(#[cfg(tokio_unstable)] hooks: OptionalTaskHooks) -> Self { Trailer { waker: UnsafeCell::new(None), owned: linked_list::Pointers::new(), - hooks, + #[cfg(tokio_unstable)] + hooks: UnsafeCell::new(hooks), } } diff --git a/tokio/src/runtime/task/harness.rs b/tokio/src/runtime/task/harness.rs index 9bf73b74fbf..f039731a51c 100644 --- a/tokio/src/runtime/task/harness.rs +++ b/tokio/src/runtime/task/harness.rs @@ -1,10 +1,12 @@ use crate::future::Future; +#[cfg(tokio_unstable)] +use crate::runtime::context::with_task_hooks; use crate::runtime::task::core::{Cell, Core, Header, Trailer}; use crate::runtime::task::state::{Snapshot, State}; use crate::runtime::task::waker::waker_ref; use crate::runtime::task::{Id, JoinError, Notified, RawTask, Schedule, Task}; - -use crate::runtime::TaskMeta; +#[cfg(tokio_unstable)] +use crate::runtime::{AfterTaskPollContext, OnTaskTerminateContext}; use std::any::Any; use std::mem; use std::mem::ManuallyDrop; @@ -150,8 +152,21 @@ where /// All necessary state checks and transitions are performed. /// Panics raised while polling the future are handled. pub(super) fn poll(self) { + let res = self.poll_inner(); + + #[cfg(tokio_unstable)] + let _ = with_task_hooks(|t| { + if let Some(hooks) = t { + let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| { + hooks.after_poll(&mut AfterTaskPollContext { + _phantom: Default::default(), + }) + })); + } + }); + // We pass our ref-count to `poll_inner`. - match self.poll_inner() { + match res { PollFuture::Notified => { // The `poll_inner` call has given us two ref-counts back. // We give one of them to a new task and call `yield_now`. @@ -367,14 +382,16 @@ where // // We call this in a separate block so that it runs after the task appears to have // completed and will still run if the destructor panics. - if let Some(f) = self.trailer().hooks.task_terminate_callback.as_ref() { - let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| { - f(&TaskMeta { - id: self.core().task_id, - _phantom: Default::default(), - }) - })); - } + #[cfg(tokio_unstable)] + let _ = with_task_hooks(|t| { + if let Some(hooks) = t { + let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| { + hooks.on_task_terminate(&mut OnTaskTerminateContext { + _phantom: Default::default(), + }) + })); + } + }); // The task has completed execution and will no longer be scheduled. let num_release = self.release(); diff --git a/tokio/src/runtime/task/list.rs b/tokio/src/runtime/task/list.rs index 54bfc01aafb..91e89dd4ffa 100644 --- a/tokio/src/runtime/task/list.rs +++ b/tokio/src/runtime/task/list.rs @@ -13,9 +13,10 @@ use crate::util::linked_list::{Link, LinkedList}; use crate::util::sharded_list; use crate::loom::sync::atomic::{AtomicBool, Ordering}; +#[cfg(tokio_unstable)] +use crate::runtime::OptionalTaskHooks; use std::marker::PhantomData; use std::num::NonZeroU64; - // The id from the module below is used to verify whether a given task is stored // in this OwnedTasks, or some other task. The counter starts at one so we can // use `None` for tasks not owned by any list. @@ -91,13 +92,20 @@ impl OwnedTasks { task: T, scheduler: S, id: super::Id, + #[cfg(tokio_unstable)] hooks: OptionalTaskHooks, ) -> (JoinHandle, Option>) where S: Schedule, T: Future + Send + 'static, T::Output: Send + 'static, { - let (task, notified, join) = super::new_task(task, scheduler, id); + let (task, notified, join) = super::new_task( + task, + scheduler, + id, + #[cfg(tokio_unstable)] + hooks, + ); let notified = unsafe { self.bind_inner(task, notified) }; (join, notified) } @@ -111,13 +119,20 @@ impl OwnedTasks { task: T, scheduler: S, id: super::Id, + #[cfg(tokio_unstable)] parent: OptionalTaskHooks, ) -> (JoinHandle, Option>) where S: Schedule, T: Future + 'static, T::Output: 'static, { - let (task, notified, join) = super::new_task(task, scheduler, id); + let (task, notified, join) = super::new_task( + task, + scheduler, + id, + #[cfg(tokio_unstable)] + parent, + ); let notified = unsafe { self.bind_inner(task, notified) }; (join, notified) } @@ -258,12 +273,16 @@ impl LocalOwnedTasks { task: T, scheduler: S, id: super::Id, + #[cfg(tokio_unstable)] parent: OptionalTaskHooks, ) -> (JoinHandle, Option>) where S: Schedule, T: Future + 'static, T::Output: 'static, { + #[cfg(tokio_unstable)] + let (task, notified, join) = super::new_task(task, scheduler, id, parent); + #[cfg(not(tokio_unstable))] let (task, notified, join) = super::new_task(task, scheduler, id); unsafe { diff --git a/tokio/src/runtime/task/mod.rs b/tokio/src/runtime/task/mod.rs index 7d314c3b176..7cf2f1e98f7 100644 --- a/tokio/src/runtime/task/mod.rs +++ b/tokio/src/runtime/task/mod.rs @@ -221,10 +221,10 @@ cfg_taskdump! { } use crate::future::Future; +#[cfg(tokio_unstable)] +use crate::runtime::{OptionalTaskHooks, OptionalTaskHooksFactory, OptionalTaskHooksFactoryRef}; use crate::util::linked_list; use crate::util::sharded_list; - -use crate::runtime::TaskCallback; use std::marker::PhantomData; use std::ptr::NonNull; use std::{fmt, mem}; @@ -256,13 +256,6 @@ pub(crate) struct LocalNotified { _not_send: PhantomData<*const ()>, } -impl LocalNotified { - #[cfg(tokio_unstable)] - pub(crate) fn task_id(&self) -> Id { - self.task.id() - } -} - /// A task that is not owned by any `OwnedTasks`. Used for blocking tasks. /// This type holds two ref-counts. pub(crate) struct UnownedTask { @@ -277,12 +270,6 @@ unsafe impl Sync for UnownedTask {} /// Task result sent back. pub(crate) type Result = std::result::Result; -/// Hooks for scheduling tasks which are needed in the task harness. -#[derive(Clone)] -pub(crate) struct TaskHarnessScheduleHooks { - pub(crate) task_terminate_callback: Option, -} - pub(crate) trait Schedule: Sync + Sized + 'static { /// The task has completed work and is ready to be released. The scheduler /// should release it immediately and return it. The task module will batch @@ -294,7 +281,11 @@ pub(crate) trait Schedule: Sync + Sized + 'static { /// Schedule the task fn schedule(&self, task: Notified); - fn hooks(&self) -> TaskHarnessScheduleHooks; + #[cfg(tokio_unstable)] + fn hooks_factory(&self) -> OptionalTaskHooksFactory; + + #[cfg(tokio_unstable)] + fn hooks_factory_ref(&self) -> OptionalTaskHooksFactoryRef<'_>; /// Schedule the task to run in the near future, yielding the thread to /// other tasks. @@ -317,13 +308,19 @@ cfg_rt! { task: T, scheduler: S, id: Id, + #[cfg(tokio_unstable)] + hooks: OptionalTaskHooks ) -> (Task, Notified, JoinHandle) where S: Schedule, T: Future + 'static, T::Output: 'static, { + #[cfg(tokio_unstable)] + let raw = RawTask::new::(task, scheduler, id, hooks); + #[cfg(not(tokio_unstable))] let raw = RawTask::new::(task, scheduler, id); + let task = Task { raw, _p: PhantomData, @@ -341,12 +338,16 @@ cfg_rt! { /// only when the task is not going to be stored in an `OwnedTasks` list. /// /// Currently only blocking tasks use this method. - pub(crate) fn unowned(task: T, scheduler: S, id: Id) -> (UnownedTask, JoinHandle) + pub(crate) fn unowned(task: T, scheduler: S, id: Id, #[cfg(tokio_unstable)] hooks: OptionalTaskHooks) -> (UnownedTask, JoinHandle) where S: Schedule, T: Send + Future + 'static, T::Output: Send + 'static, { + #[cfg(tokio_unstable)] + let (task, notified, join) = new_task(task, scheduler, id, hooks); + + #[cfg(not(tokio_unstable))] let (task, notified, join) = new_task(task, scheduler, id); // This transfers the ref-count of task and notified into an UnownedTask. @@ -459,6 +460,7 @@ impl LocalNotified { /// Runs the task. pub(crate) fn run(self) { let raw = self.task.raw; + mem::forget(self); raw.poll(); } diff --git a/tokio/src/runtime/task/raw.rs b/tokio/src/runtime/task/raw.rs index 6699551f3ec..ad2f30677dc 100644 --- a/tokio/src/runtime/task/raw.rs +++ b/tokio/src/runtime/task/raw.rs @@ -1,7 +1,12 @@ use crate::future::Future; +#[cfg(tokio_unstable)] +use crate::runtime::context::set_task_hooks; use crate::runtime::task::core::{Core, Trailer}; use crate::runtime::task::{Cell, Harness, Header, Id, Schedule, State}; - +#[cfg(tokio_unstable)] +use crate::runtime::{BeforeTaskPollContext, OptionalTaskHooks, TaskHookHarness}; +#[cfg(tokio_unstable)] +use std::panic; use std::ptr::NonNull; use std::task::{Poll, Waker}; @@ -157,12 +162,24 @@ const fn get_id_offset( } impl RawTask { - pub(super) fn new(task: T, scheduler: S, id: Id) -> RawTask + pub(super) fn new( + task: T, + scheduler: S, + id: Id, + #[cfg(tokio_unstable)] hooks: OptionalTaskHooks, + ) -> RawTask where T: Future, S: Schedule, { - let ptr = Box::into_raw(Cell::<_, S>::new(task, scheduler, State::new(), id)); + let ptr = Box::into_raw(Cell::<_, S>::new( + task, + scheduler, + State::new(), + id, + #[cfg(tokio_unstable)] + hooks, + )); let ptr = unsafe { NonNull::new_unchecked(ptr.cast()) }; RawTask { ptr } @@ -197,8 +214,30 @@ impl RawTask { /// Safety: mutual exclusion is required to call this function. pub(crate) fn poll(self) { - let vtable = self.header().vtable; - unsafe { (vtable.poll)(self.ptr) } + #[cfg(tokio_unstable)] + self.trailer().hooks.with_mut(|ptr| unsafe { + let _guard = ptr.as_mut().and_then(|x| { + x.as_mut().map(|x| { + let _ = panic::catch_unwind(panic::AssertUnwindSafe(|| { + x.before_poll(&mut BeforeTaskPollContext { + _phantom: Default::default(), + }) + })); + + set_task_hooks(NonNull::new( + (&mut **x) as *mut (dyn TaskHookHarness + Send + Sync + 'static), + )) + }) + }); + + let vtable = self.header().vtable; + (vtable.poll)(self.ptr); + }); + #[cfg(not(tokio_unstable))] + unsafe { + let vtable = self.header().vtable; + (vtable.poll)(self.ptr); + } } pub(super) fn schedule(self) { diff --git a/tokio/src/runtime/task_hooks.rs b/tokio/src/runtime/task_hooks.rs deleted file mode 100644 index 13865ed515d..00000000000 --- a/tokio/src/runtime/task_hooks.rs +++ /dev/null @@ -1,81 +0,0 @@ -use std::marker::PhantomData; - -use super::Config; - -impl TaskHooks { - pub(crate) fn spawn(&self, meta: &TaskMeta<'_>) { - if let Some(f) = self.task_spawn_callback.as_ref() { - f(meta) - } - } - - #[allow(dead_code)] - pub(crate) fn from_config(config: &Config) -> Self { - Self { - task_spawn_callback: config.before_spawn.clone(), - task_terminate_callback: config.after_termination.clone(), - #[cfg(tokio_unstable)] - before_poll_callback: config.before_poll.clone(), - #[cfg(tokio_unstable)] - after_poll_callback: config.after_poll.clone(), - } - } - - #[cfg(tokio_unstable)] - #[inline] - pub(crate) fn poll_start_callback(&self, id: super::task::Id) { - if let Some(poll_start) = &self.before_poll_callback { - (poll_start)(&TaskMeta { - id, - _phantom: std::marker::PhantomData, - }) - } - } - - #[cfg(tokio_unstable)] - #[inline] - pub(crate) fn poll_stop_callback(&self, id: super::task::Id) { - if let Some(poll_stop) = &self.after_poll_callback { - (poll_stop)(&TaskMeta { - id, - _phantom: std::marker::PhantomData, - }) - } - } -} - -#[derive(Clone)] -pub(crate) struct TaskHooks { - pub(crate) task_spawn_callback: Option, - pub(crate) task_terminate_callback: Option, - #[cfg(tokio_unstable)] - pub(crate) before_poll_callback: Option, - #[cfg(tokio_unstable)] - pub(crate) after_poll_callback: Option, -} - -/// Task metadata supplied to user-provided hooks for task events. -/// -/// **Note**: This is an [unstable API][unstable]. The public API of this type -/// may break in 1.x releases. See [the documentation on unstable -/// features][unstable] for details. -/// -/// [unstable]: crate#unstable-features -#[allow(missing_debug_implementations)] -#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] -pub struct TaskMeta<'a> { - /// The opaque ID of the task. - pub(crate) id: super::task::Id, - pub(crate) _phantom: PhantomData<&'a ()>, -} - -impl<'a> TaskMeta<'a> { - /// Return the opaque ID of the task. - #[cfg_attr(not(tokio_unstable), allow(unreachable_pub, dead_code))] - pub fn id(&self) -> super::task::Id { - self.id - } -} - -/// Runs on specific task-related events -pub(crate) type TaskCallback = std::sync::Arc) + Send + Sync>; diff --git a/tokio/src/runtime/task_hooks/mod.rs b/tokio/src/runtime/task_hooks/mod.rs new file mode 100644 index 00000000000..c54c2787991 --- /dev/null +++ b/tokio/src/runtime/task_hooks/mod.rs @@ -0,0 +1,99 @@ +use super::task; +use crate::loom::cell::UnsafeCell; +use std::marker::PhantomData; +use std::ptr::NonNull; +use std::sync::Arc; + +/// A factory which produces new [`TaskHookHarness`] objects for tasks which either have been +/// spawned in "detached mode" via the builder, or which were spawned from outside the runtime or +/// from another context where no [`TaskHookHarness`] was present. +pub trait TaskHookHarnessFactory { + /// Create a new [`TaskHookHarness`] object which the runtime will attach to a given task. + fn on_top_level_spawn( + &self, + ctx: &mut OnTopLevelTaskSpawnContext<'_>, + ) -> Option>; +} + +/// Trait for user-provided "harness" objects which are attached to tasks and provide hook +/// implementations. +#[allow(unused_variables)] +pub trait TaskHookHarness { + /// Pre-poll task hook which runs arbitrary user logic. + fn before_poll(&mut self, ctx: &mut BeforeTaskPollContext<'_>) {} + + /// Post-poll task hook which runs arbitrary user logic. + fn after_poll(&mut self, ctx: &mut AfterTaskPollContext<'_>) {} + + /// Task hook which runs when this task spawns a child, unless that child is explicitly spawned + /// detached from the parent. + /// + /// This hook creates a harness for the child, or detaches the child from any instrumentation. + fn on_child_spawn( + &mut self, + ctx: &mut OnChildTaskSpawnContext<'_>, + ) -> Option> { + None + } + + /// Task hook which runs on task termination. + fn on_task_terminate(&mut self, ctx: &mut OnTaskTerminateContext<'_>) {} +} + +pub(crate) type OptionalTaskHooksFactory = + Option>; +pub(crate) type OptionalTaskHooks = Option>; + +pub(crate) type OptionalTaskHooksWeak = + UnsafeCell>>; + +pub(crate) type OptionalTaskHooksMut<'a> = + Option<&'a mut (dyn TaskHookHarness + Send + Sync + 'static)>; +pub(crate) type OptionalTaskHooksFactoryRef<'a> = + Option<&'a (dyn TaskHookHarnessFactory + Send + Sync + 'static)>; + +#[allow(missing_debug_implementations, missing_docs)] +#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] +pub struct OnTopLevelTaskSpawnContext<'a> { + pub(crate) id: task::Id, + pub(crate) _phantom: PhantomData<&'a ()>, +} + +impl<'a> OnTopLevelTaskSpawnContext<'a> { + /// Returns the ID of the task. + pub fn id(&self) -> task::Id { + self.id + } +} + +#[allow(missing_debug_implementations, missing_docs)] +#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] +pub struct OnChildTaskSpawnContext<'a> { + pub(crate) id: task::Id, + pub(crate) _phantom: PhantomData<&'a ()>, +} + +impl<'a> OnChildTaskSpawnContext<'a> { + /// Returns the ID of the task. + pub fn id(&self) -> task::Id { + self.id + } +} + +#[allow(missing_debug_implementations, missing_docs)] +#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] +pub struct OnTaskTerminateContext<'a> { + pub(crate) _phantom: PhantomData<&'a ()>, +} + +#[allow(missing_debug_implementations, missing_docs)] +#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] +pub struct BeforeTaskPollContext<'a> { + pub(crate) _phantom: PhantomData<&'a ()>, +} + +#[allow(missing_debug_implementations, missing_docs)] +#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] +pub struct AfterTaskPollContext<'a> { + pub(crate) _phantom: PhantomData<&'a ()>, +} diff --git a/tokio/src/runtime/tests/mod.rs b/tokio/src/runtime/tests/mod.rs index 11901ebc9e5..042895c6185 100644 --- a/tokio/src/runtime/tests/mod.rs +++ b/tokio/src/runtime/tests/mod.rs @@ -6,7 +6,9 @@ use self::noop_scheduler::NoopSchedule; use self::unowned_wrapper::unowned; mod noop_scheduler { - use crate::runtime::task::{self, Task, TaskHarnessScheduleHooks}; + use crate::runtime::task::{self, Task}; + #[cfg(tokio_unstable)] + use crate::runtime::{OptionalTaskHooksFactory, OptionalTaskHooksFactoryRef}; /// `task::Schedule` implementation that does nothing, for testing. pub(crate) struct NoopSchedule; @@ -20,10 +22,14 @@ mod noop_scheduler { unreachable!(); } - fn hooks(&self) -> TaskHarnessScheduleHooks { - TaskHarnessScheduleHooks { - task_terminate_callback: None, - } + #[cfg(tokio_unstable)] + fn hooks_factory(&self) -> OptionalTaskHooksFactory { + None + } + + #[cfg(tokio_unstable)] + fn hooks_factory_ref(&self) -> OptionalTaskHooksFactoryRef<'_> { + None } } } @@ -41,6 +47,9 @@ mod unowned_wrapper { use tracing::Instrument; let span = tracing::trace_span!("test_span"); let task = task.instrument(span); + #[cfg(tokio_unstable)] + let (task, handle) = crate::runtime::task::unowned(task, NoopSchedule, Id::next(), None); + #[cfg(not(tokio_unstable))] let (task, handle) = crate::runtime::task::unowned(task, NoopSchedule, Id::next()); (task.into_notified(), handle) } @@ -51,6 +60,9 @@ mod unowned_wrapper { T: std::future::Future + Send + 'static, T::Output: Send + 'static, { + #[cfg(tokio_unstable)] + let (task, handle) = crate::runtime::task::unowned(task, NoopSchedule, Id::next(), None); + #[cfg(not(tokio_unstable))] let (task, handle) = crate::runtime::task::unowned(task, NoopSchedule, Id::next()); (task.into_notified(), handle) } diff --git a/tokio/src/runtime/tests/queue.rs b/tokio/src/runtime/tests/queue.rs index 8a57ae428e8..320ae680fc3 100644 --- a/tokio/src/runtime/tests/queue.rs +++ b/tokio/src/runtime/tests/queue.rs @@ -1,6 +1,7 @@ use crate::runtime::scheduler::multi_thread::{queue, Stats}; -use crate::runtime::task::{self, Schedule, Task, TaskHarnessScheduleHooks}; - +use crate::runtime::task::{self, Schedule, Task}; +#[cfg(tokio_unstable)] +use crate::runtime::{OptionalTaskHooksFactory, OptionalTaskHooksFactoryRef}; use std::cell::RefCell; use std::thread; use std::time::Duration; @@ -285,9 +286,13 @@ impl Schedule for Runtime { unreachable!(); } - fn hooks(&self) -> TaskHarnessScheduleHooks { - TaskHarnessScheduleHooks { - task_terminate_callback: None, - } + #[cfg(tokio_unstable)] + fn hooks_factory(&self) -> OptionalTaskHooksFactory { + None + } + + #[cfg(tokio_unstable)] + fn hooks_factory_ref(&self) -> OptionalTaskHooksFactoryRef<'_> { + None } } diff --git a/tokio/src/runtime/tests/task.rs b/tokio/src/runtime/tests/task.rs index ea48b8e5199..4cf0de69cf0 100644 --- a/tokio/src/runtime/tests/task.rs +++ b/tokio/src/runtime/tests/task.rs @@ -1,8 +1,7 @@ -use crate::runtime::task::{ - self, unowned, Id, JoinHandle, OwnedTasks, Schedule, Task, TaskHarnessScheduleHooks, -}; +use crate::runtime::task::{self, unowned, Id, JoinHandle, OwnedTasks, Schedule, Task}; use crate::runtime::tests::NoopSchedule; - +#[cfg(tokio_unstable)] +use crate::runtime::{OptionalTaskHooksFactory, OptionalTaskHooksFactoryRef}; use std::collections::VecDeque; use std::future::Future; use std::sync::atomic::{AtomicBool, Ordering}; @@ -447,9 +446,13 @@ impl Schedule for Runtime { self.0.core.try_lock().unwrap().queue.push_back(task); } - fn hooks(&self) -> TaskHarnessScheduleHooks { - TaskHarnessScheduleHooks { - task_terminate_callback: None, - } + #[cfg(tokio_unstable)] + fn hooks_factory(&self) -> OptionalTaskHooksFactory { + None + } + + #[cfg(tokio_unstable)] + fn hooks_factory_ref(&self) -> OptionalTaskHooksFactoryRef<'_> { + None } } diff --git a/tokio/src/task/builder.rs b/tokio/src/task/builder.rs index 6053352a01c..c34a2ca462e 100644 --- a/tokio/src/task/builder.rs +++ b/tokio/src/task/builder.rs @@ -44,8 +44,12 @@ use std::{future::Future, io, mem}; /// loop { /// let (socket, _) = listener.accept().await?; /// -/// tokio::task::Builder::new() -/// .name("tcp connection handler") +/// let mut builder = tokio::task::Builder::new(); +/// +/// builder +/// .name("tcp connection handler"); +/// +/// builder /// .spawn(async move { /// // Process each socket concurrently. /// process(socket).await @@ -71,8 +75,9 @@ impl<'a> Builder<'a> { } /// Assigns a name to the task which will be spawned. - pub fn name(&self, name: &'a str) -> Self { - Self { name: Some(name) } + pub fn name(&mut self, name: &'a str) -> &mut Self { + self.name = Some(name); + self } /// Spawns a task with this builder's settings on the current runtime. @@ -91,9 +96,9 @@ impl<'a> Builder<'a> { { let fut_size = mem::size_of::(); Ok(if fut_size > BOX_FUTURE_THRESHOLD { - super::spawn::spawn_inner(Box::pin(future), SpawnMeta::new(self.name, fut_size)) + super::spawn::spawn_inner(Box::pin(future), SpawnMeta::new(self.name, fut_size), None) } else { - super::spawn::spawn_inner(future, SpawnMeta::new(self.name, fut_size)) + super::spawn::spawn_inner(future, SpawnMeta::new(self.name, fut_size), None) }) } @@ -112,9 +117,9 @@ impl<'a> Builder<'a> { { let fut_size = mem::size_of::(); Ok(if fut_size > BOX_FUTURE_THRESHOLD { - handle.spawn_named(Box::pin(future), SpawnMeta::new(self.name, fut_size)) + handle.spawn_named(Box::pin(future), SpawnMeta::new(self.name, fut_size), None) } else { - handle.spawn_named(future, SpawnMeta::new(self.name, fut_size)) + handle.spawn_named(future, SpawnMeta::new(self.name, fut_size), None) }) } @@ -140,9 +145,13 @@ impl<'a> Builder<'a> { { let fut_size = mem::size_of::(); Ok(if fut_size > BOX_FUTURE_THRESHOLD { - super::local::spawn_local_inner(Box::pin(future), SpawnMeta::new(self.name, fut_size)) + super::local::spawn_local_inner( + Box::pin(future), + SpawnMeta::new(self.name, fut_size), + None, + ) } else { - super::local::spawn_local_inner(future, SpawnMeta::new(self.name, fut_size)) + super::local::spawn_local_inner(future, SpawnMeta::new(self.name, fut_size), None) }) } diff --git a/tokio/src/task/join_set.rs b/tokio/src/task/join_set.rs index a156719a067..2c43b9c989d 100644 --- a/tokio/src/task/join_set.rs +++ b/tokio/src/task/join_set.rs @@ -641,9 +641,13 @@ where #[cfg_attr(docsrs, doc(cfg(all(tokio_unstable, feature = "tracing"))))] impl<'a, T: 'static> Builder<'a, T> { /// Assigns a name to the task which will be spawned. - pub fn name(self, name: &'a str) -> Self { - let builder = self.builder.name(name); - Self { builder, ..self } + pub fn name(mut self, name: &'a str) -> Self { + self.builder.name(name); + + Self { + builder: self.builder, + ..self + } } /// Spawn the provided task with this builder's settings and store it in the diff --git a/tokio/src/task/local.rs b/tokio/src/task/local.rs index 95bd6404bec..f100e6dc2ca 100644 --- a/tokio/src/task/local.rs +++ b/tokio/src/task/local.rs @@ -1,9 +1,11 @@ //! Runs `!Send` futures on the current thread. use crate::loom::cell::UnsafeCell; use crate::loom::sync::{Arc, Mutex}; +use crate::runtime::task::{self, JoinHandle, LocalOwnedTasks, Task}; #[cfg(tokio_unstable)] -use crate::runtime; -use crate::runtime::task::{self, JoinHandle, LocalOwnedTasks, Task, TaskHarnessScheduleHooks}; +use crate::runtime::{ + self, OptionalTaskHooks, OptionalTaskHooksFactory, OptionalTaskHooksFactoryRef, +}; use crate::runtime::{context, ThreadId, BOX_FUTURE_THRESHOLD}; use crate::sync::AtomicWaker; use crate::util::trace::SpawnMeta; @@ -371,6 +373,13 @@ cfg_rt! { F::Output: 'static, { let fut_size = std::mem::size_of::(); + #[cfg(tokio_unstable)] + if fut_size > BOX_FUTURE_THRESHOLD { + spawn_local_inner(Box::pin(future), SpawnMeta::new_unnamed(fut_size), None) + } else { + spawn_local_inner(future, SpawnMeta::new_unnamed(fut_size), None) + } + #[cfg(not(tokio_unstable))] if fut_size > BOX_FUTURE_THRESHOLD { spawn_local_inner(Box::pin(future), SpawnMeta::new_unnamed(fut_size)) } else { @@ -380,7 +389,7 @@ cfg_rt! { #[track_caller] - pub(super) fn spawn_local_inner(future: F, meta: SpawnMeta<'_>) -> JoinHandle + pub(super) fn spawn_local_inner(future: F, meta: SpawnMeta<'_>, #[cfg(tokio_unstable)] hooks_override: OptionalTaskHooks) -> JoinHandle where F: Future + 'static, F::Output: 'static { @@ -412,6 +421,9 @@ cfg_rt! { let task = crate::util::trace::task(future, "task", meta, id.as_u64()); // safety: we have verified that this is a `LocalRuntime` owned by the current thread + #[cfg(tokio_unstable)] + unsafe { handle.spawn_local(task, id, hooks_override) } + #[cfg(not(tokio_unstable))] unsafe { handle.spawn_local(task, id) } } else { match CURRENT.with(|LocalData { ctx, .. }| ctx.get()) { @@ -1004,6 +1016,15 @@ impl Context { let future = crate::util::trace::task(future, "local", meta, id.as_u64()); // Safety: called from the thread that owns the `LocalSet` + #[cfg(tokio_unstable)] + let (handle, notified) = { + self.shared.local_state.assert_called_from_owner_thread(); + self.shared + .local_state + .owned + .bind(future, self.shared.clone(), id, None) + }; + #[cfg(not(tokio_unstable))] let (handle, notified) = { self.shared.local_state.assert_called_from_owner_thread(); self.shared @@ -1117,11 +1138,15 @@ impl task::Schedule for Arc { Shared::schedule(self, task); } - // localset does not currently support task hooks - fn hooks(&self) -> TaskHarnessScheduleHooks { - TaskHarnessScheduleHooks { - task_terminate_callback: None, - } + #[cfg(tokio_unstable)] + fn hooks_factory(&self) -> OptionalTaskHooksFactory { + None + } + + // localset does not support task hooks + #[cfg(tokio_unstable)] + fn hooks_factory_ref(&self) -> OptionalTaskHooksFactoryRef<'_> { + None } cfg_unstable! { diff --git a/tokio/src/task/mod.rs b/tokio/src/task/mod.rs index f0c6f71c15a..601793f2f99 100644 --- a/tokio/src/task/mod.rs +++ b/tokio/src/task/mod.rs @@ -311,6 +311,10 @@ cfg_rt! { pub use crate::runtime::task::{Id, id, try_id}; + cfg_unstable! { + pub use spawn::spawn_with_hooks; + } + cfg_trace! { mod builder; pub use builder::Builder; diff --git a/tokio/src/task/spawn.rs b/tokio/src/task/spawn.rs index 7c748226121..ad1fc6eb19c 100644 --- a/tokio/src/task/spawn.rs +++ b/tokio/src/task/spawn.rs @@ -1,4 +1,6 @@ use crate::runtime::BOX_FUTURE_THRESHOLD; +#[cfg(tokio_unstable)] +use crate::runtime::{OptionalTaskHooks, TaskHookHarness}; use crate::task::JoinHandle; use crate::util::trace::SpawnMeta; @@ -169,6 +171,13 @@ cfg_rt! { F::Output: Send + 'static, { let fut_size = std::mem::size_of::(); + #[cfg(tokio_unstable)] + if fut_size > BOX_FUTURE_THRESHOLD { + spawn_inner(Box::pin(future), SpawnMeta::new_unnamed(fut_size), None) + } else { + spawn_inner(future, SpawnMeta::new_unnamed(fut_size), None) + } + #[cfg(not(tokio_unstable))] if fut_size > BOX_FUTURE_THRESHOLD { spawn_inner(Box::pin(future), SpawnMeta::new_unnamed(fut_size)) } else { @@ -176,8 +185,26 @@ cfg_rt! { } } + /// Spawn a future with a custom set of task hooks + #[track_caller] + #[cfg(tokio_unstable)] + pub fn spawn_with_hooks(future: F, hooks: T) -> JoinHandle + where + F: Future + Send + 'static, + F::Output: Send + 'static, + T: TaskHookHarness + Send + Sync + 'static, + { + let fut_size = std::mem::size_of::(); + + if fut_size > BOX_FUTURE_THRESHOLD { + spawn_inner(Box::pin(future), SpawnMeta::new_unnamed(fut_size), Some(Box::new(hooks))) + } else { + spawn_inner(future, SpawnMeta::new_unnamed(fut_size), Some(Box::new(hooks))) + } + } + #[track_caller] - pub(super) fn spawn_inner(future: T, meta: SpawnMeta<'_>) -> JoinHandle + pub(super) fn spawn_inner(future: T, meta: SpawnMeta<'_>, #[cfg(tokio_unstable)] hooks_override: OptionalTaskHooks) -> JoinHandle where T: Future + Send + 'static, T::Output: Send + 'static, @@ -199,6 +226,13 @@ cfg_rt! { let id = task::Id::next(); let task = crate::util::trace::task(future, "task", meta, id.as_u64()); + #[cfg(tokio_unstable)] + return match context::with_current(|handle| handle.spawn(task, id, hooks_override)) { + Ok(join_handle) => join_handle, + Err(e) => panic!("{}", e), + }; + + #[cfg(not(tokio_unstable))] match context::with_current(|handle| handle.spawn(task, id)) { Ok(join_handle) => join_handle, Err(e) => panic!("{}", e), diff --git a/tokio/tests/rt_poll_callbacks.rs b/tokio/tests/rt_poll_callbacks.rs deleted file mode 100644 index 8ccff385772..00000000000 --- a/tokio/tests/rt_poll_callbacks.rs +++ /dev/null @@ -1,128 +0,0 @@ -#![allow(unknown_lints, unexpected_cfgs)] -#![cfg(tokio_unstable)] - -use std::sync::{atomic::AtomicUsize, Arc, Mutex}; - -use tokio::task::yield_now; - -#[cfg(not(target_os = "wasi"))] -#[test] -fn callbacks_fire_multi_thread() { - let poll_start_counter = Arc::new(AtomicUsize::new(0)); - let poll_stop_counter = Arc::new(AtomicUsize::new(0)); - let poll_start = poll_start_counter.clone(); - let poll_stop = poll_stop_counter.clone(); - - let before_task_poll_callback_task_id: Arc>> = - Arc::new(Mutex::new(None)); - let after_task_poll_callback_task_id: Arc>> = - Arc::new(Mutex::new(None)); - - let before_task_poll_callback_task_id_ref = Arc::clone(&before_task_poll_callback_task_id); - let after_task_poll_callback_task_id_ref = Arc::clone(&after_task_poll_callback_task_id); - let rt = tokio::runtime::Builder::new_multi_thread() - .enable_all() - .on_before_task_poll(move |task_meta| { - before_task_poll_callback_task_id_ref - .lock() - .unwrap() - .replace(task_meta.id()); - poll_start_counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - }) - .on_after_task_poll(move |task_meta| { - after_task_poll_callback_task_id_ref - .lock() - .unwrap() - .replace(task_meta.id()); - poll_stop_counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - }) - .build() - .unwrap(); - let task = rt.spawn(async { - yield_now().await; - yield_now().await; - yield_now().await; - }); - - let spawned_task_id = task.id(); - - rt.block_on(task).expect("task should succeed"); - // We need to drop the runtime to guarantee the workers have exited (and thus called the callback) - drop(rt); - - assert_eq!( - before_task_poll_callback_task_id.lock().unwrap().unwrap(), - spawned_task_id - ); - assert_eq!( - after_task_poll_callback_task_id.lock().unwrap().unwrap(), - spawned_task_id - ); - let actual_count = 4; - assert_eq!( - poll_start.load(std::sync::atomic::Ordering::Relaxed), - actual_count, - "unexpected number of poll starts" - ); - assert_eq!( - poll_stop.load(std::sync::atomic::Ordering::Relaxed), - actual_count, - "unexpected number of poll stops" - ); -} - -#[test] -fn callbacks_fire_current_thread() { - let poll_start_counter = Arc::new(AtomicUsize::new(0)); - let poll_stop_counter = Arc::new(AtomicUsize::new(0)); - let poll_start = poll_start_counter.clone(); - let poll_stop = poll_stop_counter.clone(); - - let before_task_poll_callback_task_id: Arc>> = - Arc::new(Mutex::new(None)); - let after_task_poll_callback_task_id: Arc>> = - Arc::new(Mutex::new(None)); - - let before_task_poll_callback_task_id_ref = Arc::clone(&before_task_poll_callback_task_id); - let after_task_poll_callback_task_id_ref = Arc::clone(&after_task_poll_callback_task_id); - let rt = tokio::runtime::Builder::new_current_thread() - .enable_all() - .on_before_task_poll(move |task_meta| { - before_task_poll_callback_task_id_ref - .lock() - .unwrap() - .replace(task_meta.id()); - poll_start_counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - }) - .on_after_task_poll(move |task_meta| { - after_task_poll_callback_task_id_ref - .lock() - .unwrap() - .replace(task_meta.id()); - poll_stop_counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - }) - .build() - .unwrap(); - - let task = rt.spawn(async { - yield_now().await; - yield_now().await; - yield_now().await; - }); - - let spawned_task_id = task.id(); - - let _ = rt.block_on(task); - drop(rt); - - assert_eq!( - before_task_poll_callback_task_id.lock().unwrap().unwrap(), - spawned_task_id - ); - assert_eq!( - after_task_poll_callback_task_id.lock().unwrap().unwrap(), - spawned_task_id - ); - assert_eq!(poll_start.load(std::sync::atomic::Ordering::Relaxed), 4); - assert_eq!(poll_stop.load(std::sync::atomic::Ordering::Relaxed), 4); -} diff --git a/tokio/tests/task_builder.rs b/tokio/tests/task_builder.rs index c700f229f9f..63cd9d925f1 100644 --- a/tokio/tests/task_builder.rs +++ b/tokio/tests/task_builder.rs @@ -8,22 +8,22 @@ use tokio::{ #[test] async fn spawn_with_name() { - let result = Builder::new() - .name("name") - .spawn(async { "task executed" }) - .unwrap() - .await; + let mut b = Builder::new(); + + b.name("name"); + + let result = b.spawn(async { "task executed" }).unwrap().await; assert_eq!(result.unwrap(), "task executed"); } #[test] async fn spawn_blocking_with_name() { - let result = Builder::new() - .name("name") - .spawn_blocking(|| "task executed") - .unwrap() - .await; + let mut b = Builder::new(); + + b.name("name"); + + let result = b.spawn_blocking(|| "task executed").unwrap().await; assert_eq!(result.unwrap(), "task executed"); } @@ -33,11 +33,11 @@ async fn spawn_local_with_name() { let unsend_data = Rc::new("task executed"); let result = LocalSet::new() .run_until(async move { - Builder::new() - .name("name") - .spawn_local(async move { unsend_data }) - .unwrap() - .await + let mut b = Builder::new(); + + b.name("name"); + + b.spawn_local(async move { unsend_data }).unwrap().await }) .await; diff --git a/tokio/tests/task_hooks.rs b/tokio/tests/task_hooks.rs index 185b9126cca..372af778cfd 100644 --- a/tokio/tests/task_hooks.rs +++ b/tokio/tests/task_hooks.rs @@ -1,75 +1,445 @@ -#![warn(rust_2018_idioms)] -#![cfg(all(feature = "full", tokio_unstable, target_has_atomic = "64"))] +#![cfg(all( + feature = "full", + tokio_unstable, + target_has_atomic = "64", + not(target_arch = "wasm32") +))] -use std::collections::HashSet; use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; +use tokio::runtime; +use tokio::runtime::{ + AfterTaskPollContext, BeforeTaskPollContext, OnChildTaskSpawnContext, OnTaskTerminateContext, + OnTopLevelTaskSpawnContext, TaskHookHarness, TaskHookHarnessFactory, +}; -use tokio::runtime::Builder; +#[test] +fn runtime_default_factory() { + let ct = runtime::Builder::new_current_thread(); + let mt = runtime::Builder::new_multi_thread(); + let mta = runtime::Builder::new_multi_thread_alt(); + + run_runtime_default_factory(ct); + run_runtime_default_factory(mt); + run_runtime_default_factory(mta); +} + +#[test] +fn parent_child_chaining() { + let ct = runtime::Builder::new_current_thread(); + let mt = runtime::Builder::new_multi_thread(); + let mta = runtime::Builder::new_multi_thread_alt(); + + run_parent_child_chaining(ct); + run_parent_child_chaining(mt); + run_parent_child_chaining(mta); +} -const TASKS: usize = 8; -const ITERATIONS: usize = 64; -/// Assert that the spawn task hook always fires when set. #[test] -fn spawn_task_hook_fires() { - let count = Arc::new(AtomicUsize::new(0)); - let count2 = Arc::clone(&count); +fn before_poll() { + let ct = runtime::Builder::new_current_thread(); + let mt = runtime::Builder::new_multi_thread(); + let mta = runtime::Builder::new_multi_thread_alt(); - let ids = Arc::new(Mutex::new(HashSet::new())); - let ids2 = Arc::clone(&ids); + run_before_poll(ct); + run_before_poll(mt); + run_before_poll(mta); +} + +#[test] +fn after_poll() { + let ct = runtime::Builder::new_current_thread(); + let mt = runtime::Builder::new_multi_thread(); + let mta = runtime::Builder::new_multi_thread_alt(); - let runtime = Builder::new_current_thread() - .on_task_spawn(move |data| { - ids2.lock().unwrap().insert(data.id()); + run_after_poll(ct); + run_after_poll(mt); + run_after_poll(mta); +} + +#[test] +fn terminate() { + let ct = runtime::Builder::new_current_thread(); + + run_terminate(ct); +} + +#[test] +fn hook_switching() { + let ct = runtime::Builder::new_current_thread(); + let mt = runtime::Builder::new_multi_thread(); + let mta = runtime::Builder::new_multi_thread_alt(); + + run_hook_switching(ct); + run_hook_switching(mt); + run_hook_switching(mta); +} - count2.fetch_add(1, Ordering::SeqCst); +#[test] +fn override_hooks() { + let ct = runtime::Builder::new_current_thread(); + let mt = runtime::Builder::new_multi_thread(); + let mta = runtime::Builder::new_multi_thread_alt(); + + run_override(ct); + run_override(mt); + run_override(mta); +} + +fn run_runtime_default_factory(mut builder: runtime::Builder) { + struct TestFactory { + counter: Arc, + } + + impl TaskHookHarnessFactory for TestFactory { + fn on_top_level_spawn( + &self, + _ctx: &mut OnTopLevelTaskSpawnContext<'_>, + ) -> Option> { + self.counter.fetch_add(1, Ordering::SeqCst); + None + } + } + + let counter = Arc::new(AtomicUsize::new(0)); + + let rt = builder + .hook_harness_factory(TestFactory { + counter: counter.clone(), }) .build() .unwrap(); - for _ in 0..TASKS { - runtime.spawn(std::future::pending::<()>()); + rt.spawn(async {}); + + assert_eq!(counter.load(Ordering::SeqCst), 1); + + let handle = rt.handle(); + + handle.spawn(async {}); + + assert_eq!(counter.load(Ordering::SeqCst), 2); + + rt.block_on(async {}); + + assert_eq!(counter.load(Ordering::SeqCst), 2); + + rt.block_on(async { tokio::spawn(async {}) }); + + assert_eq!(counter.load(Ordering::SeqCst), 3); + + // block on a future which spawns a future and waits for it, which in turn spawns another future + // + // this checks that stuff works from on-worker within a multithreaded runtime + let _ = rt.block_on(async { tokio::spawn(async { tokio::spawn(async {}) }).await }); + + assert_eq!(counter.load(Ordering::SeqCst), 5); +} + +fn run_parent_child_chaining(mut builder: runtime::Builder) { + struct TestFactory { + parent_spawns: Arc, + child_spawns: Arc, + } + + struct TestHooks { + spawns: Arc, + } + + impl TaskHookHarnessFactory for TestFactory { + fn on_top_level_spawn( + &self, + _ctx: &mut OnTopLevelTaskSpawnContext<'_>, + ) -> Option> { + self.parent_spawns.fetch_add(1, Ordering::SeqCst); + + Some(Box::new(TestHooks { + spawns: self.child_spawns.clone(), + })) + } } - let count_realized = count.load(Ordering::SeqCst); - assert_eq!( - TASKS, count_realized, - "Total number of spawned task hook invocations was incorrect, expected {TASKS}, got {}", - count_realized - ); + impl TaskHookHarness for TestHooks { + fn on_child_spawn( + &mut self, + _ctx: &mut OnChildTaskSpawnContext<'_>, + ) -> Option> { + self.spawns.fetch_add(1, Ordering::SeqCst); + + Some(Box::new(Self { + spawns: self.spawns.clone(), + })) + } + } + + let parent_spawns = Arc::new(AtomicUsize::new(0)); + let child_spawns = Arc::new(AtomicUsize::new(0)); + + let rt = builder + .hook_harness_factory(TestFactory { + parent_spawns: parent_spawns.clone(), + child_spawns: child_spawns.clone(), + }) + .build() + .unwrap(); + + rt.spawn(async {}); - let count_ids_realized = ids.lock().unwrap().len(); + assert_eq!(parent_spawns.load(Ordering::SeqCst), 1); + assert_eq!(child_spawns.load(Ordering::SeqCst), 0); - assert_eq!( - TASKS, count_ids_realized, - "Total number of spawned task hook invocations was incorrect, expected {TASKS}, got {}", - count_realized - ); + let _ = rt.block_on(async { tokio::spawn(async { tokio::spawn(async {}) }).await }); + + assert_eq!(parent_spawns.load(Ordering::SeqCst), 2); + assert_eq!(child_spawns.load(Ordering::SeqCst), 1); } -/// Assert that the terminate task hook always fires when set. -#[test] -fn terminate_task_hook_fires() { - let count = Arc::new(AtomicUsize::new(0)); - let count2 = Arc::clone(&count); +fn run_before_poll(mut builder: runtime::Builder) { + struct TestFactory { + polls: Arc, + } + + struct TestHooks { + polls: Arc, + } + + impl TaskHookHarnessFactory for TestFactory { + fn on_top_level_spawn( + &self, + _ctx: &mut OnTopLevelTaskSpawnContext<'_>, + ) -> Option> { + Some(Box::new(TestHooks { + polls: self.polls.clone(), + })) + } + } + + impl TaskHookHarness for TestHooks { + fn before_poll(&mut self, _ctx: &mut BeforeTaskPollContext<'_>) { + self.polls.fetch_add(1, Ordering::SeqCst); + } + } + + let polls = Arc::new(AtomicUsize::new(0)); - let runtime = Builder::new_current_thread() - .on_task_terminate(move |_data| { - count2.fetch_add(1, Ordering::SeqCst); + let rt = builder + .hook_harness_factory(TestFactory { + polls: polls.clone(), }) .build() .unwrap(); - for _ in 0..TASKS { - runtime.spawn(std::future::ready(())); + rt.block_on(async {}); + assert_eq!(polls.load(Ordering::SeqCst), 0); + + let _ = rt.block_on(async { tokio::spawn(async {}).await }); + assert_eq!(polls.load(Ordering::SeqCst), 1); + + let _ = rt.block_on(async { tokio::spawn(async { tokio::spawn(async {}).await }).await }); + assert_eq!(polls.load(Ordering::SeqCst), 4); +} + +fn run_after_poll(mut builder: runtime::Builder) { + struct TestFactory { + polls: Arc, + } + + struct TestHooks { + polls: Arc, } - runtime.block_on(async { - // tick the runtime a bunch to close out tasks - for _ in 0..ITERATIONS { - tokio::task::yield_now().await; + impl TaskHookHarnessFactory for TestFactory { + fn on_top_level_spawn( + &self, + _ctx: &mut OnTopLevelTaskSpawnContext<'_>, + ) -> Option> { + Some(Box::new(TestHooks { + polls: self.polls.clone(), + })) } + } + + impl TaskHookHarness for TestHooks { + fn after_poll(&mut self, _ctx: &mut AfterTaskPollContext<'_>) { + self.polls.fetch_add(1, Ordering::SeqCst); + } + } + + let polls = Arc::new(AtomicUsize::new(0)); + + let rt = builder + .hook_harness_factory(TestFactory { + polls: polls.clone(), + }) + .build() + .unwrap(); + + rt.block_on(async {}); + assert_eq!(polls.load(Ordering::SeqCst), 0); + + let _ = rt.block_on(async { tokio::spawn(async {}).await }); + assert_eq!(polls.load(Ordering::SeqCst), 1); + + let _ = rt.block_on(async { tokio::spawn(async { tokio::spawn(async {}).await }).await }); + assert_eq!(polls.load(Ordering::SeqCst), 4); +} + +fn run_terminate(mut builder: runtime::Builder) { + struct TestFactory { + terminations: Arc, + } + + struct TestHooks { + terminations: Arc, + } + + impl TaskHookHarnessFactory for TestFactory { + fn on_top_level_spawn( + &self, + _ctx: &mut OnTopLevelTaskSpawnContext<'_>, + ) -> Option> { + Some(Box::new(TestHooks { + terminations: self.terminations.clone(), + })) + } + } + + impl TaskHookHarness for TestHooks { + fn on_task_terminate(&mut self, _ctx: &mut OnTaskTerminateContext<'_>) { + self.terminations.fetch_add(1, Ordering::SeqCst); + } + } + + let terminations = Arc::new(AtomicUsize::new(0)); + + let rt = builder + .hook_harness_factory(TestFactory { + terminations: terminations.clone(), + }) + .build() + .unwrap(); + + let _ = rt.block_on(async { tokio::spawn(async { tokio::spawn(async {}).await }).await }); + + assert_eq!(terminations.load(Ordering::SeqCst), 2); +} + +fn run_hook_switching(mut builder: runtime::Builder) { + struct TestFactory { + next_id: Arc, + flag: Arc, + } + + struct TestHooks { + id: usize, + flag: Arc, + } + + impl TaskHookHarnessFactory for TestFactory { + fn on_top_level_spawn( + &self, + _ctx: &mut OnTopLevelTaskSpawnContext<'_>, + ) -> Option> { + Some(Box::new(TestHooks { + id: self.next_id.fetch_add(1, Ordering::SeqCst), + flag: self.flag.clone(), + })) + } + } + + impl TaskHookHarness for TestHooks { + fn before_poll(&mut self, _ctx: &mut BeforeTaskPollContext<'_>) { + self.flag.store(self.id, Ordering::SeqCst); + } + } + + let polls = Arc::new(AtomicUsize::new(0)); + + let rt = builder + .hook_harness_factory(TestFactory { + next_id: Arc::new(Default::default()), + flag: polls.clone(), + }) + .build() + .unwrap(); + + let _ = rt.block_on(async { tokio::spawn(async {}).await }); + assert_eq!(polls.load(Ordering::SeqCst), 0); + + let _ = rt.block_on(async { tokio::spawn(async { tokio::spawn(async {}).await }).await }); + assert_eq!(polls.load(Ordering::SeqCst), 1); + + let _ = rt.block_on(async { tokio::spawn(async {}).await }); + assert_eq!(polls.load(Ordering::SeqCst), 3); +} + +fn run_override(mut builder: runtime::Builder) { + struct TestFactory { + counter: Arc, + } + + struct TestHooks { + counter: Arc, + } + + impl TaskHookHarness for TestHooks { + fn before_poll(&mut self, _ctx: &mut BeforeTaskPollContext<'_>) { + self.counter.fetch_add(1, Ordering::SeqCst); + } + + fn on_child_spawn( + &mut self, + _ctx: &mut OnChildTaskSpawnContext<'_>, + ) -> Option> { + Some(Box::new(Self { + counter: self.counter.clone(), + })) + } + } + + impl TaskHookHarnessFactory for TestFactory { + fn on_top_level_spawn( + &self, + _ctx: &mut OnTopLevelTaskSpawnContext<'_>, + ) -> Option> { + self.counter.fetch_add(1, Ordering::SeqCst); + None + } + } + + let factory_counter = Arc::new(AtomicUsize::new(0)); + let builder_counter = Arc::new(AtomicUsize::new(0)); + + let rt = builder + .hook_harness_factory(TestFactory { + counter: factory_counter.clone(), + }) + .build() + .unwrap(); + + rt.spawn(async {}); + + assert_eq!(factory_counter.load(Ordering::SeqCst), 1); + + let _ = rt.block_on(async { + tokio::task::spawn_with_hooks( + async {}, + TestHooks { + counter: builder_counter.clone(), + }, + ) + .await + }); + + assert_eq!(factory_counter.load(Ordering::SeqCst), 1); + assert_eq!(builder_counter.load(Ordering::SeqCst), 1); + + let _ = rt.block_on(async { + let counter = builder_counter.clone(); + tokio::spawn(async { tokio::task::spawn_with_hooks(async {}, TestHooks { counter }).await }) + .await }); - assert_eq!(TASKS, count.load(Ordering::SeqCst)); + assert_eq!(factory_counter.load(Ordering::SeqCst), 2); + assert_eq!(builder_counter.load(Ordering::SeqCst), 2); } diff --git a/tokio/tests/tracing_task.rs b/tokio/tests/tracing_task.rs index a9317bf5b12..f2adf573a9d 100644 --- a/tokio/tests/tracing_task.rs +++ b/tokio/tests/tracing_task.rs @@ -64,9 +64,11 @@ async fn task_builder_name_recorded() { { let _guard = tracing::subscriber::set_default(subscriber); - task::Builder::new() - .name("test-task") - .spawn(futures::future::ready(())) + let mut b = task::Builder::new(); + + b.name("test-task"); + + b.spawn(futures::future::ready(())) .unwrap() .await .expect("failed to await join handle");