Skip to content

Commit

Permalink
Switch to Mutex from RwLock for userdata access in send mode.
Browse files Browse the repository at this point in the history
Unfortunately RwLock allow access to the userdata from multiple threads
without enforcing `Sync` marker.
  • Loading branch information
khvzak committed Aug 24, 2024
1 parent 2857cb7 commit 23d4e25
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 17 deletions.
14 changes: 4 additions & 10 deletions src/userdata/cell.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ type DynSerialize = dyn erased_serde::Serialize + Send;
pub(crate) enum UserDataVariant<T> {
Default(XRc<UserDataCell<T>>),
#[cfg(feature = "serialize")]
Serializable(XRc<UserDataCell<ForceSync<Box<DynSerialize>>>>),
Serializable(XRc<UserDataCell<Box<DynSerialize>>>),
}

impl<T> Clone for UserDataVariant<T> {
Expand Down Expand Up @@ -82,7 +82,7 @@ impl<T> UserDataVariant<T> {
Self::Default(inner) => XRc::into_inner(inner).unwrap().value.into_inner(),
#[cfg(feature = "serialize")]
Self::Serializable(inner) => unsafe {
let raw = Box::into_raw(XRc::into_inner(inner).unwrap().value.into_inner().0);
let raw = Box::into_raw(XRc::into_inner(inner).unwrap().value.into_inner());
*Box::from_raw(raw as *mut T)
},
})
Expand Down Expand Up @@ -112,7 +112,6 @@ impl<T: Serialize + MaybeSend + 'static> UserDataVariant<T> {
#[inline(always)]
pub(crate) fn new_ser(data: T) -> Self {
let data = Box::new(data) as Box<DynSerialize>;
let data = ForceSync(data);
Self::Serializable(XRc::new(UserDataCell::new(data)))
}
}
Expand All @@ -129,7 +128,7 @@ impl Serialize for UserDataVariant<()> {
// No need to do this if the `send` feature is disabled.
#[cfg(not(feature = "send"))]
let _guard = self.try_borrow().map_err(serde::ser::Error::custom)?;
(*inner.value.get()).0.serialize(serializer)
(*inner.value.get()).serialize(serializer)
},
}
}
Expand All @@ -142,7 +141,7 @@ pub(crate) struct UserDataCell<T> {
}

unsafe impl<T: Send> Send for UserDataCell<T> {}
unsafe impl<T: Send + Sync> Sync for UserDataCell<T> {}
unsafe impl<T: Send> Sync for UserDataCell<T> {}

impl<T> UserDataCell<T> {
#[inline(always)]
Expand Down Expand Up @@ -352,11 +351,6 @@ impl<'a, T> TryFrom<&'a UserDataVariant<T>> for UserDataBorrowMut<'a, T> {
}
}

#[repr(transparent)]
pub(crate) struct ForceSync<T>(T);

unsafe impl<T: Send> Sync for ForceSync<T> {}

#[inline]
fn try_value_to_userdata<T>(value: Value) -> Result<AnyUserData> {
match value {
Expand Down
14 changes: 7 additions & 7 deletions src/userdata/lock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,32 +62,32 @@ mod lock_impl {

#[cfg(feature = "send")]
mod lock_impl {
use parking_lot::lock_api::RawRwLock;
use parking_lot::lock_api::RawMutex;

pub(crate) type RawLock = parking_lot::RawRwLock;
pub(crate) type RawLock = parking_lot::RawMutex;

impl super::UserDataLock for RawLock {
#[allow(clippy::declare_interior_mutable_const)]
const INIT: Self = <Self as parking_lot::lock_api::RawRwLock>::INIT;
const INIT: Self = <Self as parking_lot::lock_api::RawMutex>::INIT;

#[inline(always)]
fn try_lock_shared(&self) -> bool {
RawRwLock::try_lock_shared(self)
RawLock::try_lock(self)
}

#[inline(always)]
fn try_lock_exclusive(&self) -> bool {
RawRwLock::try_lock_exclusive(self)
RawLock::try_lock(self)
}

#[inline(always)]
unsafe fn unlock_shared(&self) {
RawRwLock::unlock_shared(self)
RawLock::unlock(self)
}

#[inline(always)]
unsafe fn unlock_exclusive(&self) {
RawRwLock::unlock_exclusive(self)
RawLock::unlock(self)
}
}
}
35 changes: 35 additions & 0 deletions tests/send.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#![cfg(feature = "send")]

use std::cell::UnsafeCell;
use std::marker::PhantomData;
use std::string::String as StdString;

use mlua::{AnyUserData, Error, Lua, Result, UserDataRef};
use static_assertions::{assert_impl_all, assert_not_impl_all};

#[test]
fn test_userdata_multithread_access() -> Result<()> {
let lua = Lua::new();

// This type is `Send` but not `Sync`.
struct MyUserData(#[allow(unused)] StdString, PhantomData<UnsafeCell<()>>);

assert_impl_all!(MyUserData: Send);
assert_not_impl_all!(MyUserData: Sync);

lua.globals().set(
"ud",
AnyUserData::wrap(MyUserData("hello".to_string(), PhantomData)),
)?;
// We acquired the exclusive reference.
let _ud1 = lua.globals().get::<UserDataRef<MyUserData>>("ud")?;

std::thread::scope(|s| {
s.spawn(|| {
let res = lua.globals().get::<UserDataRef<MyUserData>>("ud");
assert!(matches!(res, Err(Error::UserDataBorrowError)));
});
});

Ok(())
}

0 comments on commit 23d4e25

Please sign in to comment.