Skip to content

Commit

Permalink
[wgpu-core] Ensure that DeviceLostCallback is always called exactly once
Browse files Browse the repository at this point in the history
* Ensure device lost closure is called exactly once before being dropped.

This requires a change to the Rust callback signature, which is now Fn
instead of FnOnce. When the Rust callback or the C closure are dropped,
they will panic if they haven't been called. `device_drop` is changed
to call the closure with a message of "Device dropped." A test is added.
  • Loading branch information
bradwerth authored Dec 19, 2023
1 parent aade481 commit 56d9d32
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 13 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ Wgpu now exposes backend feature for the Direct3D 12 (`dx12`) and Metal (`metal`
- No longer validate surfaces against their allowed extent range on configure. This caused warnings that were almost impossible to avoid. As before, the resulting behavior depends on the compositor. By @wumpf in [#4796](https://github.com/gfx-rs/wgpu/pull/4796)
- Added support for the float32-filterable feature. By @almarklein in [#4759](https://github.com/gfx-rs/wgpu/pull/4759)
- wgpu and wgpu-core features are now documented on docs.rs. By @wumpf in [#4886](https://github.com/gfx-rs/wgpu/pull/4886)
- DeviceLostClosure is guaranteed to be invoked exactly once. By @bradwerth in [#4862](https://github.com/gfx-rs/wgpu/pull/4862)

#### OpenGL
- `@builtin(instance_index)` now properly reflects the range provided in the draw call instead of always counting from 0. By @cwfitzgerald in [#4722](https://github.com/gfx-rs/wgpu/pull/4722).
Expand Down Expand Up @@ -757,7 +758,7 @@ By @cwfitzgerald in [#3671](https://github.com/gfx-rs/wgpu/pull/3671).

- Implemented basic ray-tracing api for acceleration structures, and ray-queries @daniel-keitel (started by @expenses) in [#3507](https://github.com/gfx-rs/wgpu/pull/3507)

#### Hal
#### Hal

- Added basic ray-tracing api for acceleration structures, and ray-queries @daniel-keitel (started by @expenses) in [#3507](https://github.com/gfx-rs/wgpu/pull/3507)

Expand Down
33 changes: 33 additions & 0 deletions tests/tests/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -511,3 +511,36 @@ static DEVICE_DESTROY_THEN_LOST: GpuTestConfiguration = GpuTestConfiguration::ne
"Device lost callback should have been called."
);
});

#[gpu_test]
static DEVICE_DROP_THEN_LOST: GpuTestConfiguration = GpuTestConfiguration::new()
.parameters(TestParameters::default().expect_fail(FailureCase::webgl2()))
.run_sync(|ctx| {
// This test checks that when the device is dropped (such as in a GC),
// the provided DeviceLostClosure is called with reason DeviceLostReason::Unknown.
// Fails on webgl because webgl doesn't implement drop.
let was_called = std::sync::Arc::<std::sync::atomic::AtomicBool>::new(false.into());

// Set a LoseDeviceCallback on the device.
let was_called_clone = was_called.clone();
let callback = Box::new(move |reason, message| {
was_called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
assert!(
matches!(reason, wgt::DeviceLostReason::Unknown),
"Device lost info reason should match DeviceLostReason::Unknown."
);
assert!(
message == "Device dropped.",
"Device lost info message should be \"Device dropped.\"."
);
});
ctx.device.set_device_lost_callback(callback);

// Drop the device.
drop(ctx.device);

assert!(
was_called.load(std::sync::atomic::Ordering::SeqCst),
"Device lost callback should have been called."
);
});
11 changes: 9 additions & 2 deletions wgpu-core/src/device/global.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use crate::device::trace;
use crate::{
api_log, binding_model, command, conv,
device::{
life::WaitIdleError, map_buffer, queue, DeviceError, DeviceLostClosure, HostMap,
IMPLICIT_BIND_GROUP_LAYOUT_ERROR_LABEL,
life::WaitIdleError, map_buffer, queue, DeviceError, DeviceLostClosure, DeviceLostReason,
HostMap, IMPLICIT_BIND_GROUP_LAYOUT_ERROR_LABEL,
},
global::Global,
hal_api::HalApi,
Expand Down Expand Up @@ -2239,6 +2239,11 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {

let hub = A::hub(self);
if let Some(device) = hub.devices.unregister(device_id) {
let device_lost_closure = device.lock_life().device_lost_closure.take();
if let Some(closure) = device_lost_closure {
closure.call(DeviceLostReason::Unknown, String::from("Device dropped."));
}

// The things `Device::prepare_to_die` takes care are mostly
// unnecessary here. We know our queue is empty, so we don't
// need to wait for submissions or triage them. We know we were
Expand All @@ -2254,6 +2259,8 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
}
}

// This closure will be called exactly once during "lose the device"
// or when the device is dropped, if it was never lost.
pub fn device_set_device_lost_closure<A: HalApi>(
&self,
device_id: DeviceId,
Expand Down
59 changes: 52 additions & 7 deletions wgpu-core/src/device/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,20 +211,34 @@ impl UserClosures {
not(target_feature = "atomics")
)
))]
pub type DeviceLostCallback = Box<dyn FnOnce(DeviceLostReason, String) + Send + 'static>;
pub type DeviceLostCallback = Box<dyn Fn(DeviceLostReason, String) + Send + 'static>;
#[cfg(not(any(
not(target_arch = "wasm32"),
all(
feature = "fragile-send-sync-non-atomic-wasm",
not(target_feature = "atomics")
)
)))]
pub type DeviceLostCallback = Box<dyn FnOnce(DeviceLostReason, String) + 'static>;
pub type DeviceLostCallback = Box<dyn Fn(DeviceLostReason, String) + 'static>;

pub struct DeviceLostClosureRust {
pub callback: DeviceLostCallback,
called: bool,
}

impl Drop for DeviceLostClosureRust {
fn drop(&mut self) {
if !self.called {
panic!("DeviceLostClosureRust must be called before it is dropped.");
}
}
}

#[repr(C)]
pub struct DeviceLostClosureC {
pub callback: unsafe extern "C" fn(user_data: *mut u8, reason: u8, message: *const c_char),
pub user_data: *mut u8,
called: bool,
}

#[cfg(any(
Expand All @@ -236,6 +250,14 @@ pub struct DeviceLostClosureC {
))]
unsafe impl Send for DeviceLostClosureC {}

