Skip to content

Add parallel-task-set crate, test it, use it #8174

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ members = [
"oximeter/timeseries-macro",
"oximeter/types",
"package",
"parallel-task-set",
"passwords",
"range-requests",
"rpaths",
Expand Down Expand Up @@ -267,6 +268,7 @@ default-members = [
"oximeter/timeseries-macro",
"oximeter/types",
"package",
"parallel-task-set",
"passwords",
"range-requests",
"rpaths",
Expand Down Expand Up @@ -583,6 +585,7 @@ oximeter-timeseries-macro = { path = "oximeter/timeseries-macro" }
oximeter-types = { path = "oximeter/types" }
oxql-types = { path = "oximeter/oxql-types" }
p256 = "0.13"
parallel-task-set = { path = "parallel-task-set" }
parse-display = "0.10.0"
partial-io = { version = "0.5.4", features = ["proptest1", "tokio1"] }
parse-size = "1.1.0"
Expand Down
1 change: 1 addition & 0 deletions nexus/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ oximeter-client.workspace = true
oximeter-db = { workspace = true, default-features = false, features = [ "oxql" ] }
oxnet.workspace = true
oxql-types.workspace = true
parallel-task-set.workspace = true
parse-display.workspace = true
paste.workspace = true
# See omicron-rpaths for more about the "pq-sys" dependency.
Expand Down
54 changes: 20 additions & 34 deletions nexus/src/app/background/tasks/support_bundle_collector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ use omicron_uuid_kinds::OmicronZoneUuid;
use omicron_uuid_kinds::SledUuid;
use omicron_uuid_kinds::SupportBundleUuid;
use omicron_uuid_kinds::ZpoolUuid;
use parallel_task_set::ParallelTaskSet;
use serde_json::json;
use sha2::{Digest, Sha256};
use std::future::Future;
Expand All @@ -42,7 +43,6 @@ use std::sync::Arc;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncSeekExt;
use tokio::io::SeekFrom;
use tokio::task::JoinSet;
use tufaceous_artifact::ArtifactHash;
use zip::ZipArchive;
use zip::ZipWriter;
Expand Down Expand Up @@ -566,41 +566,27 @@ impl BundleCollection {
report.listed_in_service_sleds = true;

const MAX_CONCURRENT_SLED_REQUESTS: usize = 16;
let mut sleds_iter = all_sleds.into_iter().peekable();
let mut tasks = JoinSet::new();

// While we have incoming work to send to tasks (sleds_iter)
// or a task operating on that data (tasks)...
while sleds_iter.peek().is_some() || !tasks.is_empty() {
// Spawn tasks up to the concurrency limit
while tasks.len() < MAX_CONCURRENT_SLED_REQUESTS
&& sleds_iter.peek().is_some()
{
if let Some(sled) = sleds_iter.next() {
let collection: Arc<BundleCollection> = self.clone();
let dir = dir.path().to_path_buf();
tasks.spawn({
async move {
collection
.collect_data_from_sled(&sled, &dir)
.await
}
});
}
}
let mut set = ParallelTaskSet::new_with_parallelism(
MAX_CONCURRENT_SLED_REQUESTS,
);

// Await the completion of ongoing tasks.
//
// Keep collecting from other sleds, even if one or more of the
// sled collection tasks fail.
if let Some(result) = tasks.join_next().await {
if let Err(err) = result {
warn!(
&self.log,
"Failed to fully collect support bundle info from sled";
"err" => ?err
);
for sled in all_sleds.into_iter() {
set.spawn({
let collection: Arc<BundleCollection> = self.clone();
let dir = dir.path().to_path_buf();
async move {
collection.collect_data_from_sled(&sled, &dir).await
}
});
}

while let Some(result) = set.join_next().await {
if let Err(err) = result {
warn!(
&self.log,
"Failed to fully collect support bundle info from sled";
"err" => ?err
);
}
}
}
Expand Down
15 changes: 15 additions & 0 deletions parallel-task-set/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[package]
name = "parallel-task-set"
version = "0.1.0"
edition = "2021"
license = "MPL-2.0"

[dependencies]
tokio.workspace = true
omicron-workspace-hack.workspace = true

[dev-dependencies]
rand.workspace = true

[lints]
workspace = true
146 changes: 146 additions & 0 deletions parallel-task-set/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.

use std::sync::Arc;
use tokio::sync::Semaphore;
use tokio::task::JoinSet;

/// The default number of parallel tasks used by [ParallelTaskSet].
pub const DEFAULT_MAX_PARALLELISM: usize = 16;

/// A collection of tokio tasks which execute in parallel on distinct tokio
/// tasks, up to a user-specified maximum amount of parallelism.
///
/// This parallelism is achieved by spawning tasks on a [JoinSet],
/// and may be further limited by the underlying machine's ability
/// to execute many tokio tasks.
///
/// # Why not just use FuturesUnordered?
///
/// FuturesUnordered can execute any number of futures concurrently,
/// but makes no attempt to execute them in parallel (assuming the underlying
/// futures are not themselves spawning additional tasks).
///
/// # Why not just use a JoinSet?
///
/// The tokio [JoinSet] has not limit on the "maximum number of tasks".
/// Given a bursty workload, it's possible to spawn an enormous number
/// of tasks, which may not be desirable.
///
/// Although [ParallelTaskSet] uses a [JoinSet] internally, it also
/// respects a parallism capacity.
pub struct ParallelTaskSet<T> {
semaphore: Arc<Semaphore>,
set: JoinSet<T>,
}

impl<T: 'static + Send> Default for ParallelTaskSet<T> {
fn default() -> Self {
ParallelTaskSet::new()
}
}

impl<T: 'static + Send> ParallelTaskSet<T> {
/// Creates a new [ParallelTaskSet], with [DEFAULT_MAX_PARALLELISM] as the
/// maximum number of tasks to run in parallel.
///
/// If a different amount of parallism is desired, refer to:
/// [Self::new_with_parallelism].
pub fn new() -> ParallelTaskSet<T> {
Self::new_with_parallelism(DEFAULT_MAX_PARALLELISM)
}

/// Creates a new [ParallelTaskSet], with `max_parallism` as the
/// maximum number of tasks to run in parallel.
pub fn new_with_parallelism(max_parallism: usize) -> ParallelTaskSet<T> {
let semaphore = Arc::new(Semaphore::new(max_parallism));
let set = JoinSet::new();

Self { semaphore, set }
}

/// Spawn a task immediately, but only allow it to execute if the task
/// set is within the maximum parallelism constraint.
pub fn spawn<F>(&mut self, command: F)
where
F: std::future::Future<Output = T> + Send + 'static,
{
let semaphore = Arc::clone(&self.semaphore);
let _abort_handle = self.set.spawn(async move {
// Hold onto the permit until the command finishes executing
let _permit =
semaphore.acquire_owned().await.expect("semaphore acquire");
command.await
});
}

/// Waits for the next task to complete and return its output.
///
/// # Panics
///
/// This method panics if the task returns a JoinError
pub async fn join_next(&mut self) -> Option<T> {
self.set.join_next().await.map(|r| r.expect("Failed to join task"))
}

/// Wait for all commands to execute and return their output.
///
/// # Panics
///
/// This method panics any of the tasks return a JoinError
pub async fn join_all(self) -> Vec<T> {
self.set.join_all().await
}
}

#[cfg(test)]
mod test {
use super::*;
use rand::Rng;
use std::sync::Arc;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;

#[tokio::test]
async fn test_spawn_many() {
let count = Arc::new(AtomicUsize::new(0));

let task_limit = 16;
let mut set = ParallelTaskSet::new_with_parallelism(task_limit);

for _ in 0..task_limit * 10 {
set.spawn({
let count = count.clone();
async move {
// How many tasks - including our own - are running right
// now?
let watermark = count.fetch_add(1, Ordering::SeqCst) + 1;

// The tasks should all execute for a short but variable
// amount of time.
let duration_ms = rand::thread_rng().gen_range(0..10);
tokio::time::sleep(tokio::time::Duration::from_millis(
duration_ms,
))
.await;

count.fetch_sub(1, Ordering::SeqCst);

watermark
}
});
}

let watermarks = set.join_all().await;

for (i, watermark) in watermarks.into_iter().enumerate() {
println!("task {i} saw {watermark} concurrent tasks");

assert!(
watermark <= task_limit,
"Observed simultaneous task execution of {watermark} tasks on the {i}-th worker"
);
}
}
}
Loading