@@ -161,7 +161,7 @@ static THE_REGISTRY_SET: Once = Once::new();
161
161
/// Starts the worker threads (if that has not already happened). If
162
162
/// initialization has not already occurred, use the default
163
163
/// configuration.
164
- pub ( super ) fn global_registry ( ) -> & ' static Arc < Registry > {
164
+ fn global_registry ( ) -> & ' static Arc < Registry > {
165
165
set_global_registry ( default_global_registry)
166
166
. or_else ( |err| unsafe { THE_REGISTRY . as_ref ( ) . ok_or ( err) } )
167
167
. expect ( "The global thread pool has not been initialized." )
@@ -217,6 +217,36 @@ fn default_global_registry() -> Result<Arc<Registry>, ThreadPoolBuildError> {
217
217
result
218
218
}
219
219
220
+ // This is used to temporarily overwrite the current registry.
221
+ //
222
+ // This either null, a pointer to the global registry if it was
223
+ // ever used to access the global registry or a pointer to a
224
+ // registry which is temporarily made current because the current
225
+ // thread is not a worker thread but is running a scope associated
226
+ // to a specific thread pool.
227
+ thread_local ! {
228
+ static CURRENT_REGISTRY : Cell <* const Arc <Registry >> = const { Cell :: new( ptr:: null( ) ) } ;
229
+ }
230
+
231
+ #[ cold]
232
+ fn set_current_registry_to_global_registry ( ) -> * const Arc < Registry > {
233
+ let global = global_registry ( ) ;
234
+
235
+ CURRENT_REGISTRY . with ( |current_registry| current_registry. set ( global) ) ;
236
+
237
+ global
238
+ }
239
+
240
+ pub ( super ) fn current_registry ( ) -> * const Arc < Registry > {
241
+ let mut current = CURRENT_REGISTRY . with ( Cell :: get) ;
242
+
243
+ if current. is_null ( ) {
244
+ current = set_current_registry_to_global_registry ( ) ;
245
+ }
246
+
247
+ current
248
+ }
249
+
220
250
struct Terminator < ' a > ( & ' a Arc < Registry > ) ;
221
251
222
252
impl < ' a > Drop for Terminator < ' a > {
@@ -315,22 +345,55 @@ impl Registry {
315
345
unsafe {
316
346
let worker_thread = WorkerThread :: current ( ) ;
317
347
let registry = if worker_thread. is_null ( ) {
318
- global_registry ( )
348
+ & * current_registry ( )
319
349
} else {
320
350
& ( * worker_thread) . registry
321
351
} ;
322
352
Arc :: clone ( registry)
323
353
}
324
354
}
325
355
356
+ /// Optionally install a specific registry as the current one.
357
+ ///
358
+ /// This is used when a thread which is not a worker executes
359
+ /// a scope which should use the specific thread pool instead of
360
+ /// the global one.
361
+ pub ( super ) fn with_current < F , R > ( registry : Option < & Arc < Registry > > , f : F ) -> R
362
+ where
363
+ F : FnOnce ( ) -> R ,
364
+ {
365
+ struct Guard {
366
+ current : * const Arc < Registry > ,
367
+ }
368
+
369
+ impl Guard {
370
+ fn new ( registry : & Arc < Registry > ) -> Self {
371
+ let current =
372
+ CURRENT_REGISTRY . with ( |current_registry| current_registry. replace ( registry) ) ;
373
+
374
+ Self { current }
375
+ }
376
+ }
377
+
378
+ impl Drop for Guard {
379
+ fn drop ( & mut self ) {
380
+ CURRENT_REGISTRY . with ( |current_registry| current_registry. set ( self . current ) ) ;
381
+ }
382
+ }
383
+
384
+ let _guard = registry. map ( Guard :: new) ;
385
+
386
+ f ( )
387
+ }
388
+
326
389
/// Returns the number of threads in the current registry. This
327
390
/// is better than `Registry::current().num_threads()` because it
328
391
/// avoids incrementing the `Arc`.
329
392
pub ( super ) fn current_num_threads ( ) -> usize {
330
393
unsafe {
331
394
let worker_thread = WorkerThread :: current ( ) ;
332
395
if worker_thread. is_null ( ) {
333
- global_registry ( ) . num_threads ( )
396
+ ( * current_registry ( ) ) . num_threads ( )
334
397
} else {
335
398
( * worker_thread) . registry . num_threads ( )
336
399
}
@@ -946,7 +1009,7 @@ where
946
1009
// invalidated until we return.
947
1010
op ( & * owner_thread, false )
948
1011
} else {
949
- global_registry ( ) . in_worker ( op)
1012
+ ( * current_registry ( ) ) . in_worker ( op)
950
1013
}
951
1014
}
952
1015
}
0 commit comments