diff --git a/newsfragments/2884.added.md b/newsfragments/2884.added.md new file mode 100644 index 00000000000..0d4c9265b1d --- /dev/null +++ b/newsfragments/2884.added.md @@ -0,0 +1 @@ +Added `PyErr::write_unraisable()` to report an unraisable exception to Python. diff --git a/src/err/mod.rs b/src/err/mod.rs index 0f397153286..eecbd0cc471 100644 --- a/src/err/mod.rs +++ b/src/err/mod.rs @@ -476,6 +476,40 @@ impl PyErr { unsafe { ffi::PyErr_Restore(ptype, pvalue, ptraceback) } } + /// Reports the error as unraisable. + /// + /// This calls `sys.unraisablehook()` using the current exception and obj argument. + /// + /// This method is useful to report errors in situations where there is no good mechanism + /// to report back to the Python land. In Python this is used to indicate errors in + /// background threads or destructors which are protected. In Rust code this is commonly + /// useful when you are calling into a Python callback which might fail, but there is no + /// obvious way to handle this error other than logging it. + /// + /// Calling this method has the benefit that the error goes back into a standardized callback + /// in Python which for instance allows unittests to ensure that no unraisable error + /// actually happend by hooking `sys.unraisablehook`. + /// + /// Example: + /// ```rust + /// # use pyo3::prelude::*; + /// # use pyo3::exceptions::PyRuntimeError; + /// fn failing_function() -> PyResult<()> { Err(PyRuntimeError::new_err("foo")) } + /// # fn main() -> PyResult<()> { + /// Python::with_gil(|py| { + /// match failing_function() { + /// Err(pyerr) => pyerr.write_unraisable(py, None), + /// Ok(..) => { /* do something here */ } + /// } + /// Ok(()) + /// }) + /// # } + #[inline] + pub fn write_unraisable(self, py: Python<'_>, obj: Option<&PyAny>) { + self.restore(py); + unsafe { ffi::PyErr_WriteUnraisable(obj.map_or(std::ptr::null_mut(), |x| x.as_ptr())) } + } + /// Issues a warning message. /// /// May return an `Err(PyErr)` if warnings-as-errors is enabled. diff --git a/tests/test_exceptions.rs b/tests/test_exceptions.rs index 98dab27bc70..c281beaad76 100644 --- a/tests/test_exceptions.rs +++ b/tests/test_exceptions.rs @@ -1,6 +1,7 @@ #![cfg(feature = "macros")] use pyo3::prelude::*; +use pyo3::types::PyDict; use pyo3::{exceptions, py_run, PyErr, PyResult}; use std::error::Error; use std::fmt; @@ -96,3 +97,51 @@ fn test_exception_nosegfault() { assert!(io_err().is_err()); assert!(parse_int().is_err()); } + +#[pyfunction] +fn report_unraisable(py: Python) { + use pyo3::exceptions::PyRuntimeError; + let err = PyRuntimeError::new_err("foo"); + err.write_unraisable(py, None); + + let err = PyRuntimeError::new_err("bar"); + err.write_unraisable(py, Some(py.NotImplemented().as_ref(py))); +} + +#[test] +fn test_write_unraisable() { + Python::with_gil(|py| { + let report_unraisable = wrap_pyfunction!(report_unraisable, py).unwrap(); + let locals = PyDict::new(py); + locals + .set_item("report_unraisable", report_unraisable) + .unwrap(); + + let source = r#"if True: + import sys + + captured = [] + def report(data): + captured.append(list(data)) + + original_hook = sys.unraisablehook + try: + sys.unraisablehook = report + report_unraisable() + + assert len(captured) == 2 + + assert captured[0][0] is RuntimeError + assert str(captured[0][1]) == 'foo' + assert captured[0][4] is None + + assert captured[1][0] is RuntimeError + assert str(captured[1][1]) == 'bar' + assert captured[1][4] is NotImplemented + finally: + sys.unraisablehook = original_hook + "#; + + py.run(source, Some(locals), None).unwrap(); + }); +}