diff --git a/src/lib.rs b/src/lib.rs index a404594c..4d801a5d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -113,7 +113,7 @@ pub use crate::traits::{ FromLua, FromLuaMulti, IntoLua, IntoLuaMulti, LuaNativeFn, LuaNativeFnMut, ObjectLike, }; pub use crate::types::{ - AppDataRef, AppDataRefMut, Either, Integer, LightUserData, MaybeSend, Number, RegistryKey, VmState, + AppDataRef, AppDataRefMut, Either, Integer, LightUserData, MaybeSend, Number, RegistryKey, VmState, ThreadEventInfo }; pub use crate::userdata::{ AnyUserData, MetaMethod, UserData, UserDataFields, UserDataMetatable, UserDataMethods, UserDataRef, diff --git a/src/prelude.rs b/src/prelude.rs index 68ba8f2f..76397e34 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -14,6 +14,7 @@ pub use crate::{ UserDataMethods as LuaUserDataMethods, UserDataRef as LuaUserDataRef, UserDataRefMut as LuaUserDataRefMut, UserDataRegistry as LuaUserDataRegistry, Value as LuaValue, VmState as LuaVmState, + ThreadEventInfo as LuaThreadEventInfo }; #[cfg(not(feature = "luau"))] diff --git a/src/state.rs b/src/state.rs index 35d9e4ec..a29884ae 100644 --- a/src/state.rs +++ b/src/state.rs @@ -21,7 +21,7 @@ use crate::thread::Thread; use crate::traits::{FromLua, FromLuaMulti, IntoLua, IntoLuaMulti}; use crate::types::{ AppDataRef, AppDataRefMut, ArcReentrantMutexGuard, Integer, LuaType, MaybeSend, Number, ReentrantMutex, - ReentrantMutexGuard, RegistryKey, VmState, XRc, XWeak, + ReentrantMutexGuard, RegistryKey, VmState, XRc, XWeak, ThreadEventInfo }; use crate::userdata::{AnyUserData, UserData, UserDataProxy, UserDataRegistry, UserDataStorage}; use crate::util::{ @@ -671,6 +671,71 @@ impl Lua { } } + /// Sets a callback that will be called by Luau whenever a thread is created/destroyed. + /// + /// Often used for keeping track of threads. + #[cfg(any(feature = "luau", doc))] + #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] + pub fn set_thread_event_callback(&self, callback: F) + where + F: Fn(&Lua, ThreadEventInfo) -> Result<()> + MaybeSend + 'static, + { + use std::rc::Rc; + + unsafe extern "C-unwind" fn userthread_proc(parent: *mut ffi::lua_State, state: *mut ffi::lua_State) { + callback_error_ext(state, ptr::null_mut(), move |extra, _| { + let raw_lua: &RawLua = (*extra).raw_lua(); + let _guard = StateGuard::new(raw_lua, state); + + let userthread_cb = (*extra).userthread_callback.clone(); + let userthread_cb = mlua_expect!(userthread_cb, "no userthread callback set in userthread_proc"); + if parent.is_null() { + raw_lua.push(Value::Nil).unwrap(); + } else { + raw_lua.push_ref_thread(parent).unwrap(); + } + if parent.is_null() { + let event_info = ThreadEventInfo::Destroying(state.cast_const().cast()); + let main_state = raw_lua.main_state(); + if main_state == state { + return Ok(()); // Don't process Destroying event on main thread. + } + let main_extra = ExtraData::get(main_state); + let main_raw_lua: &RawLua = (*main_extra).raw_lua(); + let _guard = StateGuard::new(main_raw_lua, state); + userthread_cb((*main_extra).lua(), event_info) + } else { + raw_lua.push_ref_thread(parent).unwrap(); + let event_info = match raw_lua.pop_value() { + Value::Thread(thr) => ThreadEventInfo::Created(thr), + _ => unimplemented!() + }; + userthread_cb((*extra).lua(), event_info) + } + }); + } + + // Set interrupt callback + let lua = self.lock(); + unsafe { + (*lua.extra.get()).userthread_callback = Some(Rc::new(callback)); + (*ffi::lua_callbacks(lua.main_state())).userthread = Some(userthread_proc); + } + } + + /// Removes any thread event function previously set by `set_thread_event_callback`. + /// + /// This function has no effect if a callback was not previously set. + #[cfg(any(feature = "luau", doc))] + #[cfg_attr(docsrs, doc(cfg(feature = "luau")))] + pub fn remove_thread_event_callback(&self) { + let lua = self.lock(); + unsafe { + (*lua.extra.get()).userthread_callback = None; + (*ffi::lua_callbacks(lua.main_state())).userthread = None; + } + } + /// Sets the warning function to be used by Lua to emit warnings. /// /// Requires `feature = "lua54"` diff --git a/src/state/extra.rs b/src/state/extra.rs index d1823b5c..2937b106 100644 --- a/src/state/extra.rs +++ b/src/state/extra.rs @@ -80,6 +80,8 @@ pub(crate) struct ExtraData { pub(super) warn_callback: Option, #[cfg(feature = "luau")] pub(super) interrupt_callback: Option, + #[cfg(feature = "luau")] + pub(super) userthread_callback: Option, #[cfg(feature = "luau")] pub(super) sandboxed: bool, @@ -177,6 +179,8 @@ impl ExtraData { #[cfg(feature = "luau")] interrupt_callback: None, #[cfg(feature = "luau")] + userthread_callback: None, + #[cfg(feature = "luau")] sandboxed: false, #[cfg(feature = "luau")] compiler: None, diff --git a/src/state/raw.rs b/src/state/raw.rs index 0731f846..c86a239c 100644 --- a/src/state/raw.rs +++ b/src/state/raw.rs @@ -64,7 +64,8 @@ impl Drop for RawLua { } let mem_state = MemoryState::get(self.main_state()); - + #[cfg(feature = "luau")] // Fixes a crash during shutdown + { (*ffi::lua_callbacks(self.main_state())).userthread = None; } ffi::lua_close(self.main_state()); // Deallocate `MemoryState` @@ -556,6 +557,21 @@ impl RawLua { value.push_into_stack(self) } + pub(crate) unsafe fn push_ref_thread(&self, ref_thread: *mut ffi::lua_State) -> Result<()> { + let state = self.state(); + check_stack(state, 1)?; + let _sg = StackGuard::new(ref_thread); + check_stack(ref_thread, 1)?; + + if self.unlikely_memory_error() { + ffi::lua_pushthread(ref_thread) + } else { + protect_lua!(ref_thread, 0, 1, |ref_thread| ffi::lua_pushthread(ref_thread))? + }; + ffi::lua_xmove(ref_thread, self.state(), 1); + Ok(()) + } + /// Pushes a `Value` (by reference) onto the Lua stack. /// /// Uses 2 stack spaces, does not call `checkstack`. diff --git a/src/types.rs b/src/types.rs index afeb239d..0f0d6019 100644 --- a/src/types.rs +++ b/src/types.rs @@ -6,6 +6,7 @@ use crate::error::Result; #[cfg(not(feature = "luau"))] use crate::hook::Debug; use crate::state::{ExtraData, Lua, RawLua}; +use crate::thread::Thread; // Re-export mutex wrappers pub(crate) use sync::{ArcReentrantMutexGuard, ReentrantMutex, ReentrantMutexGuard, XRc, XWeak}; @@ -73,6 +74,17 @@ pub enum VmState { Yield, } +/// Information about a thread event. +/// +/// For creating a thread, it contains the thread that created it. +/// +/// This is useful for tracking the origin of all threads. +#[cfg(feature = "luau")] +pub enum ThreadEventInfo { + Created(Thread), + Destroying(*const ()) // Pointer of thread +} + #[cfg(all(feature = "send", not(feature = "luau")))] pub(crate) type HookCallback = Rc Result + Send>; @@ -85,6 +97,13 @@ pub(crate) type InterruptCallback = Rc Result + Send>; #[cfg(all(not(feature = "send"), feature = "luau"))] pub(crate) type InterruptCallback = Rc Result>; +#[cfg(all(feature = "send", feature = "luau"))] +pub(crate) type ThreadEventCallback = Rc Result<()> + Send>; + +#[cfg(all(not(feature = "send"), feature = "luau"))] +pub(crate) type ThreadEventCallback = Rc Result<()>>; + + #[cfg(all(feature = "send", feature = "lua54"))] pub(crate) type WarnCallback = Box Result<()> + Send>;