Skip to content

Commit 1f97082

Browse files
committed
Support user thread creation callback (Luau)
1 parent 8d864e2 commit 1f97082

File tree

8 files changed

+160
-26
lines changed

8 files changed

+160
-26
lines changed

src/state.rs

+58-4
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,7 @@ impl Lua {
419419
// Make sure that Lua is initialized
420420
let _ = Self::get_or_init_from_ptr(state);
421421

422-
callback_error_ext(state, ptr::null_mut(), move |extra, nargs| {
422+
callback_error_ext(state, ptr::null_mut(), true, move |extra, nargs| {
423423
let rawlua = (*extra).raw_lua();
424424
let _guard = StateGuard::new(rawlua, state);
425425
let args = A::from_stack_args(nargs, 1, None, rawlua)?;
@@ -652,7 +652,7 @@ impl Lua {
652652
// We don't support GC interrupts since they cannot survive Lua exceptions
653653
return;
654654
}
655-
let result = callback_error_ext(state, ptr::null_mut(), move |extra, _| {
655+
let result = callback_error_ext(state, ptr::null_mut(), false, move |extra, _| {
656656
let interrupt_cb = (*extra).interrupt_callback.clone();
657657
let interrupt_cb = mlua_expect!(interrupt_cb, "no interrupt callback set in interrupt_proc");
658658
if XRc::strong_count(&interrupt_cb) > 2 {
@@ -690,6 +690,60 @@ impl Lua {
690690
}
691691
}
692692

693+
/// Sets a thread event callback that will be called when a thread is created or destroyed.
694+
///
695+
/// The callback is called with a [`Value`] argument that is either:
696+
/// - A [`Thread`] object when thread is created
697+
/// - A [`LightUserData`] when thread is destroyed
698+
#[cfg(any(feature = "luau", doc))]
699+
#[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
700+
pub fn set_thread_event_callback<F>(&self, callback: F)
701+
where
702+
F: Fn(&Lua, Value) -> Result<()> + MaybeSend + 'static,
703+
{
704+
unsafe extern "C-unwind" fn userthread_proc(parent: *mut ffi::lua_State, child: *mut ffi::lua_State) {
705+
let extra = ExtraData::get(child);
706+
let thread_cb = match (*extra).userthread_callback {
707+
Some(ref cb) => cb.clone(),
708+
None => return,
709+
};
710+
if XRc::strong_count(&thread_cb) > 2 {
711+
return; // Don't allow recursion
712+
}
713+
let value = match parent.is_null() {
714+
// Thread is about to be destroyed, pass light userdata
715+
true => Value::LightUserData(crate::LightUserData(child as _)),
716+
false => {
717+
// Thread is created, pass thread object
718+
ffi::lua_pushthread(child);
719+
ffi::lua_xmove(child, (*extra).ref_thread, 1);
720+
Value::Thread(Thread((*extra).raw_lua().pop_ref_thread(), child))
721+
}
722+
};
723+
callback_error_ext((*extra).raw_lua().state(), extra, false, move |extra, _| {
724+
thread_cb((*extra).lua(), value)
725+
})
726+
}
727+
728+
// Set thread callback
729+
let lua = self.lock();
730+
unsafe {
731+
(*lua.extra.get()).userthread_callback = Some(XRc::new(callback));
732+
(*ffi::lua_callbacks(lua.main_state())).userthread = Some(userthread_proc);
733+
}
734+
}
735+
736+
/// Removes any thread event callback previously set by `set_thread_event_callback`.
737+
#[cfg(any(feature = "luau", doc))]
738+
#[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
739+
pub fn remove_thread_event_callback(&self) {
740+
let lua = self.lock();
741+
unsafe {
742+
(*lua.extra.get()).userthread_callback = None;
743+
(*ffi::lua_callbacks(lua.main_state())).userthread = None;
744+
}
745+
}
746+
693747
/// Sets the warning function to be used by Lua to emit warnings.
694748
///
695749
/// Requires `feature = "lua54"`
@@ -705,7 +759,7 @@ impl Lua {
705759

706760
unsafe extern "C-unwind" fn warn_proc(ud: *mut c_void, msg: *const c_char, tocont: c_int) {
707761
let extra = ud as *mut ExtraData;
708-
callback_error_ext((*extra).raw_lua().state(), extra, |extra, _| {
762+
callback_error_ext((*extra).raw_lua().state(), extra, false, |extra, _| {
709763
let warn_callback = (*extra).warn_callback.clone();
710764
let warn_callback = mlua_expect!(warn_callback, "no warning callback set in warn_proc");
711765
if XRc::strong_count(&warn_callback) > 2 {
@@ -1444,7 +1498,7 @@ impl Lua {
14441498
Err(_) => return,
14451499
},
14461500
ffi::LUA_TTHREAD => {
1447-
ffi::lua_newthread(state);
1501+
ffi::lua_pushthread(state);
14481502
}
14491503
#[cfg(feature = "luau")]
14501504
ffi::LUA_TBUFFER => {

src/state/extra.rs

+4
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ pub(crate) struct ExtraData {
8080
pub(super) warn_callback: Option<crate::types::WarnCallback>,
8181
#[cfg(feature = "luau")]
8282
pub(super) interrupt_callback: Option<crate::types::InterruptCallback>,
83+
#[cfg(feature = "luau")]
84+
pub(super) userthread_callback: Option<crate::types::UserThreadCallback>,
8385

8486
#[cfg(feature = "luau")]
8587
pub(super) sandboxed: bool,
@@ -177,6 +179,8 @@ impl ExtraData {
177179
#[cfg(feature = "luau")]
178180
interrupt_callback: None,
179181
#[cfg(feature = "luau")]
182+
userthread_callback: None,
183+
#[cfg(feature = "luau")]
180184
sandboxed: false,
181185
#[cfg(feature = "luau")]
182186
compiler: None,

src/state/raw.rs

+17-6
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,13 @@ impl Drop for RawLua {
7070

7171
let mem_state = MemoryState::get(self.main_state());
7272

73+
#[cfg(feature = "luau")]
74+
{
75+
// Reset any callbacks
76+
(*ffi::lua_callbacks(self.main_state())).interrupt = None;
77+
(*ffi::lua_callbacks(self.main_state())).userthread = None;
78+
}
79+
7380
ffi::lua_close(self.main_state());
7481

7582
// Deallocate `MemoryState`
@@ -420,7 +427,7 @@ impl RawLua {
420427
}
421428

422429
unsafe extern "C-unwind" fn global_hook_proc(state: *mut ffi::lua_State, ar: *mut ffi::lua_Debug) {
423-
let status = callback_error_ext(state, ptr::null_mut(), move |extra, _| {
430+
let status = callback_error_ext(state, ptr::null_mut(), false, move |extra, _| {
424431
match (*extra).hook_callback.clone() {
425432
Some(hook_callback) => {
426433
let rawlua = (*extra).raw_lua();
@@ -453,7 +460,7 @@ impl RawLua {
453460
return;
454461
}
455462

456-
let status = callback_error_ext(state, ptr::null_mut(), |extra, _| {
463+
let status = callback_error_ext(state, ptr::null_mut(), false, |extra, _| {
457464
let rawlua = (*extra).raw_lua();
458465
let _guard = StateGuard::new(rawlua, state);
459466
let debug = Debug::new(rawlua, ar);
@@ -564,7 +571,11 @@ impl RawLua {
564571
let _sg = StackGuard::new(state);
565572
check_stack(state, 3)?;
566573

567-
let thread_state = if self.unlikely_memory_error() {
574+
let protect = !self.unlikely_memory_error();
575+
#[cfg(feature = "luau")]
576+
let protect = protect || (*self.extra.get()).userthread_callback.is_some();
577+
578+
let thread_state = if !protect {
568579
ffi::lua_newthread(state)
569580
} else {
570581
protect_lua!(state, 0, 1, |state| ffi::lua_newthread(state))?
@@ -1177,7 +1188,7 @@ impl RawLua {
11771188
pub(crate) fn create_callback(&self, func: Callback) -> Result<Function> {
11781189
unsafe extern "C-unwind" fn call_callback(state: *mut ffi::lua_State) -> c_int {
11791190
let upvalue = get_userdata::<CallbackUpvalue>(state, ffi::lua_upvalueindex(1));
1180-
callback_error_ext(state, (*upvalue).extra.get(), |extra, nargs| {
1191+
callback_error_ext(state, (*upvalue).extra.get(), true, |extra, nargs| {
11811192
// Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments)
11821193
// The lock must be already held as the callback is executed
11831194
let rawlua = (*extra).raw_lua();
@@ -1226,7 +1237,7 @@ impl RawLua {
12261237
// so the first upvalue is always valid
12271238
let upvalue = get_userdata::<AsyncCallbackUpvalue>(state, ffi::lua_upvalueindex(1));
12281239
let extra = (*upvalue).extra.get();
1229-
callback_error_ext(state, extra, |extra, nargs| {
1240+
callback_error_ext(state, extra, true, |extra, nargs| {
12301241
// Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments)
12311242
// The lock must be already held as the callback is executed
12321243
let rawlua = (*extra).raw_lua();
@@ -1251,7 +1262,7 @@ impl RawLua {
12511262

12521263
unsafe extern "C-unwind" fn poll_future(state: *mut ffi::lua_State) -> c_int {
12531264
let upvalue = get_userdata::<AsyncPollUpvalue>(state, ffi::lua_upvalueindex(1));
1254-
callback_error_ext(state, (*upvalue).extra.get(), |extra, _| {
1265+
callback_error_ext(state, (*upvalue).extra.get(), true, |extra, _| {
12551266
// Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments)
12561267
// The lock must be already held as the future is polled
12571268
let rawlua = (*extra).raw_lua();

src/state/util.rs

+8
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ impl Drop for StateGuard<'_> {
2727
pub(super) unsafe fn callback_error_ext<F, R>(
2828
state: *mut ffi::lua_State,
2929
mut extra: *mut ExtraData,
30+
wrap_error: bool,
3031
f: F,
3132
) -> R
3233
where
@@ -110,6 +111,13 @@ where
110111
Ok(Err(err)) => {
111112
let wrapped_error = prealloc_failure.r#use(state, extra);
112113

114+
if !wrap_error {
115+
ptr::write(wrapped_error, WrappedFailure::Error(err));
116+
get_internal_metatable::<WrappedFailure>(state);
117+
ffi::lua_setmetatable(state, -2);
118+
ffi::lua_error(state)
119+
}
120+
113121
// Build `CallbackError` with traceback
114122
let traceback = if ffi::lua_checkstack(state, ffi::LUA_TRACEBACK_STACK) != 0 {
115123
ffi::luaL_traceback(state, state, ptr::null(), 0);

src/types.rs

+6
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,12 @@ pub(crate) type InterruptCallback = XRc<dyn Fn(&Lua) -> Result<VmState> + Send>;
9090
#[cfg(all(not(feature = "send"), feature = "luau"))]
9191
pub(crate) type InterruptCallback = XRc<dyn Fn(&Lua) -> Result<VmState>>;
9292

93+
#[cfg(all(feature = "send", feature = "luau"))]
94+
pub(crate) type UserThreadCallback = XRc<dyn Fn(&Lua, crate::Value) -> Result<()> + Send>;
95+
96+
#[cfg(all(not(feature = "send"), feature = "luau"))]
97+
pub(crate) type UserThreadCallback = XRc<dyn Fn(&Lua, crate::Value) -> Result<()>>;
98+
9399
#[cfg(all(feature = "send", feature = "lua54"))]
94100
pub(crate) type WarnCallback = XRc<dyn Fn(&Lua, &str, bool) -> Result<()> + Send>;
95101

tests/hooks.rs

+3-8
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#![cfg(not(feature = "luau"))]
22

3-
use std::ops::Deref;
43
use std::sync::atomic::{AtomicI64, Ordering};
54
use std::sync::{Arc, Mutex};
65

@@ -104,14 +103,10 @@ fn test_error_within_hook() -> Result<()> {
104103
})?;
105104

106105
let err = lua.load("x = 1").exec().expect_err("panic didn't propagate");
107-
108106
match err {
109-
Error::CallbackError { cause, .. } => match cause.deref() {
110-
Error::RuntimeError(s) => assert_eq!(s, "Something happened in there!"),
111-
_ => panic!("wrong callback error kind caught"),
112-
},
113-
_ => panic!("wrong error kind caught"),
114-
};
107+
Error::RuntimeError(msg) => assert_eq!(msg, "Something happened in there!"),
108+
err => panic!("expected `RuntimeError` with a specific message, got {err:?}"),
109+
}
115110

116111
Ok(())
117112
}

tests/luau.rs

+63-6
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22

33
use std::fmt::Debug;
44
use std::fs;
5+
use std::os::raw::c_void;
56
use std::panic::{catch_unwind, AssertUnwindSafe};
6-
use std::sync::atomic::{AtomicU64, Ordering};
7+
use std::sync::atomic::{AtomicBool, AtomicPtr, AtomicU64, Ordering};
78
use std::sync::Arc;
89

910
use mlua::{Compiler, Error, Lua, LuaOptions, Result, StdLib, Table, ThreadStatus, Value, Vector, VmState};
@@ -395,11 +396,8 @@ fn test_interrupts() -> Result<()> {
395396
//
396397
lua.set_interrupt(|_| Err(Error::runtime("error from interrupt")));
397398
match f.call::<()>(()) {
398-
Err(Error::CallbackError { cause, .. }) => match *cause {
399-
Error::RuntimeError(ref m) if m == "error from interrupt" => {}
400-
ref e => panic!("expected RuntimeError with a specific message, got {:?}", e),
401-
},
402-
r => panic!("expected CallbackError, got {:?}", r),
399+
Err(Error::RuntimeError(ref msg)) => assert_eq!(msg, "error from interrupt"),
400+
res => panic!("expected `RuntimeError` with a specific message, got {res:?}"),
403401
}
404402

405403
lua.remove_interrupt();
@@ -412,3 +410,62 @@ fn test_fflags() {
412410
// We cannot really on any particular feature flag to be present
413411
assert!(Lua::set_fflag("UnknownFlag", true).is_err());
414412
}
413+
414+
#[test]
415+
fn test_thread_events() -> Result<()> {
416+
let lua = Lua::new();
417+
418+
let count = Arc::new(AtomicU64::new(0));
419+
let thread_data: Arc<(AtomicPtr<c_void>, AtomicBool)> = Arc::new(Default::default());
420+
421+
let (count2, thread_data2) = (count.clone(), thread_data.clone());
422+
lua.set_thread_event_callback(move |_, value| {
423+
count2.fetch_add(1, Ordering::Relaxed);
424+
(thread_data2.0).store(value.to_pointer() as *mut _, Ordering::Relaxed);
425+
if value.is_thread() {
426+
thread_data2.1.store(false, Ordering::Relaxed);
427+
}
428+
if value.is_light_userdata() {
429+
thread_data2.1.store(true, Ordering::Relaxed);
430+
}
431+
Ok(())
432+
});
433+
434+
let t = lua.create_thread(lua.load("return 123").into_function()?)?;
435+
assert_eq!(count.load(Ordering::Relaxed), 1);
436+
let t_ptr = t.to_pointer();
437+
assert_eq!(t_ptr, thread_data.0.load(Ordering::Relaxed));
438+
assert!(!thread_data.1.load(Ordering::Relaxed));
439+
440+
// Thead will be destroyed after GC cycle
441+
drop(t);
442+
lua.gc_collect()?;
443+
assert_eq!(count.load(Ordering::Relaxed), 2);
444+
assert_eq!(t_ptr, thread_data.0.load(Ordering::Relaxed));
445+
assert!(thread_data.1.load(Ordering::Relaxed));
446+
447+
// Check that recursion is not allowed
448+
let count3 = count.clone();
449+
lua.set_thread_event_callback(move |lua, _value| {
450+
count3.fetch_add(1, Ordering::Relaxed);
451+
let _ = lua.create_thread(lua.load("return 123").into_function().unwrap())?;
452+
Ok(())
453+
});
454+
let t = lua.create_thread(lua.load("return 123").into_function()?)?;
455+
assert_eq!(count.load(Ordering::Relaxed), 3);
456+
457+
lua.remove_thread_event_callback();
458+
drop(t);
459+
lua.gc_collect()?;
460+
assert_eq!(count.load(Ordering::Relaxed), 3);
461+
462+
// Test error inside callback
463+
lua.set_thread_event_callback(move |_, _| Err(Error::runtime("error when processing thread event")));
464+
let result = lua.create_thread(lua.load("return 123").into_function()?);
465+
assert!(result.is_err());
466+
assert!(
467+
matches!(result, Err(Error::RuntimeError(err)) if err.contains("error when processing thread event"))
468+
);
469+
470+
Ok(())
471+
}

tests/tests.rs

+1-2
Original file line numberDiff line numberDiff line change
@@ -1308,8 +1308,7 @@ fn test_warnings() -> Result<()> {
13081308
lua.set_warning_function(|_, _, _| Err(Error::runtime("warning error")));
13091309
assert!(matches!(
13101310
lua.load(r#"warn("test")"#).exec(),
1311-
Err(Error::CallbackError { cause, .. })
1312-
if matches!(*cause, Error::RuntimeError(ref err) if err == "warning error")
1311+
Err(Error::RuntimeError(ref err)) if err == "warning error"
13131312
));
13141313

13151314
// Recursive warning

0 commit comments

Comments
 (0)