From c297d1e61bde36b840172a44b9fc6b97a37be6d6 Mon Sep 17 00:00:00 2001 From: Mohsen Zohrevandi Date: Wed, 21 Apr 2021 13:45:57 -0700 Subject: [PATCH] Ensure TLS destructors run before thread joins in SGX --- library/std/src/sys/sgx/abi/mod.rs | 6 ++- library/std/src/sys/sgx/thread.rs | 69 +++++++++++++++++++++++---- library/std/src/thread/local/tests.rs | 55 ++++++++++++++++++++- 3 files changed, 119 insertions(+), 11 deletions(-) diff --git a/library/std/src/sys/sgx/abi/mod.rs b/library/std/src/sys/sgx/abi/mod.rs index a5e453034762c..f9536c4203df2 100644 --- a/library/std/src/sys/sgx/abi/mod.rs +++ b/library/std/src/sys/sgx/abi/mod.rs @@ -62,10 +62,12 @@ unsafe extern "C" fn tcs_init(secondary: bool) { extern "C" fn entry(p1: u64, p2: u64, p3: u64, secondary: bool, p4: u64, p5: u64) -> EntryReturn { // FIXME: how to support TLS in library mode? let tls = Box::new(tls::Tls::new()); - let _tls_guard = unsafe { tls.activate() }; + let tls_guard = unsafe { tls.activate() }; if secondary { - super::thread::Thread::entry(); + let join_notifier = super::thread::Thread::entry(); + drop(tls_guard); + drop(join_notifier); EntryReturn(0, 0) } else { diff --git a/library/std/src/sys/sgx/thread.rs b/library/std/src/sys/sgx/thread.rs index 55ef460cc90c5..67e2e8b59d397 100644 --- a/library/std/src/sys/sgx/thread.rs +++ b/library/std/src/sys/sgx/thread.rs @@ -9,26 +9,37 @@ pub struct Thread(task_queue::JoinHandle); pub const DEFAULT_MIN_STACK_SIZE: usize = 4096; +pub use self::task_queue::JoinNotifier; + mod task_queue { - use crate::sync::mpsc; + use super::wait_notify; use crate::sync::{Mutex, MutexGuard, Once}; - pub type JoinHandle = mpsc::Receiver<()>; + pub type JoinHandle = wait_notify::Waiter; + + pub struct JoinNotifier(Option); + + impl Drop for JoinNotifier { + fn drop(&mut self) { + self.0.take().unwrap().notify(); + } + } pub(super) struct Task { p: Box, - done: mpsc::Sender<()>, + done: JoinNotifier, } impl Task { pub(super) fn new(p: Box) -> (Task, JoinHandle) { - let (done, recv) = mpsc::channel(); + let (done, recv) = wait_notify::new(); + let done = JoinNotifier(Some(done)); (Task { p, done }, recv) } - pub(super) fn run(self) { + pub(super) fn run(self) -> JoinNotifier { (self.p)(); - let _ = self.done.send(()); + self.done } } @@ -47,6 +58,48 @@ mod task_queue { } } +/// This module provides a synchronization primitive that does not use thread +/// local variables. This is needed for signaling that a thread has finished +/// execution. The signal is sent once all TLS destructors have finished at +/// which point no new thread locals should be created. +pub mod wait_notify { + use super::super::waitqueue::{SpinMutex, WaitQueue, WaitVariable}; + use crate::sync::Arc; + + pub struct Notifier(Arc>>); + + impl Notifier { + /// Notify the waiter. The waiter is either notified right away (if + /// currently blocked in `Waiter::wait()`) or later when it calls the + /// `Waiter::wait()` method. + pub fn notify(self) { + let mut guard = self.0.lock(); + *guard.lock_var_mut() = true; + let _ = WaitQueue::notify_one(guard); + } + } + + pub struct Waiter(Arc>>); + + impl Waiter { + /// Wait for a notification. If `Notifier::notify()` has already been + /// called, this will return immediately, otherwise the current thread + /// is blocked until notified. + pub fn wait(self) { + let guard = self.0.lock(); + if *guard.lock_var() { + return; + } + WaitQueue::wait(guard, || {}); + } + } + + pub fn new() -> (Notifier, Waiter) { + let inner = Arc::new(SpinMutex::new(WaitVariable::new(false))); + (Notifier(inner.clone()), Waiter(inner)) + } +} + impl Thread { // unsafe: see thread::Builder::spawn_unchecked for safety requirements pub unsafe fn new(_stack: usize, p: Box) -> io::Result { @@ -57,7 +110,7 @@ impl Thread { Ok(Thread(handle)) } - pub(super) fn entry() { + pub(super) fn entry() -> JoinNotifier { let mut pending_tasks = task_queue::lock(); let task = rtunwrap!(Some, pending_tasks.pop()); drop(pending_tasks); // make sure to not hold the task queue lock longer than necessary @@ -78,7 +131,7 @@ impl Thread { } pub fn join(self) { - let _ = self.0.recv(); + self.0.wait(); } } diff --git a/library/std/src/thread/local/tests.rs b/library/std/src/thread/local/tests.rs index 80e6798d847b1..98f525eafb0f6 100644 --- a/library/std/src/thread/local/tests.rs +++ b/library/std/src/thread/local/tests.rs @@ -1,5 +1,6 @@ use crate::cell::{Cell, UnsafeCell}; -use crate::sync::mpsc::{channel, Sender}; +use crate::sync::atomic::{AtomicBool, Ordering}; +use crate::sync::mpsc::{self, channel, Sender}; use crate::thread::{self, LocalKey}; use crate::thread_local; @@ -207,3 +208,55 @@ fn dtors_in_dtors_in_dtors_const_init() { }); rx.recv().unwrap(); } + +// This test tests that TLS destructors have run before the thread joins. The +// test has no false positives (meaning: if the test fails, there's actually +// an ordering problem). It may have false negatives, where the test passes but +// join is not guaranteed to be after the TLS destructors. However, false +// negatives should be exceedingly rare due to judicious use of +// thread::yield_now and running the test several times. +#[test] +fn join_orders_after_tls_destructors() { + static THREAD2_LAUNCHED: AtomicBool = AtomicBool::new(false); + + for _ in 0..10 { + let (tx, rx) = mpsc::sync_channel(0); + THREAD2_LAUNCHED.store(false, Ordering::SeqCst); + + let jh = thread::spawn(move || { + struct RecvOnDrop(Cell>>); + + impl Drop for RecvOnDrop { + fn drop(&mut self) { + let rx = self.0.take().unwrap(); + while !THREAD2_LAUNCHED.load(Ordering::SeqCst) { + thread::yield_now(); + } + rx.recv().unwrap(); + } + } + + thread_local! { + static TL_RX: RecvOnDrop = RecvOnDrop(Cell::new(None)); + } + + TL_RX.with(|v| v.0.set(Some(rx))) + }); + + let tx_clone = tx.clone(); + let jh2 = thread::spawn(move || { + THREAD2_LAUNCHED.store(true, Ordering::SeqCst); + jh.join().unwrap(); + tx_clone.send(()).expect_err( + "Expecting channel to be closed because thread 1 TLS destructors must've run", + ); + }); + + while !THREAD2_LAUNCHED.load(Ordering::SeqCst) { + thread::yield_now(); + } + thread::yield_now(); + tx.send(()).expect("Expecting channel to be live because thread 2 must block on join"); + jh2.join().unwrap(); + } +}