From 91fe02da45b40623295b8db9e054e7c8b35fe9f6 Mon Sep 17 00:00:00 2001 From: Alex Orlenko Date: Tue, 24 Sep 2024 14:31:28 +0100 Subject: [PATCH] Add `LuaNativeFn`/`LuaNativeFnMut`/`LuaNativeAsyncFn` traits for using in `Function::wrap` --- src/function.rs | 68 ++++++++++++++++++++++----- src/lib.rs | 4 +- src/prelude.rs | 19 ++++---- src/state.rs | 6 +++ src/traits.rs | 92 +++++++++++++++++++++++++++++++++++++ tests/async.rs | 29 +++++++++++- tests/chunk.rs | 2 +- tests/function.rs | 114 +++++++++++++++++++++++++++++++++++++--------- tests/types.rs | 12 ++--- 9 files changed, 293 insertions(+), 53 deletions(-) diff --git a/src/function.rs b/src/function.rs index fdf4c1a7..37211122 100644 --- a/src/function.rs +++ b/src/function.rs @@ -5,6 +5,7 @@ use std::{mem, ptr, slice}; use crate::error::{Error, Result}; use crate::state::Lua; use crate::table::Table; +use crate::traits::{LuaNativeFn, LuaNativeFnMut}; use crate::types::{Callback, LuaType, MaybeSend, ValueRef}; use crate::util::{ assert_stack, check_stack, linenumber_to_usize, pop_error, ptr_to_lossy_str, ptr_to_str, StackGuard, @@ -13,6 +14,7 @@ use crate::value::{FromLuaMulti, IntoLua, IntoLuaMulti, Value}; #[cfg(feature = "async")] use { + crate::traits::LuaNativeAsyncFn, crate::types::AsyncCallback, std::future::{self, Future}, }; @@ -522,31 +524,56 @@ impl Function { /// Wraps a Rust function or closure, returning an opaque type that implements [`IntoLua`] /// trait. #[inline] - pub fn wrap(func: F) -> impl IntoLua + pub fn wrap(func: F) -> impl IntoLua where + F: LuaNativeFn> + MaybeSend + 'static, A: FromLuaMulti, R: IntoLuaMulti, - F: Fn(&Lua, A) -> Result + MaybeSend + 'static, { WrappedFunction(Box::new(move |lua, nargs| unsafe { let args = A::from_stack_args(nargs, 1, None, lua)?; - func(lua.lua(), args)?.push_into_stack_multi(lua) + func.call(args)?.push_into_stack_multi(lua) })) } /// Wraps a Rust mutable closure, returning an opaque type that implements [`IntoLua`] trait. - #[inline] - pub fn wrap_mut(func: F) -> impl IntoLua + pub fn wrap_mut(func: F) -> impl IntoLua where + F: LuaNativeFnMut> + MaybeSend + 'static, A: FromLuaMulti, R: IntoLuaMulti, - F: FnMut(&Lua, A) -> Result + MaybeSend + 'static, { let func = RefCell::new(func); WrappedFunction(Box::new(move |lua, nargs| unsafe { let mut func = func.try_borrow_mut().map_err(|_| Error::RecursiveMutCallback)?; let args = A::from_stack_args(nargs, 1, None, lua)?; - func(lua.lua(), args)?.push_into_stack_multi(lua) + func.call(args)?.push_into_stack_multi(lua) + })) + } + + #[inline] + pub fn wrap_raw(func: F) -> impl IntoLua + where + F: LuaNativeFn + MaybeSend + 'static, + A: FromLuaMulti, + { + WrappedFunction(Box::new(move |lua, nargs| unsafe { + let args = A::from_stack_args(nargs, 1, None, lua)?; + func.call(args).push_into_stack_multi(lua) + })) + } + + #[inline] + pub fn wrap_raw_mut(func: F) -> impl IntoLua + where + F: LuaNativeFnMut + MaybeSend + 'static, + A: FromLuaMulti, + { + let func = RefCell::new(func); + WrappedFunction(Box::new(move |lua, nargs| unsafe { + let mut func = func.try_borrow_mut().map_err(|_| Error::RecursiveMutCallback)?; + let args = A::from_stack_args(nargs, 1, None, lua)?; + func.call(args).push_into_stack_multi(lua) })) } @@ -554,23 +581,40 @@ impl Function { /// trait. #[cfg(feature = "async")] #[cfg_attr(docsrs, doc(cfg(feature = "async")))] - pub fn wrap_async(func: F) -> impl IntoLua + pub fn wrap_async(func: F) -> impl IntoLua where + F: LuaNativeAsyncFn> + MaybeSend + 'static, A: FromLuaMulti, R: IntoLuaMulti, - F: Fn(Lua, A) -> FR + MaybeSend + 'static, - FR: Future> + MaybeSend + 'static, { WrappedAsyncFunction(Box::new(move |rawlua, nargs| unsafe { let args = match A::from_stack_args(nargs, 1, None, rawlua) { Ok(args) => args, Err(e) => return Box::pin(future::ready(Err(e))), }; - let lua = rawlua.lua().clone(); - let fut = func(lua.clone(), args); + let lua = rawlua.lua(); + let fut = func.call(args); Box::pin(async move { fut.await?.push_into_stack_multi(lua.raw_lua()) }) })) } + + #[cfg(feature = "async")] + #[cfg_attr(docsrs, doc(cfg(feature = "async")))] + pub fn wrap_raw_async(func: F) -> impl IntoLua + where + F: LuaNativeAsyncFn + MaybeSend + 'static, + A: FromLuaMulti, + { + WrappedAsyncFunction(Box::new(move |rawlua, nargs| unsafe { + let args = match A::from_stack_args(nargs, 1, None, rawlua) { + Ok(args) => args, + Err(e) => return Box::pin(future::ready(Err(e))), + }; + let lua = rawlua.lua(); + let fut = func.call(args); + Box::pin(async move { fut.await.push_into_stack_multi(lua.raw_lua()) }) + })) + } } impl IntoLua for WrappedFunction { diff --git a/src/lib.rs b/src/lib.rs index c402e11c..4d489aa5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -115,7 +115,7 @@ pub use crate::stdlib::StdLib; pub use crate::string::{BorrowedBytes, BorrowedStr, String}; pub use crate::table::{Table, TablePairs, TableSequence}; pub use crate::thread::{Thread, ThreadStatus}; -pub use crate::traits::ObjectLike; +pub use crate::traits::{LuaNativeFn, LuaNativeFnMut, ObjectLike}; pub use crate::types::{ AppDataRef, AppDataRefMut, Integer, LightUserData, MaybeSend, Number, RegistryKey, VmState, }; @@ -133,7 +133,7 @@ pub use crate::hook::HookTriggers; pub use crate::{chunk::Compiler, function::CoverageInfo, types::Vector}; #[cfg(feature = "async")] -pub use crate::thread::AsyncThread; +pub use crate::{thread::AsyncThread, traits::LuaNativeAsyncFn}; #[cfg(feature = "serialize")] #[doc(inline)] diff --git a/src/prelude.rs b/src/prelude.rs index 329f2ee7..e26649ce 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -5,14 +5,15 @@ pub use crate::{ AnyUserData as LuaAnyUserData, Chunk as LuaChunk, Error as LuaError, ErrorContext as LuaErrorContext, ExternalError as LuaExternalError, ExternalResult as LuaExternalResult, FromLua, FromLuaMulti, Function as LuaFunction, FunctionInfo as LuaFunctionInfo, GCMode as LuaGCMode, Integer as LuaInteger, - IntoLua, IntoLuaMulti, LightUserData as LuaLightUserData, Lua, LuaOptions, MetaMethod as LuaMetaMethod, - MultiValue as LuaMultiValue, Nil as LuaNil, Number as LuaNumber, ObjectLike as LuaObjectLike, - RegistryKey as LuaRegistryKey, Result as LuaResult, StdLib as LuaStdLib, String as LuaString, - Table as LuaTable, TablePairs as LuaTablePairs, TableSequence as LuaTableSequence, Thread as LuaThread, - ThreadStatus as LuaThreadStatus, UserData as LuaUserData, UserDataFields as LuaUserDataFields, - UserDataMetatable as LuaUserDataMetatable, UserDataMethods as LuaUserDataMethods, - UserDataRef as LuaUserDataRef, UserDataRefMut as LuaUserDataRefMut, - UserDataRegistry as LuaUserDataRegistry, Value as LuaValue, VmState as LuaVmState, + IntoLua, IntoLuaMulti, LightUserData as LuaLightUserData, Lua, LuaNativeFn, LuaNativeFnMut, LuaOptions, + MetaMethod as LuaMetaMethod, MultiValue as LuaMultiValue, Nil as LuaNil, Number as LuaNumber, + ObjectLike as LuaObjectLike, RegistryKey as LuaRegistryKey, Result as LuaResult, StdLib as LuaStdLib, + String as LuaString, Table as LuaTable, TablePairs as LuaTablePairs, TableSequence as LuaTableSequence, + Thread as LuaThread, ThreadStatus as LuaThreadStatus, UserData as LuaUserData, + UserDataFields as LuaUserDataFields, UserDataMetatable as LuaUserDataMetatable, + UserDataMethods as LuaUserDataMethods, UserDataRef as LuaUserDataRef, + UserDataRefMut as LuaUserDataRefMut, UserDataRegistry as LuaUserDataRegistry, Value as LuaValue, + VmState as LuaVmState, }; #[cfg(not(feature = "luau"))] @@ -25,7 +26,7 @@ pub use crate::{CoverageInfo as LuaCoverageInfo, Vector as LuaVector}; #[cfg(feature = "async")] #[doc(no_inline)] -pub use crate::AsyncThread as LuaAsyncThread; +pub use crate::{AsyncThread as LuaAsyncThread, LuaNativeAsyncFn}; #[cfg(feature = "serialize")] #[doc(no_inline)] diff --git a/src/state.rs b/src/state.rs index d8f0b57a..0a93c798 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1553,6 +1553,12 @@ impl Lua { T::from_lua(value, self) } + /// Converts a value that implements `IntoLua` into a `FromLua` variant. + #[inline] + pub fn convert(&self, value: impl IntoLua) -> Result { + U::from_lua(value.into_lua(self)?, self) + } + /// Converts a value that implements `IntoLuaMulti` into a `MultiValue` instance. #[inline] pub fn pack_multi(&self, t: impl IntoLuaMulti) -> Result { diff --git a/src/traits.rs b/src/traits.rs index 0fd48d2b..2fd8dc64 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -2,6 +2,7 @@ use std::string::String as StdString; use crate::error::Result; use crate::private::Sealed; +use crate::types::MaybeSend; use crate::value::{FromLua, FromLuaMulti, IntoLua, IntoLuaMulti}; #[cfg(feature = "async")] @@ -76,3 +77,94 @@ pub trait ObjectLike: Sealed { /// This might invoke the `__tostring` metamethod. fn to_string(&self) -> Result; } + +/// A trait for types that can be used as Lua functions. +pub trait LuaNativeFn { + type Output: IntoLuaMulti; + + fn call(&self, args: A) -> Self::Output; +} + +/// A trait for types with mutable state that can be used as Lua functions. +pub trait LuaNativeFnMut { + type Output: IntoLuaMulti; + + fn call(&mut self, args: A) -> Self::Output; +} + +/// A trait for types that returns a future and can be used as Lua functions. +#[cfg(feature = "async")] +pub trait LuaNativeAsyncFn { + type Output: IntoLuaMulti; + + fn call(&self, args: A) -> impl Future + MaybeSend + 'static; +} + +macro_rules! impl_lua_native_fn { + ($($A:ident),*) => { + impl LuaNativeFn<($($A,)*)> for FN + where + FN: Fn($($A,)*) -> R + MaybeSend + 'static, + ($($A,)*): FromLuaMulti, + R: IntoLuaMulti, + { + type Output = R; + + #[allow(non_snake_case)] + fn call(&self, args: ($($A,)*)) -> Self::Output { + let ($($A,)*) = args; + self($($A,)*) + } + } + + impl LuaNativeFnMut<($($A,)*)> for FN + where + FN: FnMut($($A,)*) -> R + MaybeSend + 'static, + ($($A,)*): FromLuaMulti, + R: IntoLuaMulti, + { + type Output = R; + + #[allow(non_snake_case)] + fn call(&mut self, args: ($($A,)*)) -> Self::Output { + let ($($A,)*) = args; + self($($A,)*) + } + } + + #[cfg(feature = "async")] + impl LuaNativeAsyncFn<($($A,)*)> for FN + where + FN: Fn($($A,)*) -> Fut + MaybeSend + 'static, + ($($A,)*): FromLuaMulti, + Fut: Future + MaybeSend + 'static, + R: IntoLuaMulti, + { + type Output = R; + + #[allow(non_snake_case)] + fn call(&self, args: ($($A,)*)) -> impl Future + MaybeSend + 'static { + let ($($A,)*) = args; + self($($A,)*) + } + } + }; +} + +impl_lua_native_fn!(); +impl_lua_native_fn!(A); +impl_lua_native_fn!(A, B); +impl_lua_native_fn!(A, B, C); +impl_lua_native_fn!(A, B, C, D); +impl_lua_native_fn!(A, B, C, D, E); +impl_lua_native_fn!(A, B, C, D, E, F); +impl_lua_native_fn!(A, B, C, D, E, F, G); +impl_lua_native_fn!(A, B, C, D, E, F, G, H); +impl_lua_native_fn!(A, B, C, D, E, F, G, H, I); +impl_lua_native_fn!(A, B, C, D, E, F, G, H, I, J); +impl_lua_native_fn!(A, B, C, D, E, F, G, H, I, J, K); +impl_lua_native_fn!(A, B, C, D, E, F, G, H, I, J, K, L); +impl_lua_native_fn!(A, B, C, D, E, F, G, H, I, J, K, L, M); +impl_lua_native_fn!(A, B, C, D, E, F, G, H, I, J, K, L, M, N); +impl_lua_native_fn!(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O); +impl_lua_native_fn!(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P); diff --git a/tests/async.rs b/tests/async.rs index 0f4101b4..f939a684 100644 --- a/tests/async.rs +++ b/tests/async.rs @@ -1,5 +1,6 @@ #![cfg(feature = "async")] +use std::string::String as StdString; use std::sync::Arc; use std::time::Duration; @@ -39,12 +40,38 @@ async fn test_async_function() -> Result<()> { async fn test_async_function_wrap() -> Result<()> { let lua = Lua::new(); - let f = Function::wrap_async(|_, s: String| async move { Ok(s) }); + let f = Function::wrap_async(|s: StdString| async move { + tokio::task::yield_now().await; + Ok(s) + }); lua.globals().set("f", f)?; + let res: String = lua.load(r#"f("hello")"#).eval_async().await?; + assert_eq!(res, "hello"); + Ok(()) +} + +#[tokio::test] +async fn test_async_function_wrap_raw() -> Result<()> { + let lua = Lua::new(); + + let f = Function::wrap_raw_async(|s: StdString| async move { + tokio::task::yield_now().await; + s + }); + lua.globals().set("f", f)?; let res: String = lua.load(r#"f("hello")"#).eval_async().await?; assert_eq!(res, "hello"); + // Return error + let ferr = Function::wrap_raw_async(|| async move { + tokio::task::yield_now().await; + Err::<(), _>("some error") + }); + lua.globals().set("ferr", ferr)?; + let (_, err): (Value, String) = lua.load(r#"ferr()"#).eval_async().await?; + assert_eq!(err, "some error"); + Ok(()) } diff --git a/tests/chunk.rs b/tests/chunk.rs index 403a6c56..910cfbff 100644 --- a/tests/chunk.rs +++ b/tests/chunk.rs @@ -42,7 +42,7 @@ fn test_chunk_macro() -> Result<()> { data.raw_set("num", 1)?; let ud = mlua::AnyUserData::wrap("hello"); - let f = mlua::Function::wrap(|_lua, ()| Ok(())); + let f = mlua::Function::wrap(|| Ok(())); lua.globals().set("g", 123)?; diff --git a/tests/function.rs b/tests/function.rs index 7415e5e9..cec11f38 100644 --- a/tests/function.rs +++ b/tests/function.rs @@ -1,4 +1,4 @@ -use mlua::{Function, Lua, Result, String, Table}; +use mlua::{Error, Function, Lua, Result, String, Table}; #[test] fn test_function() -> Result<()> { @@ -271,31 +271,101 @@ fn test_function_deep_clone() -> Result<()> { #[test] fn test_function_wrap() -> Result<()> { - use mlua::Error; - let lua = Lua::new(); - lua.globals().set("f", Function::wrap(|_, s: String| Ok(s)))?; - lua.load(r#"assert(f("hello") == "hello")"#).exec().unwrap(); - - let mut _i = false; - lua.globals().set( - "f", - Function::wrap_mut(move |lua, ()| { - _i = true; - lua.globals().get::("f")?.call::<()>(()) - }), - )?; - match lua.globals().get::("f")?.call::<()>(()) { - Err(Error::CallbackError { ref cause, .. }) => match *cause.as_ref() { - Error::CallbackError { ref cause, .. } => match *cause.as_ref() { - Error::RecursiveMutCallback { .. } => {} - ref other => panic!("incorrect result: {other:?}"), - }, - ref other => panic!("incorrect result: {other:?}"), + let f = Function::wrap(|s: String, n| Ok(s.to_str().unwrap().repeat(n))); + lua.globals().set("f", f)?; + lua.load(r#"assert(f("hello", 2) == "hellohello")"#) + .exec() + .unwrap(); + + // Return error + let ferr = Function::wrap(|| Err::<(), _>(Error::runtime("some error"))); + lua.globals().set("ferr", ferr)?; + lua.load( + r#" + local ok, err = pcall(ferr) + assert(not ok and tostring(err):find("some error")) + "#, + ) + .exec() + .unwrap(); + + // Mutable callback + let mut i = 0; + let fmut = Function::wrap_mut(move || { + i += 1; + Ok(i) + }); + lua.globals().set("fmut", fmut)?; + lua.load(r#"fmut(); fmut(); assert(fmut() == 3)"#).exec().unwrap(); + + // Check mutable callback with error + let fmut_err = Function::wrap_mut(|| Err::<(), _>(Error::runtime("some error"))); + lua.globals().set("fmut_err", fmut_err)?; + lua.load( + r#" + local ok, err = pcall(fmut_err) + assert(not ok and tostring(err):find("some error")) + "#, + ) + .exec() + .unwrap(); + + // Check recursive mut callback error + let fmut = Function::wrap_mut(|f: Function| match f.call::<()>(&f) { + Err(Error::CallbackError { cause, .. }) => match cause.as_ref() { + Error::RecursiveMutCallback { .. } => Ok(()), + other => panic!("incorrect result: {other:?}"), }, other => panic!("incorrect result: {other:?}"), - }; + }); + let fmut = lua.convert::(fmut)?; + assert!(fmut.call::<()>(&fmut).is_ok()); + + Ok(()) +} + +#[test] +fn test_function_wrap_raw() -> Result<()> { + let lua = Lua::new(); + + let f = Function::wrap_raw(|| "hello"); + lua.globals().set("f", f)?; + lua.load(r#"assert(f() == "hello")"#).exec().unwrap(); + + // Return error + let ferr = Function::wrap_raw(|| Err::<(), _>("some error")); + lua.globals().set("ferr", ferr)?; + lua.load( + r#" + local _, err = ferr() + assert(err == "some error") + "#, + ) + .exec() + .unwrap(); + + // Mutable callback + let mut i = 0; + let fmut = Function::wrap_raw_mut(move || { + i += 1; + i + }); + lua.globals().set("fmut", fmut)?; + lua.load(r#"fmut(); fmut(); assert(fmut() == 3)"#).exec().unwrap(); + + // Check mutable callback with error + let fmut_err = Function::wrap_raw_mut(|| Err::<(), _>("some error")); + lua.globals().set("fmut_err", fmut_err)?; + lua.load( + r#" + local _, err = fmut_err() + assert(err == "some error") + "#, + ) + .exec() + .unwrap(); Ok(()) } diff --git a/tests/types.rs b/tests/types.rs index 830851ae..27bbe04d 100644 --- a/tests/types.rs +++ b/tests/types.rs @@ -30,7 +30,7 @@ fn test_boolean_type_metatable() -> Result<()> { let lua = Lua::new(); let mt = lua.create_table()?; - mt.set("__add", Function::wrap(|_, (a, b): (bool, bool)| Ok(a || b)))?; + mt.set("__add", Function::wrap(|a, b| Ok(a || b)))?; lua.set_type_metatable::(Some(mt)); lua.load(r#"assert(true + true == true)"#).exec().unwrap(); @@ -48,7 +48,7 @@ fn test_lightuserdata_type_metatable() -> Result<()> { let mt = lua.create_table()?; mt.set( "__add", - Function::wrap(|_, (a, b): (LightUserData, LightUserData)| { + Function::wrap(|a: LightUserData, b: LightUserData| { Ok(LightUserData((a.0 as usize + b.0 as usize) as *mut c_void)) }), )?; @@ -76,7 +76,7 @@ fn test_number_type_metatable() -> Result<()> { let lua = Lua::new(); let mt = lua.create_table()?; - mt.set("__call", Function::wrap(|_, (n1, n2): (f64, f64)| Ok(n1 * n2)))?; + mt.set("__call", Function::wrap(|n1: f64, n2: f64| Ok(n1 * n2)))?; lua.set_type_metatable::(Some(mt)); lua.load(r#"assert((1.5)(3.0) == 4.5)"#).exec().unwrap(); lua.load(r#"assert((5)(5) == 25)"#).exec().unwrap(); @@ -91,7 +91,7 @@ fn test_string_type_metatable() -> Result<()> { let mt = lua.create_table()?; mt.set( "__add", - Function::wrap(|_, (a, b): (LuaString, LuaString)| Ok(format!("{}{}", a.to_str()?, b.to_str()?))), + Function::wrap(|a: String, b: String| Ok(format!("{a}{b}"))), )?; lua.set_type_metatable::(Some(mt)); @@ -107,7 +107,7 @@ fn test_function_type_metatable() -> Result<()> { let mt = lua.create_table()?; mt.set( "__index", - Function::wrap(|_, (_, key): (Function, String)| Ok(format!("function.{key}"))), + Function::wrap(|_: Function, key: String| Ok(format!("function.{key}"))), )?; lua.set_type_metatable::(Some(mt)); @@ -125,7 +125,7 @@ fn test_thread_type_metatable() -> Result<()> { let mt = lua.create_table()?; mt.set( "__index", - Function::wrap(|_, (_, key): (Thread, String)| Ok(format!("thread.{key}"))), + Function::wrap(|_: Thread, key: String| Ok(format!("thread.{key}"))), )?; lua.set_type_metatable::(Some(mt));