Skip to content

Commit ebafbae

Browse files
Zoxcammaraskar
andcommitted
Add a WorkerLocal type which allow you to hold a value per Rayon worker thread
Co-authored-by: Ammar Askar <[email protected]>
1 parent 5b167be commit ebafbae

File tree

3 files changed

+82
-2
lines changed

3 files changed

+82
-2
lines changed

rayon-core/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ mod sleep;
8989
mod spawn;
9090
mod thread_pool;
9191
mod unwind;
92+
mod worker_local;
9293

9394
mod compile_fail;
9495
mod test;
@@ -105,6 +106,7 @@ pub use self::thread_pool::current_thread_has_pending_tasks;
105106
pub use self::thread_pool::current_thread_index;
106107
pub use self::thread_pool::ThreadPool;
107108
pub use self::thread_pool::{yield_local, yield_now, Yield};
109+
pub use worker_local::WorkerLocal;
108110

109111
use self::registry::{CustomSpawn, DefaultSpawn, ThreadSpawn};
110112

rayon-core/src/registry.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -676,12 +676,12 @@ pub(super) struct WorkerThread {
676676
/// local queue used for `spawn_fifo` indirection
677677
fifo: JobFifo,
678678

679-
index: usize,
679+
pub(crate) index: usize,
680680

681681
/// A weak random number generator.
682682
rng: XorShift64Star,
683683

684-
registry: Arc<Registry>,
684+
pub(crate) registry: Arc<Registry>,
685685
}
686686

687687
// This is a bit sketchy, but basically: the WorkerThread is

rayon-core/src/worker_local.rs

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
use crate::registry::{Registry, WorkerThread};
2+
use std::fmt;
3+
use std::ops::Deref;
4+
use std::sync::Arc;
5+
6+
#[repr(align(64))]
7+
#[derive(Debug)]
8+
struct CacheAligned<T>(T);
9+
10+
/// Holds worker-locals values for each thread in a thread pool.
11+
/// You can only access the worker local value through the Deref impl
12+
/// on the thread pool it was constructed on. It will panic otherwise
13+
pub struct WorkerLocal<T> {
14+
locals: Vec<CacheAligned<T>>,
15+
registry: Arc<Registry>,
16+
}
17+
18+
/// We prevent concurrent access to the underlying value in the
19+
/// Deref impl, thus any values safe to send across threads can
20+
/// be used with WorkerLocal.
21+
unsafe impl<T: Send> Sync for WorkerLocal<T> {}
22+
23+
impl<T> WorkerLocal<T> {
24+
/// Creates a new worker local where the `initial` closure computes the
25+
/// value this worker local should take for each thread in the thread pool.
26+
#[inline]
27+
pub fn new<F: FnMut(usize) -> T>(mut initial: F) -> WorkerLocal<T> {
28+
let registry = Registry::current();
29+
WorkerLocal {
30+
locals: (0..registry.num_threads())
31+
.map(|i| CacheAligned(initial(i)))
32+
.collect(),
33+
registry,
34+
}
35+
}
36+
37+
/// Returns the worker-local value for each thread
38+
#[inline]
39+
pub fn into_inner(self) -> Vec<T> {
40+
self.locals.into_iter().map(|c| c.0).collect()
41+
}
42+
43+
fn current(&self) -> &T {
44+
unsafe {
45+
let worker_thread = WorkerThread::current();
46+
if worker_thread.is_null()
47+
|| &*(*worker_thread).registry as *const _ != &*self.registry as *const _
48+
{
49+
panic!("WorkerLocal can only be used on the thread pool it was created on")
50+
}
51+
&self.locals[(*worker_thread).index].0
52+
}
53+
}
54+
}
55+
56+
impl<T> WorkerLocal<Vec<T>> {
57+
/// Joins the elements of all the worker locals into one Vec
58+
pub fn join(self) -> Vec<T> {
59+
self.into_inner().into_iter().flat_map(|v| v).collect()
60+
}
61+
}
62+
63+
impl<T: fmt::Debug> fmt::Debug for WorkerLocal<T> {
64+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65+
f.debug_struct("WorkerLocal")
66+
.field("registry", &self.registry.id())
67+
.finish()
68+
}
69+
}
70+
71+
impl<T> Deref for WorkerLocal<T> {
72+
type Target = T;
73+
74+
#[inline(always)]
75+
fn deref(&self) -> &T {
76+
self.current()
77+
}
78+
}

0 commit comments

Comments
 (0)