Skip to content

Commit

Permalink
Fix race condition in the Windows thread / pthread translation layer
Browse files Browse the repository at this point in the history
When spawning a Windows thread we have small worker wrapper function that translates
between the interfaces of Windows and POSIX threads.
This wrapper is given a pointer that might get stale before the worker starts running,
resulting in UB and crashes.
This commit adds synchronization so that we know the wrapper has finished reading the data
it needs before we allow the main thread to resume execution.
  • Loading branch information
yoniko committed Dec 17, 2022
1 parent dde58cd commit f5afaf1
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 17 deletions.
61 changes: 49 additions & 12 deletions lib/common/threading.c
Original file line number Diff line number Diff line change
Expand Up @@ -34,35 +34,72 @@ int g_ZSTD_threading_useless_symbol;

/* === Implementation === */

typedef struct {
void* (*start_routine)(void*);
void* arg;
int initialized;
ZSTD_pthread_cond_t initialized_cond;
ZSTD_pthread_mutex_t initialized_mutex;
} ZSTD_thread_params_t;

static unsigned __stdcall worker(void *arg)
{
ZSTD_pthread_t* const thread = (ZSTD_pthread_t*) arg;
thread->start_routine(thread->arg);
ZSTD_thread_params_t* const thread_param = (ZSTD_thread_params_t*)arg;
void* (*start_routine)(void*) = thread_param->start_routine;
void* thread_arg = thread_param->arg;

/* Signal main thread that we are running and do not depend on its memory anymore */
ZSTD_pthread_mutex_lock(&thread_param->initialized_mutex);
thread_param->initialized = 1;
ZSTD_pthread_mutex_unlock(&thread_param->initialized_mutex);
ZSTD_pthread_cond_signal(&thread_param->initialized_cond);

start_routine(thread_arg);

return 0;
}

int ZSTD_pthread_create(ZSTD_pthread_t* thread, const void* unused,
void* (*start_routine) (void*), void* arg)
{
ZSTD_thread_params_t thread_param;
int error = 0;
(void)unused;
thread->arg = arg;
thread->start_routine = start_routine;
thread->handle = (HANDLE) _beginthreadex(NULL, 0, worker, thread, 0, NULL);

if (!thread->handle)
thread_param.start_routine = start_routine;
thread_param.arg = arg;
thread_param.initialized = 0;

/* Setup thread initialization synchronization */
error |= ZSTD_pthread_cond_init(&thread_param.initialized_cond, NULL);
error |= ZSTD_pthread_mutex_init(&thread_param.initialized_mutex, NULL);
if(error)
return -1;
ZSTD_pthread_mutex_lock(&thread_param.initialized_mutex);

/* Spawn thread */
*thread = (HANDLE)_beginthreadex(NULL, 0, worker, &thread_param, 0, NULL);
if (!thread)
return errno;
else
return 0;

/* Wait for thread to be initialized */
while(!thread_param.initialized) {
ZSTD_pthread_cond_wait(&thread_param.initialized_cond, &thread_param.initialized_mutex);
}
ZSTD_pthread_mutex_unlock(&thread_param.initialized_mutex);
ZSTD_pthread_mutex_destroy(&thread_param.initialized_mutex);
ZSTD_pthread_cond_destroy(&thread_param.initialized_cond);

return 0;
}

int ZSTD_pthread_join(ZSTD_pthread_t thread)
{
DWORD result;

if (!thread.handle) return 0;
if (!thread) return 0;

result = WaitForSingleObject(thread.handle, INFINITE);
CloseHandle(thread.handle);
result = WaitForSingleObject(thread, INFINITE);
CloseHandle(thread);

switch (result) {
case WAIT_OBJECT_0:
Expand Down
6 changes: 1 addition & 5 deletions lib/common/threading.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,7 @@ extern "C" {
#define ZSTD_pthread_cond_broadcast(a) WakeAllConditionVariable((a))

/* ZSTD_pthread_create() and ZSTD_pthread_join() */
typedef struct {
HANDLE handle;
void* (*start_routine)(void*);
void* arg;
} ZSTD_pthread_t;
typedef HANDLE ZSTD_pthread_t;

int ZSTD_pthread_create(ZSTD_pthread_t* thread, const void* unused,
void* (*start_routine) (void*), void* arg);
Expand Down

0 comments on commit f5afaf1

Please sign in to comment.