diff --git a/crates/bevy_ecs/src/schedule/executor_parallel.rs b/crates/bevy_ecs/src/schedule/executor_parallel.rs index b2ab15c7a528c..9262b232dff99 100644 --- a/crates/bevy_ecs/src/schedule/executor_parallel.rs +++ b/crates/bevy_ecs/src/schedule/executor_parallel.rs @@ -179,7 +179,7 @@ impl ParallelExecutor { /// queues systems with no dependencies to run (or skip) at next opportunity. fn prepare_systems<'scope>( &mut self, - scope: &mut Scope<'scope, ()>, + scope: &Scope<'scope, ()>, systems: &'scope mut [ParallelSystemContainer], world: &'scope World, ) { diff --git a/crates/bevy_tasks/src/single_threaded_task_pool.rs b/crates/bevy_tasks/src/single_threaded_task_pool.rs index 8f248e1005c38..f55f3119801f3 100644 --- a/crates/bevy_tasks/src/single_threaded_task_pool.rs +++ b/crates/bevy_tasks/src/single_threaded_task_pool.rs @@ -68,7 +68,7 @@ impl TaskPool { let mut scope = Scope { executor, - results: Vec::new(), + results: Arc::new(Mutex::new(Vec::new())), }; f(&mut scope); @@ -76,11 +76,14 @@ impl TaskPool { // Loop until all tasks are done while executor.try_tick() {} - scope + let result = scope .results + .lock() + .unwrap() .iter() .map(|result| result.lock().unwrap().take().unwrap()) - .collect() + .collect(); + result } // Spawns a static future onto the JS event loop. For now it is returning FakeTask @@ -122,17 +125,17 @@ impl FakeTask { pub struct Scope<'scope, T> { executor: &'scope async_executor::LocalExecutor<'scope>, // Vector to gather results of all futures spawned during scope run - results: Vec>>>, + results: Arc>>>>>, } impl<'scope, T: Send + 'scope> Scope<'scope, T> { - pub fn spawn + 'scope + Send>(&mut self, f: Fut) { + pub fn spawn + 'scope + Send>(&self, f: Fut) { self.spawn_local(f); } - pub fn spawn_local + 'scope>(&mut self, f: Fut) { + pub fn spawn_local + 'scope>(&self, f: Fut) { let result = Arc::new(Mutex::new(None)); - self.results.push(result.clone()); + self.results.lock().unwrap().push(result.clone()); let f = async move { result.lock().unwrap().replace(f.await); }; diff --git a/crates/bevy_tasks/src/task_pool.rs b/crates/bevy_tasks/src/task_pool.rs index 597ebc334c872..5ab71be6224da 100644 --- a/crates/bevy_tasks/src/task_pool.rs +++ b/crates/bevy_tasks/src/task_pool.rs @@ -2,7 +2,7 @@ use std::{ future::Future, mem, pin::Pin, - sync::Arc, + sync::{Arc, Mutex}, thread::{self, JoinHandle}, }; @@ -164,7 +164,7 @@ impl TaskPool { /// This is similar to `rayon::scope` and `crossbeam::scope` pub fn scope<'scope, F, T>(&self, f: F) -> Vec where - F: FnOnce(&mut Scope<'scope, T>) + 'scope + Send, + F: FnOnce(&Scope<'scope, T>) + 'scope + Send, T: Send + 'static, { TaskPool::LOCAL_EXECUTOR.with(|local_executor| { @@ -179,19 +179,20 @@ impl TaskPool { let mut scope = Scope { executor, local_executor, - spawned: Vec::new(), + spawned: Arc::new(Mutex::new(Vec::new())), }; f(&mut scope); - if scope.spawned.is_empty() { + let mut spawned = scope.spawned.lock().unwrap(); + if spawned.is_empty() { Vec::default() - } else if scope.spawned.len() == 1 { - vec![future::block_on(&mut scope.spawned[0])] + } else if spawned.len() == 1 { + vec![future::block_on(&mut spawned[0])] } else { let fut = async move { - let mut results = Vec::with_capacity(scope.spawned.len()); - for task in scope.spawned { + let mut results = Vec::with_capacity(spawned.len()); + for task in spawned.iter_mut() { results.push(task.await); } @@ -265,7 +266,7 @@ impl Default for TaskPool { pub struct Scope<'scope, T> { executor: &'scope async_executor::Executor<'scope>, local_executor: &'scope async_executor::LocalExecutor<'scope>, - spawned: Vec>, + spawned: Arc>>>, } impl<'scope, T: Send + 'scope> Scope<'scope, T> { @@ -277,9 +278,9 @@ impl<'scope, T: Send + 'scope> Scope<'scope, T> { /// instead. /// /// For more information, see [`TaskPool::scope`]. - pub fn spawn + 'scope + Send>(&mut self, f: Fut) { + pub fn spawn + 'scope + Send>(&self, f: Fut) { let task = self.executor.spawn(f); - self.spawned.push(task); + self.spawned.lock().unwrap().push(task); } /// Spawns a scoped future onto the thread-local executor. The scope *must* outlive @@ -288,9 +289,9 @@ impl<'scope, T: Send + 'scope> Scope<'scope, T> { /// [`Scope::spawn`] instead, unless the provided future is not `Send`. /// /// For more information, see [`TaskPool::scope`]. - pub fn spawn_local + 'scope>(&mut self, f: Fut) { + pub fn spawn_local + 'scope>(&self, f: Fut) { let task = self.local_executor.spawn(f); - self.spawned.push(task); + self.spawned.lock().unwrap().push(task); } } @@ -334,6 +335,43 @@ mod tests { assert_eq!(count.load(Ordering::Relaxed), 100); } + #[test] + fn test_nested_spawn() { + let pool = TaskPool::new(); + + let foo = Box::new(42); + let foo = &*foo; + + let count = Arc::new(AtomicI32::new(0)); + + let outputs = pool.scope(|scope| { + for _ in 0..10 { + let count_clone = count.clone(); + scope.spawn(async move { + for _ in 0..10 { + let count_clone_clone = count_clone.clone(); + scope.spawn(async move { + if *foo != 42 { + panic!("not 42!?!?") + } else { + count_clone_clone.fetch_add(1, Ordering::Relaxed); + *foo + } + }); + } + *foo + }); + } + }); + + for output in &outputs { + assert_eq!(*output, 42); + } + + assert_eq!(outputs.len(), 100); + assert_eq!(count.load(Ordering::Relaxed), 100); + } + #[test] fn test_mixed_spawn_local_and_spawn() { let pool = TaskPool::new();