Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to optimized lock #58

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ rust-version = "1.59"

[features]
# this feature provides performance improvements using nightly features
nightly = []
nightly = ["memoffset"]

[badges]
travis-ci = { repository = "Amanieu/thread_local-rs" }
Expand All @@ -22,9 +22,11 @@ travis-ci = { repository = "Amanieu/thread_local-rs" }
once_cell = "1.5.2"
# this is required to gate `nightly` related code paths
cfg-if = "1.0.0"
crossbeam-utils = "0.8.15"
memoffset = { version = "0.9.0", optional = true }

[dev-dependencies]
criterion = "0.4.0"
criterion = "0.4"

[[bench]]
name = "thread_local"
Expand Down
31 changes: 19 additions & 12 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
#![cfg_attr(feature = "nightly", feature(thread_local))]

mod cached;
mod mutex;
mod thread_id;
mod unreachable;

Expand Down Expand Up @@ -187,10 +188,11 @@ impl<T: Send> ThreadLocal<T> {
where
F: FnOnce() -> T,
{
unsafe {
self.get_or_try(|| Ok::<T, ()>(create()))
.unchecked_unwrap_ok()
if let Some(val) = self.get() {
return val;
}

self.insert(create)
}

/// Returns the element for the current thread, or creates it if it doesn't
Expand All @@ -200,12 +202,11 @@ impl<T: Send> ThreadLocal<T> {
where
F: FnOnce() -> Result<T, E>,
{
let thread = thread_id::get();
if let Some(val) = self.get_inner(thread) {
if let Some(val) = self.get() {
return Ok(val);
}

Ok(self.insert(create()?))
self.insert_maybe(create)
}

fn get_inner(&self, thread: Thread) -> Option<&T> {
Expand All @@ -226,14 +227,22 @@ impl<T: Send> ThreadLocal<T> {
}

#[cold]
fn insert(&self, data: T) -> &T {
fn insert_maybe<F: FnOnce() -> Result<T, E>, E>(&self, gen: F) -> Result<&T, E> {
let data = gen()?;
Ok(self.insert(|| data))
}

#[cold]
fn insert<F: FnOnce() -> T>(&self, gen: F) -> &T {
// call the generator here, so it is #[cold] as well.
let data = gen();
let thread = thread_id::get();
let bucket_atomic_ptr = unsafe { self.buckets.get_unchecked(thread.bucket) };
let bucket_ptr: *const _ = bucket_atomic_ptr.load(Ordering::Acquire);

// If the bucket doesn't already exist, we need to allocate it
let bucket_ptr = if bucket_ptr.is_null() {
let new_bucket = allocate_bucket(thread.bucket_size);
let new_bucket = allocate_bucket(thread.bucket_size());

match bucket_atomic_ptr.compare_exchange(
ptr::null_mut(),
Expand All @@ -246,7 +255,7 @@ impl<T: Send> ThreadLocal<T> {
// another thread stored a new bucket before we could,
// and we can free our bucket and use that one instead
Err(bucket_ptr) => {
unsafe { deallocate_bucket(new_bucket, thread.bucket_size) }
unsafe { deallocate_bucket(new_bucket, thread.bucket_size()) }
bucket_ptr
}
}
Expand Down Expand Up @@ -495,9 +504,7 @@ impl<T: Send> Iterator for IntoIter<T> {
fn next(&mut self) -> Option<T> {
self.raw.next_mut(&mut self.thread_local).map(|entry| {
*entry.present.get_mut() = false;
unsafe {
std::mem::replace(&mut *entry.value.get(), MaybeUninit::uninit()).assume_init()
}
unsafe { mem::replace(&mut *entry.value.get(), MaybeUninit::uninit()).assume_init() }
})
}
fn size_hint(&self) -> (usize, Option<usize>) {
Expand Down
80 changes: 80 additions & 0 deletions src/mutex.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
use crossbeam_utils::Backoff;
use std::cell::UnsafeCell;
use std::ops::{Deref, DerefMut};
use std::sync::atomic::{AtomicBool, Ordering};

/// A mutex optimized for little contention.
pub(crate) struct Mutex<T> {
guard: AtomicBool,
data: UnsafeCell<T>,
}

impl<T> Mutex<T> {
#[inline]
pub const fn new(val: T) -> Self {
Self {
guard: AtomicBool::new(false),
data: UnsafeCell::new(val),
}
}

pub fn lock(&self) -> MutexGuard<'_, T> {
if self
.guard
.compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
.is_ok()
{
return MutexGuard(self);
}

let backoff = Backoff::new();
while self
.guard
.compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
.is_err()
{
backoff.snooze();
}
MutexGuard(self)
}

#[inline]
pub fn try_lock(&self) -> Option<MutexGuard<'_, T>> {
if self
.guard
.compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
.is_err()
{
return None;
}
Some(MutexGuard(self))
}
}

unsafe impl<T: Send> Send for Mutex<T> {}
unsafe impl<T: Sync> Sync for Mutex<T> {}

pub(crate) struct MutexGuard<'a, T>(&'a Mutex<T>);

impl<'a, T> Deref for MutexGuard<'a, T> {
type Target = T;

#[inline(always)]
fn deref(&self) -> &Self::Target {
unsafe { &*self.0.data.get() }
}
}

impl<'a, T> DerefMut for MutexGuard<'a, T> {
#[inline(always)]
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { &mut *self.0.data.get() }
}
}

impl<'a, T> Drop for MutexGuard<'a, T> {
#[inline]
fn drop(&mut self) {
self.0.guard.store(false, Ordering::Release);
}
}
96 changes: 63 additions & 33 deletions src/thread_id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.

use crate::mutex::Mutex;
use crate::POINTER_WIDTH;
use once_cell::sync::Lazy;
use std::cell::Cell;
use std::cmp::Reverse;
use std::collections::BinaryHeap;
use std::sync::Mutex;

/// Thread ID manager which allocates thread IDs. It attempts to aggressively
/// reuse thread IDs where possible to avoid cases where a ThreadLocal grows
Expand Down Expand Up @@ -49,38 +49,47 @@ static THREAD_ID_MANAGER: Lazy<Mutex<ThreadIdManager>> =
/// A thread ID may be reused after a thread exits.
#[derive(Clone, Copy)]
pub(crate) struct Thread {
/// The thread ID obtained from the thread ID manager.
pub(crate) id: usize,
/// The bucket this thread's local storage will be in.
pub(crate) bucket: usize,
/// The size of the bucket this thread's local storage will be in.
pub(crate) bucket_size: usize,
/// The index into the bucket this thread's local storage is in.
pub(crate) index: usize,
}

impl Thread {
/// id: The thread ID obtained from the thread ID manager.
#[inline]
fn new(id: usize) -> Self {
let bucket = usize::from(POINTER_WIDTH) - ((id + 1).leading_zeros() as usize) - 1;
let bucket_size = 1 << bucket;
let index = id - (bucket_size - 1);
Self { bucket, index }
}

Self {
id,
bucket,
bucket_size,
index,
}
/// The size of the bucket this thread's local storage will be in.
#[inline]
pub fn bucket_size(&self) -> usize {
1 << self.bucket
}
}

cfg_if::cfg_if! {
if #[cfg(feature = "nightly")] {
use memoffset::offset_of;
use std::ptr::null;
use std::cell::UnsafeCell;

// This is split into 2 thread-local variables so that we can check whether the
// thread is initialized without having to register a thread-local destructor.
//
// This makes the fast path smaller.
#[thread_local]
static mut THREAD: Option<Thread> = None;
static THREAD: UnsafeCell<ThreadWrapper> = UnsafeCell::new(ThreadWrapper {
self_ptr: null(),
thread: Thread {
index: 0,
bucket: 0,
},
});
thread_local! { static THREAD_GUARD: ThreadGuard = const { ThreadGuard { id: Cell::new(0) } }; }

// Guard to ensure the thread ID is released on thread exit.
Expand All @@ -97,17 +106,41 @@ cfg_if::cfg_if! {
// will go through get_slow which will either panic or
// initialize a new ThreadGuard.
unsafe {
THREAD = None;
(&mut *THREAD.get()).self_ptr = null();
}
THREAD_ID_MANAGER.lock().free(self.id.get());
}
}

/// Data which is unique to the current thread while it is running.
/// A thread ID may be reused after a thread exits.
///
/// This wrapper exists to hide multiple accesses to the TLS data
/// from the backend as this can lead to inefficient codegen
/// (to be precise it can lead to multiple TLS address lookups)
#[derive(Clone, Copy)]
struct ThreadWrapper {
self_ptr: *const Thread,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of self_ptr? It seems to always point to the thread field. Is this just a bool that indicates whether the thread has been initialized? How is this better than just using an Option?

Copy link
Contributor Author

@terrarier2111 terrarier2111 Jan 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is what i was describing earlier, this is the hack that tricks the compiler into actually only ever generating one get_tls_addr call and reusing the calculated address, if the compiler was smart enough, this self_ptr wouldn't be necessary at all.

thread: Thread,
}

impl ThreadWrapper {
/// The thread ID obtained from the thread ID manager.
#[inline]
fn new(id: usize) -> Self {
Self {
self_ptr: ((THREAD.get().cast_const() as usize) + offset_of!(ThreadWrapper, thread)) as *const Thread,
thread: Thread::new(id),
}
THREAD_ID_MANAGER.lock().unwrap().free(self.id.get());
}
}

/// Returns a thread ID for the current thread, allocating one if needed.
#[inline]
pub(crate) fn get() -> Thread {
if let Some(thread) = unsafe { THREAD } {
thread
let thread = unsafe { *THREAD.get() };
if !thread.self_ptr.is_null() {
unsafe { thread.self_ptr.read() }
} else {
get_slow()
}
Expand All @@ -116,12 +149,13 @@ cfg_if::cfg_if! {
/// Out-of-line slow path for allocating a thread ID.
#[cold]
fn get_slow() -> Thread {
let new = Thread::new(THREAD_ID_MANAGER.lock().unwrap().alloc());
let id = THREAD_ID_MANAGER.lock().alloc();
let new = ThreadWrapper::new(id);
unsafe {
THREAD = Some(new);
*THREAD.get() = new;
}
THREAD_GUARD.with(|guard| guard.id.set(new.id));
new
THREAD_GUARD.with(|guard| guard.id.set(id));
new.thread
}
} else {
// This is split into 2 thread-local variables so that we can check whether the
Expand All @@ -145,7 +179,7 @@ cfg_if::cfg_if! {
// will go through get_slow which will either panic or
// initialize a new ThreadGuard.
let _ = THREAD.try_with(|thread| thread.set(None));
THREAD_ID_MANAGER.lock().unwrap().free(self.id.get());
THREAD_ID_MANAGER.lock().free(self.id.get());
}
}

Expand All @@ -164,9 +198,10 @@ cfg_if::cfg_if! {
/// Out-of-line slow path for allocating a thread ID.
#[cold]
fn get_slow(thread: &Cell<Option<Thread>>) -> Thread {
let new = Thread::new(THREAD_ID_MANAGER.lock().unwrap().alloc());
let id = THREAD_ID_MANAGER.lock().alloc();
let new = Thread::new(id);
thread.set(Some(new));
THREAD_GUARD.with(|guard| guard.id.set(new.id));
THREAD_GUARD.with(|guard| guard.id.set(id));
new
}
}
Expand All @@ -175,32 +210,27 @@ cfg_if::cfg_if! {
#[test]
fn test_thread() {
let thread = Thread::new(0);
assert_eq!(thread.id, 0);
assert_eq!(thread.bucket, 0);
assert_eq!(thread.bucket_size, 1);
assert_eq!(thread.bucket_size(), 1);
assert_eq!(thread.index, 0);

let thread = Thread::new(1);
assert_eq!(thread.id, 1);
assert_eq!(thread.bucket, 1);
assert_eq!(thread.bucket_size, 2);
assert_eq!(thread.bucket_size(), 2);
assert_eq!(thread.index, 0);

let thread = Thread::new(2);
assert_eq!(thread.id, 2);
assert_eq!(thread.bucket, 1);
assert_eq!(thread.bucket_size, 2);
assert_eq!(thread.bucket_size(), 2);
assert_eq!(thread.index, 1);

let thread = Thread::new(3);
assert_eq!(thread.id, 3);
assert_eq!(thread.bucket, 2);
assert_eq!(thread.bucket_size, 4);
assert_eq!(thread.bucket_size(), 4);
assert_eq!(thread.index, 0);

let thread = Thread::new(19);
assert_eq!(thread.id, 19);
assert_eq!(thread.bucket, 4);
assert_eq!(thread.bucket_size, 16);
assert_eq!(thread.bucket_size(), 16);
assert_eq!(thread.index, 4);
}