Skip to content

Commit 36e5d84

Browse files
committed
Allow to temporarily set the current registry even if it is not associated with a worker thread
1 parent b3bd4bc commit 36e5d84

File tree

2 files changed

+74
-7
lines changed

2 files changed

+74
-7
lines changed

rayon-core/src/registry.rs

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,36 @@ fn default_global_registry() -> Result<Arc<Registry>, ThreadPoolBuildError> {
217217
result
218218
}
219219

220+
// This is used to temporarily overwrite the current registry.
221+
//
222+
// This either null, a pointer to the global registry if it was
223+
// ever used to access the global registry or a pointer to a
224+
// registry which is temporarily made current because the current
225+
// thread is not a worker thread but is running a scope associated
226+
// to a specific thread pool.
227+
thread_local! {
228+
static CURRENT_REGISTRY: Cell<*const Arc<Registry>> = const { Cell::new(ptr::null()) };
229+
}
230+
231+
#[cold]
232+
fn set_current_registry_to_global_registry() -> *const Arc<Registry> {
233+
let global = global_registry();
234+
235+
CURRENT_REGISTRY.with(|current_registry| current_registry.set(global));
236+
237+
global
238+
}
239+
240+
fn current_registry() -> *const Arc<Registry> {
241+
let mut current = CURRENT_REGISTRY.with(Cell::get);
242+
243+
if current.is_null() {
244+
current = set_current_registry_to_global_registry();
245+
}
246+
247+
current
248+
}
249+
220250
struct Terminator<'a>(&'a Arc<Registry>);
221251

222252
impl<'a> Drop for Terminator<'a> {
@@ -315,14 +345,47 @@ impl Registry {
315345
unsafe {
316346
let worker_thread = WorkerThread::current();
317347
let registry = if worker_thread.is_null() {
318-
global_registry()
348+
&*current_registry()
319349
} else {
320350
&(*worker_thread).registry
321351
};
322352
Arc::clone(registry)
323353
}
324354
}
325355

356+
/// Optionally install a specific registry as the current one.
357+
///
358+
/// This is used when a thread which is not a worker executes
359+
/// a scope which should use the specific thread pool instead of
360+
/// the global one.
361+
pub(super) fn with_current<F, R>(registry: Option<&Arc<Registry>>, f: F) -> R
362+
where
363+
F: FnOnce() -> R,
364+
{
365+
struct Guard {
366+
current: *const Arc<Registry>,
367+
}
368+
369+
impl Guard {
370+
fn new(registry: &Arc<Registry>) -> Self {
371+
let current =
372+
CURRENT_REGISTRY.with(|current_registry| current_registry.replace(registry));
373+
374+
Self { current }
375+
}
376+
}
377+
378+
impl Drop for Guard {
379+
fn drop(&mut self) {
380+
CURRENT_REGISTRY.with(|current_registry| current_registry.set(self.current));
381+
}
382+
}
383+
384+
let _guard = registry.map(Guard::new);
385+
386+
f()
387+
}
388+
326389
/// Returns the number of threads in the current registry. This
327390
/// is better than `Registry::current().num_threads()` because it
328391
/// avoids incrementing the `Arc`.

rayon-core/src/scope/mod.rs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -416,9 +416,11 @@ pub(crate) fn do_in_place_scope<'scope, OP, R>(registry: Option<&Arc<Registry>>,
416416
where
417417
OP: FnOnce(&Scope<'scope>) -> R,
418418
{
419-
let thread = unsafe { WorkerThread::current().as_ref() };
420-
let scope = Scope::<'scope>::new(thread, registry);
421-
scope.base.complete(thread, || op(&scope))
419+
Registry::with_current(registry, || {
420+
let thread = unsafe { WorkerThread::current().as_ref() };
421+
let scope = Scope::<'scope>::new(thread, registry);
422+
scope.base.complete(thread, || op(&scope))
423+
})
422424
}
423425

424426
/// Creates a "fork-join" scope `s` with FIFO order, and invokes the
@@ -453,9 +455,11 @@ pub(crate) fn do_in_place_scope_fifo<'scope, OP, R>(registry: Option<&Arc<Regist
453455
where
454456
OP: FnOnce(&ScopeFifo<'scope>) -> R,
455457
{
456-
let thread = unsafe { WorkerThread::current().as_ref() };
457-
let scope = ScopeFifo::<'scope>::new(thread, registry);
458-
scope.base.complete(thread, || op(&scope))
458+
Registry::with_current(registry, || {
459+
let thread = unsafe { WorkerThread::current().as_ref() };
460+
let scope = ScopeFifo::<'scope>::new(thread, registry);
461+
scope.base.complete(thread, || op(&scope))
462+
})
459463
}
460464

461465
impl<'scope> Scope<'scope> {

0 commit comments

Comments
 (0)