@@ -14,11 +14,7 @@ use std::{fmt, io};
14
14
use tracing:: trace;
15
15
16
16
/// Mock implementation of `std::thread::JoinHandle`.
17
- pub struct JoinHandle < T > {
18
- result : Arc < Mutex < Option < std:: thread:: Result < T > > > > ,
19
- notify : rt:: Notify ,
20
- thread : Thread ,
21
- }
17
+ pub struct JoinHandle < T > ( JoinHandleInner < ' static , T > ) ;
22
18
23
19
/// Mock implementation of `std::thread::Thread`.
24
20
#[ derive( Clone , Debug ) ]
@@ -129,7 +125,7 @@ where
129
125
F : ' static ,
130
126
T : ' static ,
131
127
{
132
- spawn_internal ( f, None , None , location ! ( ) )
128
+ JoinHandle ( spawn_internal_static ( f, None , None , location ! ( ) ) )
133
129
}
134
130
135
131
/// Mock implementation of `std::thread::park`.
@@ -143,43 +139,6 @@ pub fn park() {
143
139
rt:: park ( location ! ( ) ) ;
144
140
}
145
141
146
- fn spawn_internal < F , T > (
147
- f : F ,
148
- name : Option < String > ,
149
- stack_size : Option < usize > ,
150
- location : Location ,
151
- ) -> JoinHandle < T >
152
- where
153
- F : FnOnce ( ) -> T ,
154
- F : ' static ,
155
- T : ' static ,
156
- {
157
- let result = Arc :: new ( Mutex :: new ( None ) ) ;
158
- let notify = rt:: Notify :: new ( true , false ) ;
159
-
160
- let id = {
161
- let name = name. clone ( ) ;
162
- let result = result. clone ( ) ;
163
- rt:: spawn ( stack_size, move || {
164
- rt:: execution ( |execution| {
165
- init_current ( execution, name) ;
166
- } ) ;
167
-
168
- * result. lock ( ) . unwrap ( ) = Some ( Ok ( f ( ) ) ) ;
169
- notify. notify ( location) ;
170
- } )
171
- } ;
172
-
173
- JoinHandle {
174
- result,
175
- notify,
176
- thread : Thread {
177
- id : ThreadId { id } ,
178
- name,
179
- } ,
180
- }
181
- }
182
-
183
142
impl Builder {
184
143
/// Generates the base configuration for spawning a thread, from which
185
144
/// configuration methods can be chained.
@@ -217,21 +176,53 @@ impl Builder {
217
176
F : Send + ' static ,
218
177
T : Send + ' static ,
219
178
{
220
- Ok ( spawn_internal ( f, self . name , self . stack_size , location ! ( ) ) )
179
+ Ok ( JoinHandle ( spawn_internal_static (
180
+ f,
181
+ self . name ,
182
+ self . stack_size ,
183
+ location ! ( ) ,
184
+ ) ) )
185
+ }
186
+ }
187
+
188
+ impl Builder {
189
+ /// Spawns a new scoped thread using the settings set through this `Builder`.
190
+ pub fn spawn_scoped < ' scope , ' env , F , T > (
191
+ self ,
192
+ scope : & ' scope Scope < ' scope , ' env > ,
193
+ f : F ,
194
+ ) -> io:: Result < ScopedJoinHandle < ' scope , T > >
195
+ where
196
+ F : FnOnce ( ) -> T + Send + ' scope ,
197
+ T : Send + ' scope ,
198
+ {
199
+ Ok ( ScopedJoinHandle (
200
+ // Safety: the call to this function requires a `&'scope Scope`
201
+ // which can only be constructed by `scope()`, which ensures that
202
+ // all spawned threads are joined before the `Scope` is destroyed.
203
+ unsafe {
204
+ spawn_internal (
205
+ f,
206
+ self . name ,
207
+ self . stack_size ,
208
+ Some ( scope. data . clone ( ) ) ,
209
+ location ! ( ) ,
210
+ )
211
+ } ,
212
+ ) )
221
213
}
222
214
}
223
215
224
216
impl < T > JoinHandle < T > {
225
217
/// Waits for the associated thread to finish.
226
218
#[ track_caller]
227
219
pub fn join ( self ) -> std:: thread:: Result < T > {
228
- self . notify . wait ( location ! ( ) ) ;
229
- self . result . lock ( ) . unwrap ( ) . take ( ) . unwrap ( )
220
+ self . 0 . join ( )
230
221
}
231
222
232
223
/// Gets a handle to the underlying [`Thread`]
233
224
pub fn thread ( & self ) -> & Thread {
234
- & self . thread
225
+ self . 0 . thread ( )
235
226
}
236
227
}
237
228
@@ -312,3 +303,225 @@ impl<T: 'static> fmt::Debug for LocalKey<T> {
312
303
f. pad ( "LocalKey { .. }" )
313
304
}
314
305
}
306
+
307
+ /// A scope for spawning scoped threads.
308
+ ///
309
+ /// See [`scope`] for more details.
310
+ #[ derive( Debug ) ]
311
+ pub struct Scope < ' scope , ' env : ' scope > {
312
+ data : Arc < ScopeData > ,
313
+ scope : PhantomData < & ' scope mut & ' scope ( ) > ,
314
+ env : PhantomData < & ' env mut & ' env ( ) > ,
315
+ }
316
+
317
+ /// An owned permission to join on a scoped thread (block on its termination).
318
+ ///
319
+ /// See [`Scope::spawn`] for details.
320
+ #[ derive( Debug ) ]
321
+ pub struct ScopedJoinHandle < ' scope , T > ( JoinHandleInner < ' scope , T > ) ;
322
+
323
+ /// Create a scope for spawning scoped threads.
324
+ ///
325
+ /// Mock implementation of [`std::thread::scope`].
326
+ #[ track_caller]
327
+ pub fn scope < ' env , F , T > ( f : F ) -> T
328
+ where
329
+ F : for < ' scope > FnOnce ( & ' scope Scope < ' scope , ' env > ) -> T ,
330
+ {
331
+ let scope = Scope {
332
+ data : Arc :: new ( ScopeData {
333
+ running_threads : Mutex :: default ( ) ,
334
+ main_thread : current ( ) ,
335
+ } ) ,
336
+ env : PhantomData ,
337
+ scope : PhantomData ,
338
+ } ;
339
+
340
+ // Run `f`, but catch panics so we can make sure to wait for all the threads to join.
341
+ let result = std:: panic:: catch_unwind ( std:: panic:: AssertUnwindSafe ( || f ( & scope) ) ) ;
342
+
343
+ // Wait until all the threads are finished. This is required to fulfill
344
+ // the safety requirements of `spawn_internal`.
345
+ let running = loop {
346
+ {
347
+ let running = scope. data . running_threads . lock ( ) . unwrap ( ) ;
348
+ if running. count == 0 {
349
+ break running;
350
+ }
351
+ }
352
+ park ( ) ;
353
+ } ;
354
+
355
+ for notify in & running. notify_on_finished {
356
+ notify. wait ( location ! ( ) )
357
+ }
358
+
359
+ // Throw any panic from `f`, or the return value of `f` if no thread panicked.
360
+ match result {
361
+ Err ( e) => std:: panic:: resume_unwind ( e) ,
362
+ Ok ( result) => result,
363
+ }
364
+ }
365
+
366
+ impl < ' scope , ' env > Scope < ' scope , ' env > {
367
+ /// Spawns a new thread within a scope, returning a [`ScopedJoinHandle`] for it.
368
+ ///
369
+ /// See [`std::thread::Scope`] and [`std::thread::scope`] for details.
370
+ pub fn spawn < F , T > ( & ' scope self , f : F ) -> ScopedJoinHandle < ' scope , T >
371
+ where
372
+ F : FnOnce ( ) -> T + Send + ' scope ,
373
+ T : Send + ' scope ,
374
+ {
375
+ Builder :: new ( )
376
+ . spawn_scoped ( self , f)
377
+ . expect ( "failed to spawn thread" )
378
+ }
379
+ }
380
+
381
+ impl < ' scope , T > ScopedJoinHandle < ' scope , T > {
382
+ /// Extracts a handle to the underlying thread.
383
+ pub fn thread ( & self ) -> & Thread {
384
+ self . 0 . thread ( )
385
+ }
386
+
387
+ /// Waits for the associated thread to finish.
388
+ pub fn join ( self ) -> std:: thread:: Result < T > {
389
+ self . 0 . join ( )
390
+ }
391
+ }
392
+
393
+ /// Handle for joining on a thread with a scope.
394
+ #[ derive( Debug ) ]
395
+ struct JoinHandleInner < ' scope , T > {
396
+ data : Arc < ThreadData < ' scope , T > > ,
397
+ notify : rt:: Notify ,
398
+ thread : Thread ,
399
+ }
400
+
401
+ /// Spawns a thread without a local scope.
402
+ fn spawn_internal_static < F , T > (
403
+ f : F ,
404
+ name : Option < String > ,
405
+ stack_size : Option < usize > ,
406
+ location : Location ,
407
+ ) -> JoinHandleInner < ' static , T >
408
+ where
409
+ F : FnOnce ( ) -> T ,
410
+ F : ' static ,
411
+ T : ' static ,
412
+ {
413
+ // Safety: the requirements of `spawn_internal` are trivially satisfied
414
+ // since there is no `scope`.
415
+ unsafe { spawn_internal ( f, name, stack_size, None , location) }
416
+ }
417
+
418
+ /// Spawns a thread with an optional scope.
419
+ ///
420
+ /// The caller must ensure that if `scope` is not None, the provided closure
421
+ /// finishes before `'scope` ends.
422
+ unsafe fn spawn_internal < ' scope , F , T > (
423
+ f : F ,
424
+ name : Option < String > ,
425
+ stack_size : Option < usize > ,
426
+ scope : Option < Arc < ScopeData > > ,
427
+ location : Location ,
428
+ ) -> JoinHandleInner < ' scope , T >
429
+ where
430
+ F : FnOnce ( ) -> T ,
431
+ F : ' scope ,
432
+ T : ' scope ,
433
+ {
434
+ let scope_notify = scope
435
+ . clone ( )
436
+ . map ( |scope| ( scope. add_running_thread ( ) , scope) ) ;
437
+ let thread_data = Arc :: new ( ThreadData :: new ( ) ) ;
438
+ let notify = rt:: Notify :: new ( true , false ) ;
439
+
440
+ let id = {
441
+ let name = name. clone ( ) ;
442
+ let thread_data = thread_data. clone ( ) ;
443
+ let body: Box < dyn FnOnce ( ) + ' scope > = Box :: new ( move || {
444
+ rt:: execution ( |execution| {
445
+ init_current ( execution, name) ;
446
+ } ) ;
447
+
448
+ * thread_data. result . lock ( ) . unwrap ( ) = Some ( Ok ( f ( ) ) ) ;
449
+ notify. notify ( location) ;
450
+
451
+ if let Some ( ( notifier, scope) ) = scope_notify {
452
+ notifier. notify ( location ! ( ) ) ;
453
+ scope. remove_running_thread ( )
454
+ }
455
+ } ) ;
456
+ rt:: spawn (
457
+ stack_size,
458
+ std:: mem:: transmute :: < _ , Box < dyn FnOnce ( ) > > ( body) ,
459
+ )
460
+ } ;
461
+
462
+ JoinHandleInner {
463
+ data : thread_data,
464
+ notify,
465
+ thread : Thread {
466
+ id : ThreadId { id } ,
467
+ name,
468
+ } ,
469
+ }
470
+ }
471
+
472
+ /// Data for a running thread.
473
+ #[ derive( Debug ) ]
474
+ struct ThreadData < ' scope , T > {
475
+ result : Mutex < Option < std:: thread:: Result < T > > > ,
476
+ _marker : PhantomData < Option < & ' scope ScopeData > > ,
477
+ }
478
+
479
+ impl < ' scope , T > ThreadData < ' scope , T > {
480
+ fn new ( ) -> Self {
481
+ Self {
482
+ result : Mutex :: new ( None ) ,
483
+ _marker : PhantomData ,
484
+ }
485
+ }
486
+ }
487
+
488
+ impl < ' scope , T > JoinHandleInner < ' scope , T > {
489
+ fn join ( self ) -> std:: thread:: Result < T > {
490
+ self . notify . wait ( location ! ( ) ) ;
491
+ self . data . result . lock ( ) . unwrap ( ) . take ( ) . unwrap ( )
492
+ }
493
+
494
+ fn thread ( & self ) -> & Thread {
495
+ & self . thread
496
+ }
497
+ }
498
+
499
+ #[ derive( Default , Debug ) ]
500
+ struct ScopeThreads {
501
+ count : usize ,
502
+ notify_on_finished : Vec < rt:: Notify > ,
503
+ }
504
+
505
+ #[ derive( Debug ) ]
506
+ struct ScopeData {
507
+ running_threads : Mutex < ScopeThreads > ,
508
+ main_thread : Thread ,
509
+ }
510
+
511
+ impl ScopeData {
512
+ fn add_running_thread ( & self ) -> rt:: Notify {
513
+ let mut running = self . running_threads . lock ( ) . unwrap ( ) ;
514
+ running. count += 1 ;
515
+ let notify = rt:: Notify :: new ( true , false ) ;
516
+ running. notify_on_finished . push ( notify) ;
517
+ notify
518
+ }
519
+
520
+ fn remove_running_thread ( & self ) {
521
+ let mut running = self . running_threads . lock ( ) . unwrap ( ) ;
522
+ running. count -= 1 ;
523
+ if running. count == 0 {
524
+ self . main_thread . unpark ( )
525
+ }
526
+ }
527
+ }
0 commit comments