From d69e8848d1e703f41a032e98e6a87d5900cd7d71 Mon Sep 17 00:00:00 2001 From: mpage Date: Fri, 1 Mar 2024 13:43:12 -0800 Subject: [PATCH] gh-114271: Make `_thread.ThreadHandle` thread-safe in free-threaded builds (GH-115190) Make `_thread.ThreadHandle` thread-safe in free-threaded builds We protect the mutable state of `ThreadHandle` using a `_PyOnceFlag`. Concurrent operations (i.e. `join` or `detach`) on `ThreadHandle` block until it is their turn to execute or an earlier operation succeeds. Once an operation has been applied successfully all future operations complete immediately. The `join()` method is now idempotent. It may be called multiple times but the underlying OS thread will only be joined once. After `join()` succeeds, any future calls to `join()` will succeed immediately. The internal thread handle `detach()` method has been removed. --- Include/internal/pycore_lock.h | 13 +++ Lib/test/test_thread.py | 91 +++++++++--------- Lib/threading.py | 24 ++--- Modules/_threadmodule.c | 167 +++++++++++++++++++++++++-------- Python/lock.c | 31 ++++++ 5 files changed, 225 insertions(+), 101 deletions(-) diff --git a/Include/internal/pycore_lock.h b/Include/internal/pycore_lock.h index c89159b55e130f6..f648be496ea4af3 100644 --- a/Include/internal/pycore_lock.h +++ b/Include/internal/pycore_lock.h @@ -136,6 +136,10 @@ typedef struct { uint8_t v; } PyEvent; +// Check if the event is set without blocking. Returns 1 if the event is set or +// 0 otherwise. +PyAPI_FUNC(int) _PyEvent_IsSet(PyEvent *evt); + // Set the event and notify any waiting threads. // Export for '_testinternalcapi' shared extension PyAPI_FUNC(void) _PyEvent_Notify(PyEvent *evt); @@ -149,6 +153,15 @@ PyAPI_FUNC(void) PyEvent_Wait(PyEvent *evt); // and 0 if the timeout expired or thread was interrupted. PyAPI_FUNC(int) PyEvent_WaitTimed(PyEvent *evt, PyTime_t timeout_ns); +// A one-time event notification with reference counting. +typedef struct _PyEventRc { + PyEvent event; + Py_ssize_t refcount; +} _PyEventRc; + +_PyEventRc *_PyEventRc_New(void); +void _PyEventRc_Incref(_PyEventRc *erc); +void _PyEventRc_Decref(_PyEventRc *erc); // _PyRawMutex implements a word-sized mutex that that does not depend on the // parking lot API, and therefore can be used in the parking lot diff --git a/Lib/test/test_thread.py b/Lib/test/test_thread.py index 931cb4b797e0b21..83235230d5c1120 100644 --- a/Lib/test/test_thread.py +++ b/Lib/test/test_thread.py @@ -189,8 +189,8 @@ def task(): with threading_helper.wait_threads_exit(): handle = thread.start_joinable_thread(task) handle.join() - with self.assertRaisesRegex(ValueError, "not joinable"): - handle.join() + # Subsequent join() calls should succeed + handle.join() def test_joinable_not_joined(self): handle_destroyed = thread.allocate_lock() @@ -233,58 +233,61 @@ def task(): with self.assertRaisesRegex(RuntimeError, "Cannot join current thread"): raise errors[0] - def test_detach_from_self(self): - errors = [] - handles = [] - start_joinable_thread_returned = thread.allocate_lock() - start_joinable_thread_returned.acquire() - thread_detached = thread.allocate_lock() - thread_detached.acquire() + def test_join_then_self_join(self): + # make sure we can't deadlock in the following scenario with + # threads t0 and t1 (see comment in `ThreadHandle_join()` for more + # details): + # + # - t0 joins t1 + # - t1 self joins + def make_lock(): + lock = thread.allocate_lock() + lock.acquire() + return lock + + error = None + self_joiner_handle = None + self_joiner_started = make_lock() + self_joiner_barrier = make_lock() + def self_joiner(): + nonlocal error + + self_joiner_started.release() + self_joiner_barrier.acquire() - def task(): - start_joinable_thread_returned.acquire() try: - handles[0].detach() + self_joiner_handle.join() except Exception as e: - errors.append(e) - finally: - thread_detached.release() + error = e + + joiner_started = make_lock() + def joiner(): + joiner_started.release() + self_joiner_handle.join() with threading_helper.wait_threads_exit(): - handle = thread.start_joinable_thread(task) - handles.append(handle) - start_joinable_thread_returned.release() - thread_detached.acquire() - with self.assertRaisesRegex(ValueError, "not joinable"): - handle.join() + self_joiner_handle = thread.start_joinable_thread(self_joiner) + # Wait for the self-joining thread to start + self_joiner_started.acquire() - assert len(errors) == 0 + # Start the thread that joins the self-joiner + joiner_handle = thread.start_joinable_thread(joiner) - def test_detach_then_join(self): - lock = thread.allocate_lock() - lock.acquire() + # Wait for the joiner to start + joiner_started.acquire() - def task(): - lock.acquire() + # Not great, but I don't think there's a deterministic way to make + # sure that the self-joining thread has been joined. + time.sleep(0.1) - with threading_helper.wait_threads_exit(): - handle = thread.start_joinable_thread(task) - # detach() returns even though the thread is blocked on lock - handle.detach() - # join() then cannot be called anymore - with self.assertRaisesRegex(ValueError, "not joinable"): - handle.join() - lock.release() - - def test_join_then_detach(self): - def task(): - pass + # Unblock the self-joiner + self_joiner_barrier.release() - with threading_helper.wait_threads_exit(): - handle = thread.start_joinable_thread(task) - handle.join() - with self.assertRaisesRegex(ValueError, "not joinable"): - handle.detach() + self_joiner_handle.join() + joiner_handle.join() + + with self.assertRaisesRegex(RuntimeError, "Cannot join current thread"): + raise error class Barrier: diff --git a/Lib/threading.py b/Lib/threading.py index b6ff00acadd58fe..ec89550d6b022ee 100644 --- a/Lib/threading.py +++ b/Lib/threading.py @@ -931,7 +931,6 @@ class is implemented. if _HAVE_THREAD_NATIVE_ID: self._native_id = None self._tstate_lock = None - self._join_lock = None self._handle = None self._started = Event() self._is_stopped = False @@ -956,14 +955,11 @@ def _after_fork(self, new_ident=None): if self._tstate_lock is not None: self._tstate_lock._at_fork_reinit() self._tstate_lock.acquire() - if self._join_lock is not None: - self._join_lock._at_fork_reinit() else: # This thread isn't alive after fork: it doesn't have a tstate # anymore. self._is_stopped = True self._tstate_lock = None - self._join_lock = None self._handle = None def __repr__(self): @@ -996,8 +992,6 @@ def start(self): if self._started.is_set(): raise RuntimeError("threads can only be started once") - self._join_lock = _allocate_lock() - with _active_limbo_lock: _limbo[self] = self try: @@ -1167,17 +1161,9 @@ def join(self, timeout=None): self._join_os_thread() def _join_os_thread(self): - join_lock = self._join_lock - if join_lock is None: - return - with join_lock: - # Calling join() multiple times would raise an exception - # in one of the callers. - if self._handle is not None: - self._handle.join() - self._handle = None - # No need to keep this around - self._join_lock = None + # self._handle may be cleared post-fork + if self._handle is not None: + self._handle.join() def _wait_for_tstate_lock(self, block=True, timeout=-1): # Issue #18808: wait for the thread state to be gone. @@ -1478,6 +1464,10 @@ def __init__(self): with _active_limbo_lock: _active[self._ident] = self + def _join_os_thread(self): + # No ThreadHandle for main thread + pass + # Helper thread-local instance to detect when a _DummyThread # is collected. Not a part of the public API. diff --git a/Modules/_threadmodule.c b/Modules/_threadmodule.c index 4c2185cc7ea1fd3..3a8f77d6dfbbc6a 100644 --- a/Modules/_threadmodule.c +++ b/Modules/_threadmodule.c @@ -1,9 +1,9 @@ - /* Thread module */ /* Interface to Sjoerd's portable C thread library */ #include "Python.h" #include "pycore_interp.h" // _PyInterpreterState.threads.count +#include "pycore_lock.h" #include "pycore_moduleobject.h" // _PyModule_GetState() #include "pycore_modsupport.h" // _PyArg_NoKeywords() #include "pycore_pylifecycle.h" @@ -44,24 +44,76 @@ get_thread_state(PyObject *module) // _ThreadHandle type +// Handles transition from RUNNING to one of JOINED, DETACHED, or INVALID (post +// fork). +typedef enum { + THREAD_HANDLE_RUNNING = 1, + THREAD_HANDLE_JOINED = 2, + THREAD_HANDLE_DETACHED = 3, + THREAD_HANDLE_INVALID = 4, +} ThreadHandleState; + +// A handle around an OS thread. +// +// The OS thread is either joined or detached after the handle is destroyed. +// +// Joining the handle is idempotent; the underlying OS thread is joined or +// detached only once. Concurrent join operations are serialized until it is +// their turn to execute or an earlier operation completes successfully. Once a +// join has completed successfully all future joins complete immediately. typedef struct { PyObject_HEAD struct llist_node node; // linked list node (see _pythread_runtime_state) + + // The `ident` and `handle` fields are immutable once the object is visible + // to threads other than its creator, thus they do not need to be accessed + // atomically. PyThread_ident_t ident; PyThread_handle_t handle; - char joinable; + + // Holds a value from the `ThreadHandleState` enum. + int state; + + // Set immediately before `thread_run` returns to indicate that the OS + // thread is about to exit. This is used to avoid false positives when + // detecting self-join attempts. See the comment in `ThreadHandle_join()` + // for a more detailed explanation. + _PyEventRc *thread_is_exiting; + + // Serializes calls to `join`. + _PyOnceFlag once; } ThreadHandleObject; +static inline int +get_thread_handle_state(ThreadHandleObject *handle) +{ + return _Py_atomic_load_int(&handle->state); +} + +static inline void +set_thread_handle_state(ThreadHandleObject *handle, ThreadHandleState state) +{ + _Py_atomic_store_int(&handle->state, state); +} + static ThreadHandleObject* new_thread_handle(thread_module_state* state) { + _PyEventRc *event = _PyEventRc_New(); + if (event == NULL) { + PyErr_NoMemory(); + return NULL; + } ThreadHandleObject* self = PyObject_New(ThreadHandleObject, state->thread_handle_type); if (self == NULL) { + _PyEventRc_Decref(event); return NULL; } self->ident = 0; self->handle = 0; - self->joinable = 0; + self->thread_is_exiting = event; + self->once = (_PyOnceFlag){0}; + self->state = THREAD_HANDLE_INVALID; HEAD_LOCK(&_PyRuntime); llist_insert_tail(&_PyRuntime.threads.handles, &self->node); @@ -82,13 +134,21 @@ ThreadHandle_dealloc(ThreadHandleObject *self) } HEAD_UNLOCK(&_PyRuntime); - if (self->joinable) { - int ret = PyThread_detach_thread(self->handle); - if (ret) { + // It's safe to access state non-atomically: + // 1. This is the destructor; nothing else holds a reference. + // 2. The refcount going to zero is a "synchronizes-with" event; + // all changes from other threads are visible. + if (self->state == THREAD_HANDLE_RUNNING) { + // This is typically short so no need to release the GIL + if (PyThread_detach_thread(self->handle)) { PyErr_SetString(ThreadError, "Failed detaching thread"); PyErr_WriteUnraisable(tp); } + else { + self->state = THREAD_HANDLE_DETACHED; + } } + _PyEventRc_Decref(self->thread_is_exiting); PyObject_Free(self); Py_DECREF(tp); } @@ -109,8 +169,9 @@ _PyThread_AfterFork(struct _pythread_runtime_state *state) continue; } - // Disallow calls to detach() and join() as they could crash. - hobj->joinable = 0; + // Disallow calls to join() as they could crash. We are the only + // thread; it's safe to set this without an atomic. + hobj->state = THREAD_HANDLE_INVALID; llist_remove(node); } } @@ -128,48 +189,54 @@ ThreadHandle_get_ident(ThreadHandleObject *self, void *ignored) return PyLong_FromUnsignedLongLong(self->ident); } - -static PyObject * -ThreadHandle_detach(ThreadHandleObject *self, void* ignored) +static int +join_thread(ThreadHandleObject *handle) { - if (!self->joinable) { - PyErr_SetString(PyExc_ValueError, - "the thread is not joinable and thus cannot be detached"); - return NULL; - } - self->joinable = 0; - // This is typically short so no need to release the GIL - int ret = PyThread_detach_thread(self->handle); - if (ret) { - PyErr_SetString(ThreadError, "Failed detaching thread"); - return NULL; + assert(get_thread_handle_state(handle) == THREAD_HANDLE_RUNNING); + + int err; + Py_BEGIN_ALLOW_THREADS + err = PyThread_join_thread(handle->handle); + Py_END_ALLOW_THREADS + if (err) { + PyErr_SetString(ThreadError, "Failed joining thread"); + return -1; } - Py_RETURN_NONE; + set_thread_handle_state(handle, THREAD_HANDLE_JOINED); + return 0; } static PyObject * ThreadHandle_join(ThreadHandleObject *self, void* ignored) { - if (!self->joinable) { - PyErr_SetString(PyExc_ValueError, "the thread is not joinable"); + if (get_thread_handle_state(self) == THREAD_HANDLE_INVALID) { + PyErr_SetString(PyExc_ValueError, + "the handle is invalid and thus cannot be joined"); return NULL; } - if (self->ident == PyThread_get_thread_ident_ex()) { + + // We want to perform this check outside of the `_PyOnceFlag` to prevent + // deadlock in the scenario where another thread joins us and we then + // attempt to join ourselves. However, it's not safe to check thread + // identity once the handle's os thread has finished. We may end up reusing + // the identity stored in the handle and erroneously think we are + // attempting to join ourselves. + // + // To work around this, we set `thread_is_exiting` immediately before + // `thread_run` returns. We can be sure that we are not attempting to join + // ourselves if the handle's thread is about to exit. + if (!_PyEvent_IsSet(&self->thread_is_exiting->event) && + self->ident == PyThread_get_thread_ident_ex()) { // PyThread_join_thread() would deadlock or error out. PyErr_SetString(ThreadError, "Cannot join current thread"); return NULL; } - // Before actually joining, we must first mark the thread as non-joinable, - // as joining several times simultaneously or sequentially is undefined behavior. - self->joinable = 0; - int ret; - Py_BEGIN_ALLOW_THREADS - ret = PyThread_join_thread(self->handle); - Py_END_ALLOW_THREADS - if (ret) { - PyErr_SetString(ThreadError, "Failed joining thread"); + + if (_PyOnceFlag_CallOnce(&self->once, (_Py_once_fn_t *)join_thread, + self) == -1) { return NULL; } + assert(get_thread_handle_state(self) == THREAD_HANDLE_JOINED); Py_RETURN_NONE; } @@ -180,7 +247,6 @@ static PyGetSetDef ThreadHandle_getsetlist[] = { static PyMethodDef ThreadHandle_methods[] = { - {"detach", (PyCFunction)ThreadHandle_detach, METH_NOARGS}, {"join", (PyCFunction)ThreadHandle_join, METH_NOARGS}, {0, 0} }; @@ -1210,11 +1276,15 @@ _localdummy_destroyed(PyObject *localweakref, PyObject *dummyweakref) /* Module functions */ +// bootstate is used to "bootstrap" new threads. Any arguments needed by +// `thread_run()`, which can only take a single argument due to platform +// limitations, are contained in bootstate. struct bootstate { PyThreadState *tstate; PyObject *func; PyObject *args; PyObject *kwargs; + _PyEventRc *thread_is_exiting; }; @@ -1226,6 +1296,9 @@ thread_bootstate_free(struct bootstate *boot, int decref) Py_DECREF(boot->args); Py_XDECREF(boot->kwargs); } + if (boot->thread_is_exiting != NULL) { + _PyEventRc_Decref(boot->thread_is_exiting); + } PyMem_RawFree(boot); } @@ -1236,6 +1309,10 @@ thread_run(void *boot_raw) struct bootstate *boot = (struct bootstate *) boot_raw; PyThreadState *tstate = boot->tstate; + // `thread_is_exiting` needs to be set after bootstate has been freed + _PyEventRc *thread_is_exiting = boot->thread_is_exiting; + boot->thread_is_exiting = NULL; + // gh-108987: If _thread.start_new_thread() is called before or while // Python is being finalized, thread_run() can called *after*. // _PyRuntimeState_SetFinalizing() is called. At this point, all Python @@ -1280,6 +1357,11 @@ thread_run(void *boot_raw) _PyThreadState_DeleteCurrent(tstate); exit: + if (thread_is_exiting != NULL) { + _PyEvent_Notify(&thread_is_exiting->event); + _PyEventRc_Decref(thread_is_exiting); + } + // bpo-44434: Don't call explicitly PyThread_exit_thread(). On Linux with // the glibc, pthread_exit() can abort the whole process if dlopen() fails // to open the libgcc_s.so library (ex: EMFILE error). @@ -1308,7 +1390,8 @@ static int do_start_new_thread(thread_module_state* state, PyObject *func, PyObject* args, PyObject* kwargs, int joinable, - PyThread_ident_t* ident, PyThread_handle_t* handle) + PyThread_ident_t* ident, PyThread_handle_t* handle, + _PyEventRc *thread_is_exiting) { PyInterpreterState *interp = _PyInterpreterState_GET(); if (!_PyInterpreterState_HasFeature(interp, Py_RTFLAGS_THREADS)) { @@ -1341,6 +1424,10 @@ do_start_new_thread(thread_module_state* state, boot->func = Py_NewRef(func); boot->args = Py_NewRef(args); boot->kwargs = Py_XNewRef(kwargs); + boot->thread_is_exiting = thread_is_exiting; + if (thread_is_exiting != NULL) { + _PyEventRc_Incref(thread_is_exiting); + } int err; if (joinable) { @@ -1392,7 +1479,7 @@ thread_PyThread_start_new_thread(PyObject *module, PyObject *fargs) PyThread_ident_t ident = 0; PyThread_handle_t handle; if (do_start_new_thread(state, func, args, kwargs, /*joinable=*/ 0, - &ident, &handle)) { + &ident, &handle, NULL)) { return NULL; } return PyLong_FromUnsignedLongLong(ident); @@ -1436,13 +1523,13 @@ thread_PyThread_start_joinable_thread(PyObject *module, PyObject *func) return NULL; } if (do_start_new_thread(state, func, args, /*kwargs=*/ NULL, /*joinable=*/ 1, - &hobj->ident, &hobj->handle)) { + &hobj->ident, &hobj->handle, hobj->thread_is_exiting)) { Py_DECREF(args); Py_DECREF(hobj); return NULL; } + set_thread_handle_state(hobj, THREAD_HANDLE_RUNNING); Py_DECREF(args); - hobj->joinable = 1; return (PyObject*) hobj; } diff --git a/Python/lock.c b/Python/lock.c index 5fa8bf78da23808..de25adce3851050 100644 --- a/Python/lock.c +++ b/Python/lock.c @@ -249,6 +249,13 @@ _PyRawMutex_UnlockSlow(_PyRawMutex *m) } } +int +_PyEvent_IsSet(PyEvent *evt) +{ + uint8_t v = _Py_atomic_load_uint8(&evt->v); + return v == _Py_LOCKED; +} + void _PyEvent_Notify(PyEvent *evt) { @@ -297,6 +304,30 @@ PyEvent_WaitTimed(PyEvent *evt, PyTime_t timeout_ns) } } +_PyEventRc * +_PyEventRc_New(void) +{ + _PyEventRc *erc = (_PyEventRc *)PyMem_RawCalloc(1, sizeof(_PyEventRc)); + if (erc != NULL) { + erc->refcount = 1; + } + return erc; +} + +void +_PyEventRc_Incref(_PyEventRc *erc) +{ + _Py_atomic_add_ssize(&erc->refcount, 1); +} + +void +_PyEventRc_Decref(_PyEventRc *erc) +{ + if (_Py_atomic_add_ssize(&erc->refcount, -1) == 1) { + PyMem_RawFree(erc); + } +} + static int unlock_once(_PyOnceFlag *o, int res) {