Skip to content

Commit

Permalink
Add thread create/destroy callbacks to TaskPool (bevyengine#6561)
Browse files Browse the repository at this point in the history
# Objective
Fix bevyengine#1991. Allow users to have a bit more control over the creation and finalization of the threads in `TaskPool`.

## Solution
Add new methods to `TaskPoolBuilder` that expose callbacks that are called to initialize and finalize each thread in the `TaskPool`.

Unlike the proposed solution in bevyengine#1991, the callback is argument-less. If an an identifier is needed, `std::thread::current` should provide that information easily.

Added a unit test to ensure that they're being called correctly.
  • Loading branch information
james7132 authored and alradish committed Jan 22, 2023
1 parent 39cd9d1 commit 8ce5a4f
Showing 1 changed file with 98 additions and 14 deletions.
112 changes: 98 additions & 14 deletions crates/bevy_tasks/src/task_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,18 @@ use futures_lite::{future, pin, FutureExt};

use crate::Task;

struct CallOnDrop(Option<Arc<dyn Fn() + Send + Sync + 'static>>);

impl Drop for CallOnDrop {
fn drop(&mut self) {
if let Some(call) = self.0.as_ref() {
call();
}
}
}

/// Used to create a [`TaskPool`]
#[derive(Debug, Default, Clone)]
#[derive(Default)]
#[must_use]
pub struct TaskPoolBuilder {
/// If set, we'll set up the thread pool to use at most `num_threads` threads.
Expand All @@ -24,6 +34,9 @@ pub struct TaskPoolBuilder {
/// Allows customizing the name of the threads - helpful for debugging. If set, threads will
/// be named <thread_name> (<thread_index>), i.e. "MyThreadPool (2)"
thread_name: Option<String>,

on_thread_spawn: Option<Arc<dyn Fn() + Send + Sync + 'static>>,
on_thread_destroy: Option<Arc<dyn Fn() + Send + Sync + 'static>>,
}

impl TaskPoolBuilder {
Expand Down Expand Up @@ -52,13 +65,27 @@ impl TaskPoolBuilder {
self
}

/// Sets a callback that is invoked once for every created thread as it starts.
///
/// This is called on the thread itself and has access to all thread-local storage.
/// This will block running async tasks on the thread until the callback completes.
pub fn on_thread_spawn(mut self, f: impl Fn() + Send + Sync + 'static) -> Self {
self.on_thread_spawn = Some(Arc::new(f));
self
}

/// Sets a callback that is invoked once for every created thread as it terminates.
///
/// This is called on the thread itself and has access to all thread-local storage.
/// This will block thread termination until the callback completes.
pub fn on_thread_destroy(mut self, f: impl Fn() + Send + Sync + 'static) -> Self {
self.on_thread_destroy = Some(Arc::new(f));
self
}

/// Creates a new [`TaskPool`] based on the current options.
pub fn build(self) -> TaskPool {
TaskPool::new_internal(
self.num_threads,
self.stack_size,
self.thread_name.as_deref(),
)
TaskPool::new_internal(self)
}
}

Expand Down Expand Up @@ -88,36 +115,42 @@ impl TaskPool {
TaskPoolBuilder::new().build()
}

fn new_internal(
num_threads: Option<usize>,
stack_size: Option<usize>,
thread_name: Option<&str>,
) -> Self {
fn new_internal(builder: TaskPoolBuilder) -> Self {
let (shutdown_tx, shutdown_rx) = async_channel::unbounded::<()>();

let executor = Arc::new(async_executor::Executor::new());

let num_threads = num_threads.unwrap_or_else(crate::available_parallelism);
let num_threads = builder
.num_threads
.unwrap_or_else(crate::available_parallelism);

let threads = (0..num_threads)
.map(|i| {
let ex = Arc::clone(&executor);
let shutdown_rx = shutdown_rx.clone();

let thread_name = if let Some(thread_name) = thread_name {
let thread_name = if let Some(thread_name) = builder.thread_name.as_deref() {
format!("{thread_name} ({i})")
} else {
format!("TaskPool ({i})")
};
let mut thread_builder = thread::Builder::new().name(thread_name);

if let Some(stack_size) = stack_size {
if let Some(stack_size) = builder.stack_size {
thread_builder = thread_builder.stack_size(stack_size);
}

let on_thread_spawn = builder.on_thread_spawn.clone();
let on_thread_destroy = builder.on_thread_destroy.clone();

thread_builder
.spawn(move || {
TaskPool::LOCAL_EXECUTOR.with(|local_executor| {
if let Some(on_thread_spawn) = on_thread_spawn {
on_thread_spawn();
drop(on_thread_spawn);
}
let _destructor = CallOnDrop(on_thread_destroy);
loop {
let res = std::panic::catch_unwind(|| {
let tick_forever = async move {
Expand Down Expand Up @@ -452,6 +485,57 @@ mod tests {
assert_eq!(count.load(Ordering::Relaxed), 100);
}

#[test]
fn test_thread_callbacks() {
let counter = Arc::new(AtomicI32::new(0));
let start_counter = counter.clone();
{
let barrier = Arc::new(Barrier::new(11));
let last_barrier = barrier.clone();
// Build and immediately drop to terminate
let _pool = TaskPoolBuilder::new()
.num_threads(10)
.on_thread_spawn(move || {
start_counter.fetch_add(1, Ordering::Relaxed);
barrier.clone().wait();
})
.build();
last_barrier.wait();
assert_eq!(10, counter.load(Ordering::Relaxed));
}
assert_eq!(10, counter.load(Ordering::Relaxed));
let end_counter = counter.clone();
{
let _pool = TaskPoolBuilder::new()
.num_threads(20)
.on_thread_destroy(move || {
end_counter.fetch_sub(1, Ordering::Relaxed);
})
.build();
assert_eq!(10, counter.load(Ordering::Relaxed));
}
assert_eq!(-10, counter.load(Ordering::Relaxed));
let start_counter = counter.clone();
let end_counter = counter.clone();
{
let barrier = Arc::new(Barrier::new(6));
let last_barrier = barrier.clone();
let _pool = TaskPoolBuilder::new()
.num_threads(5)
.on_thread_spawn(move || {
start_counter.fetch_add(1, Ordering::Relaxed);
barrier.wait();
})
.on_thread_destroy(move || {
end_counter.fetch_sub(1, Ordering::Relaxed);
})
.build();
last_barrier.wait();
assert_eq!(-5, counter.load(Ordering::Relaxed));
}
assert_eq!(-10, counter.load(Ordering::Relaxed));
}

#[test]
fn test_mixed_spawn_on_scope_and_spawn() {
let pool = TaskPool::new();
Expand Down

0 comments on commit 8ce5a4f

Please sign in to comment.