Skip to content

Commit c181071

Browse files
committed
Add support for scoped threads
Add loom::thread::scope to mirror std::thread::scope provided by the standard library.
1 parent ce8a232 commit c181071

File tree

2 files changed

+358
-47
lines changed

2 files changed

+358
-47
lines changed

Diff for: src/thread.rs

+260-47
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,7 @@ use std::{fmt, io};
1414
use tracing::trace;
1515

1616
/// Mock implementation of `std::thread::JoinHandle`.
17-
pub struct JoinHandle<T> {
18-
result: Arc<Mutex<Option<std::thread::Result<T>>>>,
19-
notify: rt::Notify,
20-
thread: Thread,
21-
}
17+
pub struct JoinHandle<T>(JoinHandleInner<'static, T>);
2218

2319
/// Mock implementation of `std::thread::Thread`.
2420
#[derive(Clone, Debug)]
@@ -129,7 +125,7 @@ where
129125
F: 'static,
130126
T: 'static,
131127
{
132-
spawn_internal(f, None, None, location!())
128+
JoinHandle(spawn_internal_static(f, None, None, location!()))
133129
}
134130

135131
/// Mock implementation of `std::thread::park`.
@@ -143,43 +139,6 @@ pub fn park() {
143139
rt::park(location!());
144140
}
145141

146-
fn spawn_internal<F, T>(
147-
f: F,
148-
name: Option<String>,
149-
stack_size: Option<usize>,
150-
location: Location,
151-
) -> JoinHandle<T>
152-
where
153-
F: FnOnce() -> T,
154-
F: 'static,
155-
T: 'static,
156-
{
157-
let result = Arc::new(Mutex::new(None));
158-
let notify = rt::Notify::new(true, false);
159-
160-
let id = {
161-
let name = name.clone();
162-
let result = result.clone();
163-
rt::spawn(stack_size, move || {
164-
rt::execution(|execution| {
165-
init_current(execution, name);
166-
});
167-
168-
*result.lock().unwrap() = Some(Ok(f()));
169-
notify.notify(location);
170-
})
171-
};
172-
173-
JoinHandle {
174-
result,
175-
notify,
176-
thread: Thread {
177-
id: ThreadId { id },
178-
name,
179-
},
180-
}
181-
}
182-
183142
impl Builder {
184143
/// Generates the base configuration for spawning a thread, from which
185144
/// configuration methods can be chained.
@@ -217,21 +176,53 @@ impl Builder {
217176
F: Send + 'static,
218177
T: Send + 'static,
219178
{
220-
Ok(spawn_internal(f, self.name, self.stack_size, location!()))
179+
Ok(JoinHandle(spawn_internal_static(
180+
f,
181+
self.name,
182+
self.stack_size,
183+
location!(),
184+
)))
185+
}
186+
}
187+
188+
impl Builder {
189+
/// Spawns a new scoped thread using the settings set through this `Builder`.
190+
pub fn spawn_scoped<'scope, 'env, F, T>(
191+
self,
192+
scope: &'scope Scope<'scope, 'env>,
193+
f: F,
194+
) -> io::Result<ScopedJoinHandle<'scope, T>>
195+
where
196+
F: FnOnce() -> T + Send + 'scope,
197+
T: Send + 'scope,
198+
{
199+
Ok(ScopedJoinHandle(
200+
// Safety: the call to this function requires a `&'scope Scope`
201+
// which can only be constructed by `scope()`, which ensures that
202+
// all spawned threads are joined before the `Scope` is destroyed.
203+
unsafe {
204+
spawn_internal(
205+
f,
206+
self.name,
207+
self.stack_size,
208+
Some(scope.data.clone()),
209+
location!(),
210+
)
211+
},
212+
))
221213
}
222214
}
223215

224216
impl<T> JoinHandle<T> {
225217
/// Waits for the associated thread to finish.
226218
#[track_caller]
227219
pub fn join(self) -> std::thread::Result<T> {
228-
self.notify.wait(location!());
229-
self.result.lock().unwrap().take().unwrap()
220+
self.0.join()
230221
}
231222

232223
/// Gets a handle to the underlying [`Thread`]
233224
pub fn thread(&self) -> &Thread {
234-
&self.thread
225+
self.0.thread()
235226
}
236227
}
237228

@@ -312,3 +303,225 @@ impl<T: 'static> fmt::Debug for LocalKey<T> {
312303
f.pad("LocalKey { .. }")
313304
}
314305
}
306+
307+
/// A scope for spawning scoped threads.
308+
///
309+
/// See [`scope`] for more details.
310+
#[derive(Debug)]
311+
pub struct Scope<'scope, 'env: 'scope> {
312+
data: Arc<ScopeData>,
313+
scope: PhantomData<&'scope mut &'scope ()>,
314+
env: PhantomData<&'env mut &'env ()>,
315+
}
316+
317+
/// An owned permission to join on a scoped thread (block on its termination).
318+
///
319+
/// See [`Scope::spawn`] for details.
320+
#[derive(Debug)]
321+
pub struct ScopedJoinHandle<'scope, T>(JoinHandleInner<'scope, T>);
322+
323+
/// Create a scope for spawning scoped threads.
324+
///
325+
/// Mock implementation of [`std::thread::scope`].
326+
#[track_caller]
327+
pub fn scope<'env, F, T>(f: F) -> T
328+
where
329+
F: for<'scope> FnOnce(&'scope Scope<'scope, 'env>) -> T,
330+
{
331+
let scope = Scope {
332+
data: Arc::new(ScopeData {
333+
running_threads: Mutex::default(),
334+
main_thread: current(),
335+
}),
336+
env: PhantomData,
337+
scope: PhantomData,
338+
};
339+
340+
// Run `f`, but catch panics so we can make sure to wait for all the threads to join.
341+
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(&scope)));
342+
343+
// Wait until all the threads are finished. This is required to fulfill
344+
// the safety requirements of `spawn_internal`.
345+
let running = loop {
346+
{
347+
let running = scope.data.running_threads.lock().unwrap();
348+
if running.count == 0 {
349+
break running;
350+
}
351+
}
352+
park();
353+
};
354+
355+
for notify in &running.notify_on_finished {
356+
notify.wait(location!())
357+
}
358+
359+
// Throw any panic from `f`, or the return value of `f` if no thread panicked.
360+
match result {
361+
Err(e) => std::panic::resume_unwind(e),
362+
Ok(result) => result,
363+
}
364+
}
365+
366+
impl<'scope, 'env> Scope<'scope, 'env> {
367+
/// Spawns a new thread within a scope, returning a [`ScopedJoinHandle`] for it.
368+
///
369+
/// See [`std::thread::Scope`] and [`std::thread::scope`] for details.
370+
pub fn spawn<F, T>(&'scope self, f: F) -> ScopedJoinHandle<'scope, T>
371+
where
372+
F: FnOnce() -> T + Send + 'scope,
373+
T: Send + 'scope,
374+
{
375+
Builder::new()
376+
.spawn_scoped(self, f)
377+
.expect("failed to spawn thread")
378+
}
379+
}
380+
381+
impl<'scope, T> ScopedJoinHandle<'scope, T> {
382+
/// Extracts a handle to the underlying thread.
383+
pub fn thread(&self) -> &Thread {
384+
self.0.thread()
385+
}
386+
387+
/// Waits for the associated thread to finish.
388+
pub fn join(self) -> std::thread::Result<T> {
389+
self.0.join()
390+
}
391+
}
392+
393+
/// Handle for joining on a thread with a scope.
394+
#[derive(Debug)]
395+
struct JoinHandleInner<'scope, T> {
396+
data: Arc<ThreadData<'scope, T>>,
397+
notify: rt::Notify,
398+
thread: Thread,
399+
}
400+
401+
/// Spawns a thread without a local scope.
402+
fn spawn_internal_static<F, T>(
403+
f: F,
404+
name: Option<String>,
405+
stack_size: Option<usize>,
406+
location: Location,
407+
) -> JoinHandleInner<'static, T>
408+
where
409+
F: FnOnce() -> T,
410+
F: 'static,
411+
T: 'static,
412+
{
413+
// Safety: the requirements of `spawn_internal` are trivially satisfied
414+
// since there is no `scope`.
415+
unsafe { spawn_internal(f, name, stack_size, None, location) }
416+
}
417+
418+
/// Spawns a thread with an optional scope.
419+
///
420+
/// The caller must ensure that if `scope` is not None, the provided closure
421+
/// finishes before `'scope` ends.
422+
unsafe fn spawn_internal<'scope, F, T>(
423+
f: F,
424+
name: Option<String>,
425+
stack_size: Option<usize>,
426+
scope: Option<Arc<ScopeData>>,
427+
location: Location,
428+
) -> JoinHandleInner<'scope, T>
429+
where
430+
F: FnOnce() -> T,
431+
F: 'scope,
432+
T: 'scope,
433+
{
434+
let scope_notify = scope
435+
.clone()
436+
.map(|scope| (scope.add_running_thread(), scope));
437+
let thread_data = Arc::new(ThreadData::new());
438+
let notify = rt::Notify::new(true, false);
439+
440+
let id = {
441+
let name = name.clone();
442+
let thread_data = thread_data.clone();
443+
let body: Box<dyn FnOnce() + 'scope> = Box::new(move || {
444+
rt::execution(|execution| {
445+
init_current(execution, name);
446+
});
447+
448+
*thread_data.result.lock().unwrap() = Some(Ok(f()));
449+
notify.notify(location);
450+
451+
if let Some((notifier, scope)) = scope_notify {
452+
notifier.notify(location!());
453+
scope.remove_running_thread()
454+
}
455+
});
456+
rt::spawn(
457+
stack_size,
458+
std::mem::transmute::<_, Box<dyn FnOnce()>>(body),
459+
)
460+
};
461+
462+
JoinHandleInner {
463+
data: thread_data,
464+
notify,
465+
thread: Thread {
466+
id: ThreadId { id },
467+
name,
468+
},
469+
}
470+
}
471+
472+
/// Data for a running thread.
473+
#[derive(Debug)]
474+
struct ThreadData<'scope, T> {
475+
result: Mutex<Option<std::thread::Result<T>>>,
476+
_marker: PhantomData<Option<&'scope ScopeData>>,
477+
}
478+
479+
impl<'scope, T> ThreadData<'scope, T> {
480+
fn new() -> Self {
481+
Self {
482+
result: Mutex::new(None),
483+
_marker: PhantomData,
484+
}
485+
}
486+
}
487+
488+
impl<'scope, T> JoinHandleInner<'scope, T> {
489+
fn join(self) -> std::thread::Result<T> {
490+
self.notify.wait(location!());
491+
self.data.result.lock().unwrap().take().unwrap()
492+
}
493+
494+
fn thread(&self) -> &Thread {
495+
&self.thread
496+
}
497+
}
498+
499+
#[derive(Default, Debug)]
500+
struct ScopeThreads {
501+
count: usize,
502+
notify_on_finished: Vec<rt::Notify>,
503+
}
504+
505+
#[derive(Debug)]
506+
struct ScopeData {
507+
running_threads: Mutex<ScopeThreads>,
508+
main_thread: Thread,
509+
}
510+
511+
impl ScopeData {
512+
fn add_running_thread(&self) -> rt::Notify {
513+
let mut running = self.running_threads.lock().unwrap();
514+
running.count += 1;
515+
let notify = rt::Notify::new(true, false);
516+
running.notify_on_finished.push(notify);
517+
notify
518+
}
519+
520+
fn remove_running_thread(&self) {
521+
let mut running = self.running_threads.lock().unwrap();
522+
running.count -= 1;
523+
if running.count == 0 {
524+
self.main_thread.unpark()
525+
}
526+
}
527+
}

0 commit comments

Comments
 (0)