impl Drop for DeviceLostClosureC {
fn drop(&mut self) {
if !self.called {
panic!("DeviceLostClosureC must be called before it is dropped.");
}
}
}

pub struct DeviceLostClosure {
// We wrap this so creating the enum in the C variant can be unsafe,
// allowing our call function to be safe.
Expand All @@ -249,14 +271,18 @@ pub struct DeviceLostInvocation {
}

enum DeviceLostClosureInner {
Rust { callback: DeviceLostCallback },
Rust { inner: DeviceLostClosureRust },
C { inner: DeviceLostClosureC },
}

impl DeviceLostClosure {
pub fn from_rust(callback: DeviceLostCallback) -> Self {
let inner = DeviceLostClosureRust {
callback,
called: false,
};
Self {
inner: DeviceLostClosureInner::Rust { callback },
inner: DeviceLostClosureInner::Rust { inner },
}
}

Expand All @@ -267,17 +293,36 @@ impl DeviceLostClosure {
///
/// - Both pointers must point to `'static` data, as the callback may happen at
/// an unspecified time.
pub unsafe fn from_c(inner: DeviceLostClosureC) -> Self {
pub unsafe fn from_c(closure: DeviceLostClosureC) -> Self {
// Build an inner with the values from closure, ensuring that
// inner.called is false.
let inner = DeviceLostClosureC {
callback: closure.callback,
user_data: closure.user_data,
called: false,
};
Self {
inner: DeviceLostClosureInner::C { inner },
}
}

pub(crate) fn call(self, reason: DeviceLostReason, message: String) {
match self.inner {
DeviceLostClosureInner::Rust { callback } => callback(reason, message),
DeviceLostClosureInner::Rust { mut inner } => {
if inner.called {
panic!("DeviceLostClosureRust must only be called once.");
}
inner.called = true;

(inner.callback)(reason, message)
}
// SAFETY: the contract of the call to from_c says that this unsafe is sound.
DeviceLostClosureInner::C { inner } => unsafe {
DeviceLostClosureInner::C { mut inner } => unsafe {
if inner.called {
panic!("DeviceLostClosureC must only be called once.");
}
inner.called = true;

// Ensure message is structured as a null-terminated C string. It only
// needs to live as long as the callback invocation.
let message = std::ffi::CString::new(message).unwrap();
Expand Down
4 changes: 2 additions & 2 deletions wgpu/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1210,15 +1210,15 @@ pub type SubmittedWorkDoneCallback = Box<dyn FnOnce() + 'static>;
not(target_feature = "atomics")
)
))]
pub type DeviceLostCallback = Box<dyn FnOnce(DeviceLostReason, String) + Send + 'static>;
pub type DeviceLostCallback = Box<dyn Fn(DeviceLostReason, String) + Send + 'static>;
#[cfg(not(any(
not(target_arch = "wasm32"),
all(
feature = "fragile-send-sync-non-atomic-wasm",
not(target_feature = "atomics")
)
)))]
pub type DeviceLostCallback = Box<dyn FnOnce(DeviceLostReason, String) + 'static>;
pub type DeviceLostCallback = Box<dyn Fn(DeviceLostReason, String) + 'static>;

/// An object safe variant of [`Context`] implemented by all types that implement [`Context`].
pub(crate) trait DynContext: Debug + WasmNotSendSync {
Expand Down
2 changes: 1 addition & 1 deletion wgpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2923,7 +2923,7 @@ impl Device {
/// Set a DeviceLostCallback on this device.
pub fn set_device_lost_callback(
&self,
callback: impl FnOnce(DeviceLostReason, String) + Send + 'static,
callback: impl Fn(DeviceLostReason, String) + Send + 'static,
) {
DynContext::device_set_device_lost_callback(
&*self.context,
Expand Down

0 comments on commit 56d9d32

Please sign in to comment.