@@ -8,33 +8,45 @@ cond: std.Thread.Condition = .{},
8
8
run_queue : RunQueue = .{},
9
9
is_running : bool = true ,
10
10
allocator : std.mem.Allocator ,
11
- threads : []std.Thread ,
11
+ threads : if (builtin .single_threaded ) [0 ]std .Thread else []std.Thread ,
12
+ ids : if (builtin .single_threaded ) struct {
13
+ inline fn deinit (_ : @This (), _ : std .mem .Allocator ) void {}
14
+ fn getIndex (_ : @This (), _ : std .Thread .Id ) usize {
15
+ return 0 ;
16
+ }
17
+ } else std .AutoArrayHashMapUnmanaged (std.Thread.Id , void ),
12
18
13
19
const RunQueue = std .SinglyLinkedList (Runnable );
14
20
const Runnable = struct {
15
21
runFn : RunProto ,
16
22
};
17
23
18
- const RunProto = * const fn (* Runnable ) void ;
24
+ const RunProto = * const fn (* Runnable , id : ? usize ) void ;
19
25
20
26
pub const Options = struct {
21
27
allocator : std.mem.Allocator ,
22
- n_jobs : ? u32 = null ,
28
+ n_jobs : ? usize = null ,
29
+ track_ids : bool = false ,
23
30
};
24
31
25
32
pub fn init (pool : * Pool , options : Options ) ! void {
26
33
const allocator = options .allocator ;
27
34
28
35
pool .* = .{
29
36
.allocator = allocator ,
30
- .threads = &[_ ]std.Thread {},
37
+ .threads = if (builtin .single_threaded ) .{} else &.{},
38
+ .ids = .{},
31
39
};
32
40
33
41
if (builtin .single_threaded ) {
34
42
return ;
35
43
}
36
44
37
45
const thread_count = options .n_jobs orelse @max (1 , std .Thread .getCpuCount () catch 1 );
46
+ if (options .track_ids ) {
47
+ try pool .ids .ensureTotalCapacity (allocator , 1 + thread_count );
48
+ pool .ids .putAssumeCapacityNoClobber (std .Thread .getCurrentId (), {});
49
+ }
38
50
39
51
// kill and join any threads we spawned and free memory on error.
40
52
pool .threads = try allocator .alloc (std .Thread , thread_count );
@@ -49,6 +61,7 @@ pub fn init(pool: *Pool, options: Options) !void {
49
61
50
62
pub fn deinit (pool : * Pool ) void {
51
63
pool .join (pool .threads .len ); // kill and join all threads.
64
+ pool .ids .deinit (pool .allocator );
52
65
pool .* = undefined ;
53
66
}
54
67
@@ -96,7 +109,7 @@ pub fn spawnWg(pool: *Pool, wait_group: *WaitGroup, comptime func: anytype, args
96
109
run_node : RunQueue.Node = .{ .data = .{ .runFn = runFn } },
97
110
wait_group : * WaitGroup ,
98
111
99
- fn runFn (runnable : * Runnable ) void {
112
+ fn runFn (runnable : * Runnable , _ : ? usize ) void {
100
113
const run_node : * RunQueue.Node = @fieldParentPtr ("data" , runnable );
101
114
const closure : * @This () = @alignCast (@fieldParentPtr ("run_node" , run_node ));
102
115
@call (.auto , func , closure .arguments );
@@ -134,6 +147,70 @@ pub fn spawnWg(pool: *Pool, wait_group: *WaitGroup, comptime func: anytype, args
134
147
pool .cond .signal ();
135
148
}
136
149
150
+ /// Runs `func` in the thread pool, calling `WaitGroup.start` beforehand, and
151
+ /// `WaitGroup.finish` after it returns.
152
+ ///
153
+ /// The first argument passed to `func` is a dense `usize` thread id, the rest
154
+ /// of the arguments are passed from `args`. Requires the pool to have been
155
+ /// initialized with `.track_ids = true`.
156
+ ///
157
+ /// In the case that queuing the function call fails to allocate memory, or the
158
+ /// target is single-threaded, the function is called directly.
159
+ pub fn spawnWgId (pool : * Pool , wait_group : * WaitGroup , comptime func : anytype , args : anytype ) void {
160
+ wait_group .start ();
161
+
162
+ if (builtin .single_threaded ) {
163
+ @call (.auto , func , .{0 } ++ args );
164
+ wait_group .finish ();
165
+ return ;
166
+ }
167
+
168
+ const Args = @TypeOf (args );
169
+ const Closure = struct {
170
+ arguments : Args ,
171
+ pool : * Pool ,
172
+ run_node : RunQueue.Node = .{ .data = .{ .runFn = runFn } },
173
+ wait_group : * WaitGroup ,
174
+
175
+ fn runFn (runnable : * Runnable , id : ? usize ) void {
176
+ const run_node : * RunQueue.Node = @fieldParentPtr ("data" , runnable );
177
+ const closure : * @This () = @alignCast (@fieldParentPtr ("run_node" , run_node ));
178
+ @call (.auto , func , .{id .? } ++ closure .arguments );
179
+ closure .wait_group .finish ();
180
+
181
+ // The thread pool's allocator is protected by the mutex.
182
+ const mutex = & closure .pool .mutex ;
183
+ mutex .lock ();
184
+ defer mutex .unlock ();
185
+
186
+ closure .pool .allocator .destroy (closure );
187
+ }
188
+ };
189
+
190
+ {
191
+ pool .mutex .lock ();
192
+
193
+ const closure = pool .allocator .create (Closure ) catch {
194
+ const id : ? usize = pool .ids .getIndex (std .Thread .getCurrentId ());
195
+ pool .mutex .unlock ();
196
+ @call (.auto , func , .{id .? } ++ args );
197
+ wait_group .finish ();
198
+ return ;
199
+ };
200
+ closure .* = .{
201
+ .arguments = args ,
202
+ .pool = pool ,
203
+ .wait_group = wait_group ,
204
+ };
205
+
206
+ pool .run_queue .prepend (& closure .run_node );
207
+ pool .mutex .unlock ();
208
+ }
209
+
210
+ // Notify waiting threads outside the lock to try and keep the critical section small.
211
+ pool .cond .signal ();
212
+ }
213
+
137
214
pub fn spawn (pool : * Pool , comptime func : anytype , args : anytype ) ! void {
138
215
if (builtin .single_threaded ) {
139
216
@call (.auto , func , args );
@@ -181,14 +258,16 @@ fn worker(pool: *Pool) void {
181
258
pool .mutex .lock ();
182
259
defer pool .mutex .unlock ();
183
260
261
+ const id : ? usize = if (pool .ids .count () > 0 ) @intCast (pool .ids .count ()) else null ;
262
+ if (id ) | _ | pool .ids .putAssumeCapacityNoClobber (std .Thread .getCurrentId (), {});
263
+
184
264
while (true ) {
185
265
while (pool .run_queue .popFirst ()) | run_node | {
186
266
// Temporarily unlock the mutex in order to execute the run_node
187
267
pool .mutex .unlock ();
188
268
defer pool .mutex .lock ();
189
269
190
- const runFn = run_node .data .runFn ;
191
- runFn (& run_node .data );
270
+ run_node .data .runFn (& run_node .data , id );
192
271
}
193
272
194
273
// Stop executing instead of waiting if the thread pool is no longer running.
@@ -201,17 +280,23 @@ fn worker(pool: *Pool) void {
201
280
}
202
281
203
282
pub fn waitAndWork (pool : * Pool , wait_group : * WaitGroup ) void {
283
+ var id : ? usize = null ;
284
+
204
285
while (! wait_group .isDone ()) {
205
- if (blk : {
206
- pool .mutex .lock ();
207
- defer pool .mutex .unlock ();
208
- break :blk pool .run_queue .popFirst ();
209
- }) | run_node | {
210
- run_node .data .runFn (& run_node .data );
286
+ pool .mutex .lock ();
287
+ if (pool .run_queue .popFirst ()) | run_node | {
288
+ id = id orelse pool .ids .getIndex (std .Thread .getCurrentId ());
289
+ pool .mutex .unlock ();
290
+ run_node .data .runFn (& run_node .data , id );
211
291
continue ;
212
292
}
213
293
294
+ pool .mutex .unlock ();
214
295
wait_group .wait ();
215
296
return ;
216
297
}
217
298
}
299
+
300
+ pub fn getIdCount (pool : * Pool ) usize {
301
+ return @intCast (1 + pool .threads .len );
302
+ }
0 commit comments