Skip to content

Commit 49958f4

Browse files
committed
Split Luau thread event callback to creation and collection callbacks.
We need this because they have different requirements: - Thread creation callback can return Error or panic - Thread collection callback runs during Luau GC cycle and cannot make any Lua calls or trigger panics.
1 parent 39b3af2 commit 49958f4

File tree

6 files changed

+124
-56
lines changed

6 files changed

+124
-56
lines changed

src/state.rs

+69-34
Original file line numberDiff line numberDiff line change
@@ -696,56 +696,91 @@ impl Lua {
696696
}
697697
}
698698

699-
/// Sets a thread event callback that will be called when a thread is created or destroyed.
699+
/// Sets a thread creation callback that will be called when a thread is created.
700+
#[cfg(any(feature = "luau", doc))]
701+
#[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
702+
pub fn set_thread_creation_callback<F>(&self, callback: F)
703+
where
704+
F: Fn(&Lua, Thread) -> Result<()> + MaybeSend + 'static,
705+
{
706+
let lua = self.lock();
707+
unsafe {
708+
(*lua.extra.get()).thread_creation_callback = Some(XRc::new(callback));
709+
(*ffi::lua_callbacks(lua.main_state())).userthread = Some(Self::userthread_proc);
710+
}
711+
}
712+
713+
/// Sets a thread collection callback that will be called when a thread is destroyed.
700714
///
701-
/// The callback is called with a [`Value`] argument that is either:
702-
/// - A [`Thread`] object when thread is created
703-
/// - A [`LightUserData`] when thread is destroyed
715+
/// Luau GC does not support exceptions during collection, so the callback must be
716+
/// non-panicking. If the callback panics, the program will be aborted.
704717
#[cfg(any(feature = "luau", doc))]
705718
#[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
706-
pub fn set_thread_event_callback<F>(&self, callback: F)
719+
pub fn set_thread_collection_callback<F>(&self, callback: F)
707720
where
708-
F: Fn(&Lua, Value) -> Result<()> + MaybeSend + 'static,
721+
F: Fn(crate::LightUserData) + MaybeSend + 'static,
709722
{
710-
unsafe extern "C-unwind" fn userthread_proc(parent: *mut ffi::lua_State, child: *mut ffi::lua_State) {
711-
let extra = ExtraData::get(child);
712-
let thread_cb = match (*extra).userthread_callback {
723+
let lua = self.lock();
724+
unsafe {
725+
(*lua.extra.get()).thread_collection_callback = Some(XRc::new(callback));
726+
(*ffi::lua_callbacks(lua.main_state())).userthread = Some(Self::userthread_proc);
727+
}
728+
}
729+
730+
#[cfg(feature = "luau")]
731+
unsafe extern "C-unwind" fn userthread_proc(parent: *mut ffi::lua_State, child: *mut ffi::lua_State) {
732+
let extra = ExtraData::get(child);
733+
if !parent.is_null() {
734+
// Thread is created
735+
let callback = match (*extra).thread_creation_callback {
713736
Some(ref cb) => cb.clone(),
714737
None => return,
715738
};
716-
if XRc::strong_count(&thread_cb) > 2 {
739+
if XRc::strong_count(&callback) > 2 {
717740
return; // Don't allow recursion
718741
}
719-
let value = match parent.is_null() {
720-
// Thread is about to be destroyed, pass light userdata
721-
true => Value::LightUserData(crate::LightUserData(child as _)),
722-
false => {
723-
// Thread is created, pass thread object
724-
ffi::lua_pushthread(child);
725-
ffi::lua_xmove(child, (*extra).ref_thread, 1);
726-
Value::Thread(Thread((*extra).raw_lua().pop_ref_thread(), child))
727-
}
728-
};
742+
ffi::lua_pushthread(child);
743+
ffi::lua_xmove(child, (*extra).ref_thread, 1);
744+
let value = Thread((*extra).raw_lua().pop_ref_thread(), child);
745+
let _guard = StateGuard::new((*extra).raw_lua(), parent);
729746
callback_error_ext((*extra).raw_lua().state(), extra, false, move |extra, _| {
730-
thread_cb((*extra).lua(), value)
747+
callback((*extra).lua(), value)
731748
})
732-
}
749+
} else {
750+
// Thread is about to be collected
751+
let callback = match (*extra).thread_collection_callback {
752+
Some(ref cb) => cb.clone(),
753+
None => return,
754+
};
733755

734-
// Set thread callback
735-
let lua = self.lock();
736-
unsafe {
737-
(*lua.extra.get()).userthread_callback = Some(XRc::new(callback));
738-
(*ffi::lua_callbacks(lua.main_state())).userthread = Some(userthread_proc);
756+
// We need to wrap the callback call in non-unwind function as it's not safe to unwind when
757+
// Luau GC is running.
758+
// This will trigger `abort()` if the callback panics.
759+
unsafe extern "C" fn run_callback(
760+
callback: *const crate::types::ThreadCollectionCallback,
761+
value: *mut ffi::lua_State,
762+
) {
763+
(*callback)(crate::LightUserData(value as _));
764+
}
765+
766+
(*extra).running_gc = true;
767+
run_callback(&callback, child);
768+
(*extra).running_gc = false;
739769
}
740770
}
741771

742-
/// Removes any thread event callback previously set by `set_thread_event_callback`.
772+
/// Removes any thread creation or collection callbacks previously set by
773+
/// [`Lua::set_thread_creation_callback`] or [`Lua::set_thread_collection_callback`].
774+
///
775+
/// This function has no effect if a thread callbacks were not previously set.
743776
#[cfg(any(feature = "luau", doc))]
744777
#[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
745-
pub fn remove_thread_event_callback(&self) {
778+
pub fn remove_thread_callbacks(&self) {
746779
let lua = self.lock();
747780
unsafe {
748-
(*lua.extra.get()).userthread_callback = None;
781+
let extra = lua.extra.get();
782+
(*extra).thread_creation_callback = None;
783+
(*extra).thread_collection_callback = None;
749784
(*ffi::lua_callbacks(lua.main_state())).userthread = None;
750785
}
751786
}
@@ -2039,8 +2074,8 @@ impl Lua {
20392074
pub(crate) fn lock(&self) -> ReentrantMutexGuard<RawLua> {
20402075
let rawlua = self.raw.lock();
20412076
#[cfg(feature = "luau")]
2042-
if unsafe { (*rawlua.extra.get()).running_userdata_gc } {
2043-
panic!("Luau VM is suspended while userdata destructor is running");
2077+
if unsafe { (*rawlua.extra.get()).running_gc } {
2078+
panic!("Luau VM is suspended while GC is running");
20442079
}
20452080
rawlua
20462081
}
@@ -2066,8 +2101,8 @@ impl WeakLua {
20662101
pub(crate) fn lock(&self) -> LuaGuard {
20672102
let guard = LuaGuard::new(self.0.upgrade().expect("Lua instance is destroyed"));
20682103
#[cfg(feature = "luau")]
2069-
if unsafe { (*guard.extra.get()).running_userdata_gc } {
2070-
panic!("Luau VM is suspended while userdata destructor is running");
2104+
if unsafe { (*guard.extra.get()).running_gc } {
2105+
panic!("Luau VM is suspended while GC is running");
20712106
}
20722107
guard
20732108
}

src/state/extra.rs

+8-4
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,12 @@ pub(crate) struct ExtraData {
8181
#[cfg(feature = "luau")]
8282
pub(super) interrupt_callback: Option<crate::types::InterruptCallback>,
8383
#[cfg(feature = "luau")]
84-
pub(super) userthread_callback: Option<crate::types::UserThreadCallback>,
84+
pub(super) thread_creation_callback: Option<crate::types::ThreadCreationCallback>,
85+
#[cfg(feature = "luau")]
86+
pub(super) thread_collection_callback: Option<crate::types::ThreadCollectionCallback>,
8587

8688
#[cfg(feature = "luau")]
87-
pub(crate) running_userdata_gc: bool,
89+
pub(crate) running_gc: bool,
8890
#[cfg(feature = "luau")]
8991
pub(super) sandboxed: bool,
9092
#[cfg(feature = "luau")]
@@ -181,15 +183,17 @@ impl ExtraData {
181183
#[cfg(feature = "luau")]
182184
interrupt_callback: None,
183185
#[cfg(feature = "luau")]
184-
userthread_callback: None,
186+
thread_creation_callback: None,
187+
#[cfg(feature = "luau")]
188+
thread_collection_callback: None,
185189
#[cfg(feature = "luau")]
186190
sandboxed: false,
187191
#[cfg(feature = "luau")]
188192
compiler: None,
189193
#[cfg(feature = "luau-jit")]
190194
enable_jit: true,
191195
#[cfg(feature = "luau")]
192-
running_userdata_gc: false,
196+
running_gc: false,
193197
}));
194198

195199
// Store it in the registry

src/state/raw.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -573,7 +573,7 @@ impl RawLua {
573573

574574
let protect = !self.unlikely_memory_error();
575575
#[cfg(feature = "luau")]
576-
let protect = protect || (*self.extra.get()).userthread_callback.is_some();
576+
let protect = protect || (*self.extra.get()).thread_creation_callback.is_some();
577577

578578
let thread_state = if !protect {
579579
ffi::lua_newthread(state)

src/types.rs

+8-2
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,16 @@ pub(crate) type InterruptCallback = XRc<dyn Fn(&Lua) -> Result<VmState> + Send>;
9191
pub(crate) type InterruptCallback = XRc<dyn Fn(&Lua) -> Result<VmState>>;
9292

9393
#[cfg(all(feature = "send", feature = "luau"))]
94-
pub(crate) type UserThreadCallback = XRc<dyn Fn(&Lua, crate::Value) -> Result<()> + Send>;
94+
pub(crate) type ThreadCreationCallback = XRc<dyn Fn(&Lua, crate::Thread) -> Result<()> + Send>;
9595

9696
#[cfg(all(not(feature = "send"), feature = "luau"))]
97-
pub(crate) type UserThreadCallback = XRc<dyn Fn(&Lua, crate::Value) -> Result<()>>;
97+
pub(crate) type ThreadCreationCallback = XRc<dyn Fn(&Lua, crate::Thread) -> Result<()>>;
98+
99+
#[cfg(all(feature = "send", feature = "luau"))]
100+
pub(crate) type ThreadCollectionCallback = XRc<dyn Fn(crate::LightUserData) + Send>;
101+
102+
#[cfg(all(not(feature = "send"), feature = "luau"))]
103+
pub(crate) type ThreadCollectionCallback = XRc<dyn Fn(crate::LightUserData)>;
98104

99105
#[cfg(all(feature = "send", feature = "lua54"))]
100106
pub(crate) type WarnCallback = XRc<dyn Fn(&Lua, &str, bool) -> Result<()> + Send>;

src/userdata/util.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -443,11 +443,11 @@ pub(crate) unsafe extern "C" fn collect_userdata<T>(
443443
// Almost none Lua operations are allowed when destructor is running,
444444
// so we need to set a flag to prevent calling any Lua functions
445445
let extra = (*ffi::lua_callbacks(state)).userdata as *mut crate::state::ExtraData;
446-
(*extra).running_userdata_gc = true;
446+
(*extra).running_gc = true;
447447
// Luau does not support _any_ panics in destructors (they are declared as "C", NOT as "C-unwind"),
448448
// so any panics will trigger `abort()`.
449449
ptr::drop_in_place(ud as *mut T);
450-
(*extra).running_userdata_gc = false;
450+
(*extra).running_gc = false;
451451
}
452452

453453
// This method can be called by user or Lua GC to destroy the userdata.

tests/luau.rs

+36-13
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#![cfg(feature = "luau")]
22

3+
use std::cell::Cell;
34
use std::fmt::Debug;
45
use std::fs;
56
use std::os::raw::c_void;
@@ -419,17 +420,19 @@ fn test_thread_events() -> Result<()> {
419420
let thread_data: Arc<(AtomicPtr<c_void>, AtomicBool)> = Arc::new(Default::default());
420421

421422
let (count2, thread_data2) = (count.clone(), thread_data.clone());
422-
lua.set_thread_event_callback(move |_, value| {
423+
lua.set_thread_creation_callback(move |_, thread| {
423424
count2.fetch_add(1, Ordering::Relaxed);
424-
(thread_data2.0).store(value.to_pointer() as *mut _, Ordering::Relaxed);
425-
if value.is_thread() {
426-
thread_data2.1.store(false, Ordering::Relaxed);
427-
}
428-
if value.is_light_userdata() {
429-
thread_data2.1.store(true, Ordering::Relaxed);
430-
}
425+
(thread_data2.0).store(thread.to_pointer() as *mut _, Ordering::Relaxed);
426+
thread_data2.1.store(false, Ordering::Relaxed);
431427
Ok(())
432428
});
429+
let (count3, thread_data3) = (count.clone(), thread_data.clone());
430+
lua.set_thread_collection_callback(move |thread_ptr| {
431+
count3.fetch_add(1, Ordering::Relaxed);
432+
if thread_data3.0.load(Ordering::Relaxed) == thread_ptr.0 {
433+
thread_data3.1.store(true, Ordering::Relaxed);
434+
}
435+
});
433436

434437
let t = lua.create_thread(lua.load("return 123").into_function()?)?;
435438
assert_eq!(count.load(Ordering::Relaxed), 1);
@@ -445,27 +448,47 @@ fn test_thread_events() -> Result<()> {
445448
assert!(thread_data.1.load(Ordering::Relaxed));
446449

447450
// Check that recursion is not allowed
448-
let count3 = count.clone();
449-
lua.set_thread_event_callback(move |lua, _value| {
450-
count3.fetch_add(1, Ordering::Relaxed);
451+
let count4 = count.clone();
452+
lua.set_thread_creation_callback(move |lua, _value| {
453+
count4.fetch_add(1, Ordering::Relaxed);
451454
let _ = lua.create_thread(lua.load("return 123").into_function().unwrap())?;
452455
Ok(())
453456
});
454457
let t = lua.create_thread(lua.load("return 123").into_function()?)?;
455458
assert_eq!(count.load(Ordering::Relaxed), 3);
456459

457-
lua.remove_thread_event_callback();
460+
lua.remove_thread_callbacks();
458461
drop(t);
459462
lua.gc_collect()?;
460463
assert_eq!(count.load(Ordering::Relaxed), 3);
461464

462465
// Test error inside callback
463-
lua.set_thread_event_callback(move |_, _| Err(Error::runtime("error when processing thread event")));
466+
lua.set_thread_creation_callback(move |_, _| Err(Error::runtime("error when processing thread event")));
464467
let result = lua.create_thread(lua.load("return 123").into_function()?);
465468
assert!(result.is_err());
466469
assert!(
467470
matches!(result, Err(Error::RuntimeError(err)) if err.contains("error when processing thread event"))
468471
);
469472

473+
// Test context switch when running Lua script
474+
let count = Cell::new(0);
475+
lua.set_thread_creation_callback(move |_, _| {
476+
count.set(count.get() + 1);
477+
if count.get() == 2 {
478+
return Err(Error::runtime("thread limit exceeded"));
479+
}
480+
Ok(())
481+
});
482+
let result = lua
483+
.load(
484+
r#"
485+
local co = coroutine.wrap(function() return coroutine.create(print) end)
486+
co()
487+
"#,
488+
)
489+
.exec();
490+
assert!(result.is_err());
491+
assert!(matches!(result, Err(Error::RuntimeError(err)) if err.contains("thread limit exceeded")));
492+
470493
Ok(())
471494
}

0 commit comments

Comments
 (0)