diff --git a/Cargo.lock b/Cargo.lock index 453647cf5e..03bb44ffbd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7414,6 +7414,7 @@ dependencies = [ "oximeter-producer", "oxnet", "oxql-types", + "parallel-task-set", "parse-display", "paste", "pem", @@ -8825,6 +8826,15 @@ dependencies = [ "unicode-width 0.2.0", ] +[[package]] +name = "parallel-task-set" +version = "0.1.0" +dependencies = [ + "omicron-workspace-hack", + "rand 0.8.5", + "tokio", +] + [[package]] name = "parking" version = "2.2.1" diff --git a/Cargo.toml b/Cargo.toml index f8c8df4c1e..6d1909d30d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -116,6 +116,7 @@ members = [ "oximeter/timeseries-macro", "oximeter/types", "package", + "parallel-task-set", "passwords", "range-requests", "rpaths", @@ -267,6 +268,7 @@ default-members = [ "oximeter/timeseries-macro", "oximeter/types", "package", + "parallel-task-set", "passwords", "range-requests", "rpaths", @@ -583,6 +585,7 @@ oximeter-timeseries-macro = { path = "oximeter/timeseries-macro" } oximeter-types = { path = "oximeter/types" } oxql-types = { path = "oximeter/oxql-types" } p256 = "0.13" +parallel-task-set = { path = "parallel-task-set" } parse-display = "0.10.0" partial-io = { version = "0.5.4", features = ["proptest1", "tokio1"] } parse-size = "1.1.0" diff --git a/nexus/Cargo.toml b/nexus/Cargo.toml index 2f102e8863..14a01e6e54 100644 --- a/nexus/Cargo.toml +++ b/nexus/Cargo.toml @@ -69,6 +69,7 @@ oximeter-client.workspace = true oximeter-db = { workspace = true, default-features = false, features = [ "oxql" ] } oxnet.workspace = true oxql-types.workspace = true +parallel-task-set.workspace = true parse-display.workspace = true paste.workspace = true # See omicron-rpaths for more about the "pq-sys" dependency. diff --git a/nexus/src/app/background/tasks/support_bundle_collector.rs b/nexus/src/app/background/tasks/support_bundle_collector.rs index ab61422332..6ef99263fe 100644 --- a/nexus/src/app/background/tasks/support_bundle_collector.rs +++ b/nexus/src/app/background/tasks/support_bundle_collector.rs @@ -34,6 +34,7 @@ use omicron_uuid_kinds::OmicronZoneUuid; use omicron_uuid_kinds::SledUuid; use omicron_uuid_kinds::SupportBundleUuid; use omicron_uuid_kinds::ZpoolUuid; +use parallel_task_set::ParallelTaskSet; use serde_json::json; use sha2::{Digest, Sha256}; use std::future::Future; @@ -42,7 +43,6 @@ use std::sync::Arc; use tokio::io::AsyncReadExt; use tokio::io::AsyncSeekExt; use tokio::io::SeekFrom; -use tokio::task::JoinSet; use tufaceous_artifact::ArtifactHash; use zip::ZipArchive; use zip::ZipWriter; @@ -566,41 +566,27 @@ impl BundleCollection { report.listed_in_service_sleds = true; const MAX_CONCURRENT_SLED_REQUESTS: usize = 16; - let mut sleds_iter = all_sleds.into_iter().peekable(); - let mut tasks = JoinSet::new(); - - // While we have incoming work to send to tasks (sleds_iter) - // or a task operating on that data (tasks)... - while sleds_iter.peek().is_some() || !tasks.is_empty() { - // Spawn tasks up to the concurrency limit - while tasks.len() < MAX_CONCURRENT_SLED_REQUESTS - && sleds_iter.peek().is_some() - { - if let Some(sled) = sleds_iter.next() { - let collection: Arc = self.clone(); - let dir = dir.path().to_path_buf(); - tasks.spawn({ - async move { - collection - .collect_data_from_sled(&sled, &dir) - .await - } - }); - } - } + let mut set = ParallelTaskSet::new_with_parallelism( + MAX_CONCURRENT_SLED_REQUESTS, + ); - // Await the completion of ongoing tasks. - // - // Keep collecting from other sleds, even if one or more of the - // sled collection tasks fail. - if let Some(result) = tasks.join_next().await { - if let Err(err) = result { - warn!( - &self.log, - "Failed to fully collect support bundle info from sled"; - "err" => ?err - ); + for sled in all_sleds.into_iter() { + set.spawn({ + let collection: Arc = self.clone(); + let dir = dir.path().to_path_buf(); + async move { + collection.collect_data_from_sled(&sled, &dir).await } + }); + } + + while let Some(result) = set.join_next().await { + if let Err(err) = result { + warn!( + &self.log, + "Failed to fully collect support bundle info from sled"; + "err" => ?err + ); } } } diff --git a/parallel-task-set/Cargo.toml b/parallel-task-set/Cargo.toml new file mode 100644 index 0000000000..a53f1945fc --- /dev/null +++ b/parallel-task-set/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "parallel-task-set" +version = "0.1.0" +edition = "2021" +license = "MPL-2.0" + +[dependencies] +tokio.workspace = true +omicron-workspace-hack.workspace = true + +[dev-dependencies] +rand.workspace = true + +[lints] +workspace = true diff --git a/parallel-task-set/src/lib.rs b/parallel-task-set/src/lib.rs new file mode 100644 index 0000000000..3cbe333297 --- /dev/null +++ b/parallel-task-set/src/lib.rs @@ -0,0 +1,146 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +use std::sync::Arc; +use tokio::sync::Semaphore; +use tokio::task::JoinSet; + +/// The default number of parallel tasks used by [ParallelTaskSet]. +pub const DEFAULT_MAX_PARALLELISM: usize = 16; + +/// A collection of tokio tasks which execute in parallel on distinct tokio +/// tasks, up to a user-specified maximum amount of parallelism. +/// +/// This parallelism is achieved by spawning tasks on a [JoinSet], +/// and may be further limited by the underlying machine's ability +/// to execute many tokio tasks. +/// +/// # Why not just use FuturesUnordered? +/// +/// FuturesUnordered can execute any number of futures concurrently, +/// but makes no attempt to execute them in parallel (assuming the underlying +/// futures are not themselves spawning additional tasks). +/// +/// # Why not just use a JoinSet? +/// +/// The tokio [JoinSet] has not limit on the "maximum number of tasks". +/// Given a bursty workload, it's possible to spawn an enormous number +/// of tasks, which may not be desirable. +/// +/// Although [ParallelTaskSet] uses a [JoinSet] internally, it also +/// respects a parallism capacity. +pub struct ParallelTaskSet { + semaphore: Arc, + set: JoinSet, +} + +impl Default for ParallelTaskSet { + fn default() -> Self { + ParallelTaskSet::new() + } +} + +impl ParallelTaskSet { + /// Creates a new [ParallelTaskSet], with [DEFAULT_MAX_PARALLELISM] as the + /// maximum number of tasks to run in parallel. + /// + /// If a different amount of parallism is desired, refer to: + /// [Self::new_with_parallelism]. + pub fn new() -> ParallelTaskSet { + Self::new_with_parallelism(DEFAULT_MAX_PARALLELISM) + } + + /// Creates a new [ParallelTaskSet], with `max_parallism` as the + /// maximum number of tasks to run in parallel. + pub fn new_with_parallelism(max_parallism: usize) -> ParallelTaskSet { + let semaphore = Arc::new(Semaphore::new(max_parallism)); + let set = JoinSet::new(); + + Self { semaphore, set } + } + + /// Spawn a task immediately, but only allow it to execute if the task + /// set is within the maximum parallelism constraint. + pub fn spawn(&mut self, command: F) + where + F: std::future::Future + Send + 'static, + { + let semaphore = Arc::clone(&self.semaphore); + let _abort_handle = self.set.spawn(async move { + // Hold onto the permit until the command finishes executing + let _permit = + semaphore.acquire_owned().await.expect("semaphore acquire"); + command.await + }); + } + + /// Waits for the next task to complete and return its output. + /// + /// # Panics + /// + /// This method panics if the task returns a JoinError + pub async fn join_next(&mut self) -> Option { + self.set.join_next().await.map(|r| r.expect("Failed to join task")) + } + + /// Wait for all commands to execute and return their output. + /// + /// # Panics + /// + /// This method panics any of the tasks return a JoinError + pub async fn join_all(self) -> Vec { + self.set.join_all().await + } +} + +#[cfg(test)] +mod test { + use super::*; + use rand::Rng; + use std::sync::Arc; + use std::sync::atomic::AtomicUsize; + use std::sync::atomic::Ordering; + + #[tokio::test] + async fn test_spawn_many() { + let count = Arc::new(AtomicUsize::new(0)); + + let task_limit = 16; + let mut set = ParallelTaskSet::new_with_parallelism(task_limit); + + for _ in 0..task_limit * 10 { + set.spawn({ + let count = count.clone(); + async move { + // How many tasks - including our own - are running right + // now? + let watermark = count.fetch_add(1, Ordering::SeqCst) + 1; + + // The tasks should all execute for a short but variable + // amount of time. + let duration_ms = rand::thread_rng().gen_range(0..10); + tokio::time::sleep(tokio::time::Duration::from_millis( + duration_ms, + )) + .await; + + count.fetch_sub(1, Ordering::SeqCst); + + watermark + } + }); + } + + let watermarks = set.join_all().await; + + for (i, watermark) in watermarks.into_iter().enumerate() { + println!("task {i} saw {watermark} concurrent tasks"); + + assert!( + watermark <= task_limit, + "Observed simultaneous task execution of {watermark} tasks on the {i}-th worker" + ); + } + } +}