diff --git a/guide/src/async-await.md b/guide/src/async-await.md index 3995524971d..077f32b585a 100644 --- a/guide/src/async-await.md +++ b/guide/src/async-await.md @@ -96,4 +96,24 @@ To make a Rust future awaitable in Python, PyO3 defines a [`Coroutine`]({{#PYO3_ Each `coroutine.send` call is translated to `Future::poll` call. If a [`CancelHandle`]({{#PYO3_DOCS_URL}}/pyo3/coroutine/struct.CancelHandle.html) parameter is declared, the exception passed to `coroutine.throw` call is stored in it and can be retrieved with [`CancelHandle::cancelled`]({{#PYO3_DOCS_URL}}/pyo3/coroutine/struct.CancelHandle.html#method.cancelled); otherwise, it cancels the Rust future, and the exception is reraised; -*The type does not yet have a public constructor until the design is finalized.* \ No newline at end of file +Coroutine can also be instantiated directly + +```rust +# # ![allow(dead_code)] +use pyo3::prelude::*; +use pyo3::coroutine::{CancelHandle, Coroutine}; + +#[pyfunction] +fn new_coroutine(py: Python<'_>) -> Coroutine { + let mut cancel = CancelHandle::new(); + let throw_callback = cancel.throw_callback(); + let future = async move { + cancel.cancelled().await; + PyResult::Ok(()) + }; + Coroutine::new(pyo3::intern!(py, "my_coro"), future) + .with_qualname_prefix("MyClass") + .with_throw_callback(throw_callback) + .with_allow_threads(true) +} +``` diff --git a/newsfragments/3613.added.md b/newsfragments/3613.added.md new file mode 100644 index 00000000000..6016f54ffa2 --- /dev/null +++ b/newsfragments/3613.added.md @@ -0,0 +1 @@ +Expose `Coroutine` constructor \ No newline at end of file diff --git a/pyo3-macros-backend/src/method.rs b/pyo3-macros-backend/src/method.rs index 7ced62c6d2b..022067fbf61 100644 --- a/pyo3-macros-backend/src/method.rs +++ b/pyo3-macros-backend/src/method.rs @@ -504,13 +504,13 @@ impl<'a> FnSpec<'a> { }; let mut call = quote! {{ let future = #future; - _pyo3::impl_::coroutine::new_coroutine( + _pyo3::coroutine::Coroutine::new( _pyo3::intern!(py, stringify!(#python_name)), - #qualname_prefix, - #throw_callback, - #allow_threads, async move { _pyo3::impl_::wrap::OkWrap::wrap(future.await) }, ) + .with_qualname_prefix(#qualname_prefix) + .with_throw_callback(#throw_callback) + .with_allow_threads(#allow_threads) }}; if cancel_handle.is_some() { call = quote! {{ diff --git a/src/coroutine.rs b/src/coroutine.rs index c2a7fba36d2..dbd6d65b896 100644 --- a/src/coroutine.rs +++ b/src/coroutine.rs @@ -12,7 +12,7 @@ use pyo3_macros::{pyclass, pymethods}; use crate::{ coroutine::waker::CoroutineWaker, - exceptions::{PyAttributeError, PyRuntimeError, PyStopIteration}, + exceptions::{PyRuntimeError, PyStopIteration}, pyclass::IterNextOutput, types::PyString, IntoPy, Py, PyErr, PyObject, PyResult, Python, @@ -26,20 +26,19 @@ pub(crate) mod cancel; mod trio; pub(crate) mod waker; -use crate::coroutine::cancel::ThrowCallback; use crate::panic::PanicException; -pub use cancel::CancelHandle; +pub use cancel::{CancelHandle, ThrowCallback}; const COROUTINE_REUSED_ERROR: &str = "cannot reuse already awaited coroutine"; /// Python coroutine wrapping a [`Future`]. #[pyclass(crate = "crate")] pub struct Coroutine { - name: Option>, + future: Option> + Send>>>, + name: Py, qualname_prefix: Option<&'static str>, throw_callback: Option, allow_threads: bool, - future: Option> + Send>>>, waker: Option>, } @@ -50,13 +49,7 @@ impl Coroutine { /// (should always be `None` anyway). /// /// `Coroutine `throw` drop the wrapped future and reraise the exception passed - pub(crate) fn new( - name: Option>, - qualname_prefix: Option<&'static str>, - throw_callback: Option, - allow_threads: bool, - future: F, - ) -> Self + pub fn new(name: impl Into>, future: F) -> Self where F: Future> + Send + 'static, T: IntoPy, @@ -68,15 +61,36 @@ impl Coroutine { Ok(obj.into_py(unsafe { Python::assume_gil_acquired() })) }; Self { - name, - qualname_prefix, - throw_callback, - allow_threads, future: Some(Box::pin(wrap)), + name: name.into(), + qualname_prefix: None, + throw_callback: None, + allow_threads: false, waker: None, } } + /// Set a prefix for `__qualname__`, which will be joined with a "." + pub fn with_qualname_prefix(mut self, prefix: impl Into>) -> Self { + self.qualname_prefix = prefix.into(); + self + } + + /// Register a callback for coroutine `throw` method. + /// + /// The exception passed to `throw` is then redirected to this callback, notifying the + /// associated [`CancelHandle`], without being reraised. + pub fn with_throw_callback(mut self, callback: impl Into>) -> Self { + self.throw_callback = callback.into(); + self + } + + /// Release the GIL while polling the future wrapped. + pub fn with_allow_threads(mut self, allow_threads: bool) -> Self { + self.allow_threads = allow_threads; + self + } + fn poll_inner( &mut self, py: Python<'_>, @@ -151,22 +165,18 @@ pub(crate) fn iter_result(result: IterNextOutput) -> PyResul #[pymethods(crate = "crate")] impl Coroutine { #[getter] - fn __name__(&self, py: Python<'_>) -> PyResult> { - match &self.name { - Some(name) => Ok(name.clone_ref(py)), - None => Err(PyAttributeError::new_err("__name__")), - } + fn __name__(&self, py: Python<'_>) -> Py { + self.name.clone_ref(py) } #[getter] fn __qualname__(&self, py: Python<'_>) -> PyResult> { - match (&self.name, &self.qualname_prefix) { - (Some(name), Some(prefix)) => Ok(format!("{}.{}", prefix, name.as_ref(py).to_str()?) + Ok(match &self.qualname_prefix { + Some(prefix) => format!("{}.{}", prefix, self.name.as_ref(py).to_str()?) .as_str() - .into_py(py)), - (Some(name), None) => Ok(name.clone_ref(py)), - (None, _) => Err(PyAttributeError::new_err("__qualname__")), - } + .into_py(py), + None => self.name.clone_ref(py), + }) } fn send(&mut self, py: Python<'_>, value: PyObject) -> PyResult { diff --git a/src/coroutine/cancel.rs b/src/coroutine/cancel.rs index d02d59a04bf..4be8226263b 100644 --- a/src/coroutine/cancel.rs +++ b/src/coroutine/cancel.rs @@ -56,7 +56,7 @@ impl CancelHandle { Cancelled(self).await } - #[doc(hidden)] + /// Instantiate a [`ThrowCallback`] associated to this cancel handle. pub fn throw_callback(&self) -> ThrowCallback { ThrowCallback(self.0.clone()) } @@ -71,7 +71,7 @@ impl Future for Cancelled<'_> { } } -#[doc(hidden)] +/// Callback for coroutine `throw` method, notifying the associated [`CancelHandle`] pub struct ThrowCallback(Arc); impl ThrowCallback { diff --git a/src/impl_/coroutine.rs b/src/impl_/coroutine.rs index f689ee7cb2f..31889de69f5 100644 --- a/src/impl_/coroutine.rs +++ b/src/impl_/coroutine.rs @@ -1,33 +1,8 @@ use std::future::Future; use std::mem; -use crate::coroutine::cancel::ThrowCallback; use crate::pyclass::boolean_struct::False; -use crate::{ - coroutine::Coroutine, types::PyString, IntoPy, Py, PyAny, PyCell, PyClass, PyErr, PyObject, - PyRef, PyRefMut, PyResult, Python, -}; - -pub fn new_coroutine( - name: &PyString, - qualname_prefix: Option<&'static str>, - throw_callback: Option, - allow_threads: bool, - future: F, -) -> Coroutine -where - F: Future> + Send + 'static, - T: IntoPy, - E: Into, -{ - Coroutine::new( - Some(name.into()), - qualname_prefix, - throw_callback, - allow_threads, - future, - ) -} +use crate::{Py, PyAny, PyCell, PyClass, PyRef, PyRefMut, PyResult, Python}; fn get_ptr(obj: &Py) -> *mut T { // SAFETY: Py can be casted as *const PyCell diff --git a/tests/test_coroutine.rs b/tests/test_coroutine.rs index 41650f9a822..9954c543899 100644 --- a/tests/test_coroutine.rs +++ b/tests/test_coroutine.rs @@ -1,10 +1,12 @@ #![cfg(feature = "macros")] #![cfg(not(target_arch = "wasm32"))] use std::ops::Deref; +use std::sync::Arc; use std::{task::Poll, thread, time::Duration}; use futures::{channel::oneshot, future::poll_fn, FutureExt}; -use pyo3::coroutine::CancelHandle; +use pyo3::coroutine::{CancelHandle, Coroutine}; +use pyo3::sync::GILOnceCell; use pyo3::types::{IntoPyDict, PyType}; use pyo3::{prelude::*, py_run}; @@ -183,3 +185,53 @@ fn test_async_method_receiver() { async fn method_mut(&mut self) {} } } + +#[test] +fn multi_thread_event_loop() { + Python::with_gil(|gil| { + let sleep = wrap_pyfunction!(sleep, gil).unwrap(); + let test = r#" + import asyncio + import threading + loop = asyncio.new_event_loop() + # spawn the sleep task and run just one iteration of the event loop + # to schedule the sleep wakeup + task = loop.create_task(sleep(0.1)) + loop.stop() + loop.run_forever() + assert not task.done() + # spawn a thread to complete the execution of the sleep task + def target(loop, task): + loop.run_until_complete(task) + thread = threading.Thread(target=target, args=(loop, task)) + thread.start() + thread.join() + assert task.result() == 42 + "#; + py_run!(gil, sleep, test); + }) +} + +#[test] +fn closed_event_loop() { + let waker = Arc::new(GILOnceCell::new()); + let waker2 = waker.clone(); + let future = poll_fn(move |cx| { + Python::with_gil(|gil| waker2.set(gil, cx.waker().clone()).unwrap()); + Poll::Pending::> + }); + Python::with_gil(|gil| { + let register_waker = Coroutine::new("register_waker".into_py(gil), future).into_py(gil); + let test = r#" + import asyncio + loop = asyncio.new_event_loop() + # register a waker by spawning a task and polling it once, then close the loop + task = loop.create_task(register_waker) + loop.stop() + loop.run_forever() + loop.close() + "#; + py_run!(gil, register_waker, test); + Python::with_gil(|gil| waker.get(gil).unwrap().wake_by_ref()) + }) +}