Skip to content

Commit

Permalink
feat: expose coroutine constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
wyfo committed Dec 1, 2023
1 parent 8de760b commit 9ea2eb4
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 61 deletions.
22 changes: 21 additions & 1 deletion guide/src/async-await.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,4 +96,24 @@ To make a Rust future awaitable in Python, PyO3 defines a [`Coroutine`]({{#PYO3_

Each `coroutine.send` call is translated to `Future::poll` call. If a [`CancelHandle`]({{#PYO3_DOCS_URL}}/pyo3/coroutine/struct.CancelHandle.html) parameter is declared, the exception passed to `coroutine.throw` call is stored in it and can be retrieved with [`CancelHandle::cancelled`]({{#PYO3_DOCS_URL}}/pyo3/coroutine/struct.CancelHandle.html#method.cancelled); otherwise, it cancels the Rust future, and the exception is reraised;

*The type does not yet have a public constructor until the design is finalized.*
Coroutine can also be instantiated directly

```rust
# # ![allow(dead_code)]
use pyo3::prelude::*;
use pyo3::coroutine::{CancelHandle, Coroutine};

#[pyfunction]
fn new_coroutine(py: Python<'_>) -> Coroutine {
let mut cancel = CancelHandle::new();
let throw_callback = cancel.throw_callback();
let future = async move {
cancel.cancelled().await;
PyResult::Ok(())
};
Coroutine::new(pyo3::intern!(py, "my_coro"), future)
.with_qualname_prefix("MyClass")
.with_throw_callback(throw_callback)
.with_allow_threads(true)
}
```
1 change: 1 addition & 0 deletions newsfragments/3613.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Expose `Coroutine` constructor
8 changes: 4 additions & 4 deletions pyo3-macros-backend/src/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -504,13 +504,13 @@ impl<'a> FnSpec<'a> {
};
let mut call = quote! {{
let future = #future;
_pyo3::impl_::coroutine::new_coroutine(
_pyo3::coroutine::Coroutine::new(
_pyo3::intern!(py, stringify!(#python_name)),
#qualname_prefix,
#throw_callback,
#allow_threads,
async move { _pyo3::impl_::wrap::OkWrap::wrap(future.await) },
)
.with_qualname_prefix(#qualname_prefix)
.with_throw_callback(#throw_callback)
.with_allow_threads(#allow_threads)
}};
if cancel_handle.is_some() {
call = quote! {{
Expand Down
64 changes: 37 additions & 27 deletions src/coroutine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use pyo3_macros::{pyclass, pymethods};

use crate::{
coroutine::waker::CoroutineWaker,
exceptions::{PyAttributeError, PyRuntimeError, PyStopIteration},
exceptions::{PyRuntimeError, PyStopIteration},
pyclass::IterNextOutput,
types::PyString,
IntoPy, Py, PyErr, PyObject, PyResult, Python,
Expand All @@ -26,20 +26,19 @@ pub(crate) mod cancel;
mod trio;
pub(crate) mod waker;

use crate::coroutine::cancel::ThrowCallback;
use crate::panic::PanicException;
pub use cancel::CancelHandle;
pub use cancel::{CancelHandle, ThrowCallback};

const COROUTINE_REUSED_ERROR: &str = "cannot reuse already awaited coroutine";

/// Python coroutine wrapping a [`Future`].
#[pyclass(crate = "crate")]
pub struct Coroutine {
name: Option<Py<PyString>>,
future: Option<Pin<Box<dyn Future<Output = PyResult<PyObject>> + Send>>>,
name: Py<PyString>,
qualname_prefix: Option<&'static str>,
throw_callback: Option<ThrowCallback>,
allow_threads: bool,
future: Option<Pin<Box<dyn Future<Output = PyResult<PyObject>> + Send>>>,
waker: Option<Arc<CoroutineWaker>>,
}

Expand All @@ -50,13 +49,7 @@ impl Coroutine {
/// (should always be `None` anyway).
///
/// `Coroutine `throw` drop the wrapped future and reraise the exception passed
pub(crate) fn new<F, T, E>(
name: Option<Py<PyString>>,
qualname_prefix: Option<&'static str>,
throw_callback: Option<ThrowCallback>,
allow_threads: bool,
future: F,
) -> Self
pub fn new<F, T, E>(name: impl Into<Py<PyString>>, future: F) -> Self
where
F: Future<Output = Result<T, E>> + Send + 'static,
T: IntoPy<PyObject>,
Expand All @@ -68,15 +61,36 @@ impl Coroutine {
Ok(obj.into_py(unsafe { Python::assume_gil_acquired() }))
};
Self {
name,
qualname_prefix,
throw_callback,
allow_threads,
future: Some(Box::pin(wrap)),
name: name.into(),
qualname_prefix: None,
throw_callback: None,
allow_threads: false,
waker: None,
}
}

/// Set a prefix for `__qualname__`, which will be joined with a "."
pub fn with_qualname_prefix(mut self, prefix: impl Into<Option<&'static str>>) -> Self {
self.qualname_prefix = prefix.into();
self
}

/// Register a callback for coroutine `throw` method.
///
/// The exception passed to `throw` is then redirected to this callback, notifying the
/// associated [`CancelHandle`], without being reraised.
pub fn with_throw_callback(mut self, callback: impl Into<Option<ThrowCallback>>) -> Self {
self.throw_callback = callback.into();
self
}

/// Release the GIL while polling the future wrapped.
pub fn with_allow_threads(mut self, allow_threads: bool) -> Self {
self.allow_threads = allow_threads;
self
}

fn poll_inner(
&mut self,
py: Python<'_>,
Expand Down Expand Up @@ -151,22 +165,18 @@ pub(crate) fn iter_result(result: IterNextOutput<PyObject, PyObject>) -> PyResul
#[pymethods(crate = "crate")]
impl Coroutine {
#[getter]
fn __name__(&self, py: Python<'_>) -> PyResult<Py<PyString>> {
match &self.name {
Some(name) => Ok(name.clone_ref(py)),
None => Err(PyAttributeError::new_err("__name__")),
}
fn __name__(&self, py: Python<'_>) -> Py<PyString> {
self.name.clone_ref(py)
}

#[getter]
fn __qualname__(&self, py: Python<'_>) -> PyResult<Py<PyString>> {
match (&self.name, &self.qualname_prefix) {
(Some(name), Some(prefix)) => Ok(format!("{}.{}", prefix, name.as_ref(py).to_str()?)
Ok(match &self.qualname_prefix {
Some(prefix) => format!("{}.{}", prefix, self.name.as_ref(py).to_str()?)
.as_str()
.into_py(py)),
(Some(name), None) => Ok(name.clone_ref(py)),
(None, _) => Err(PyAttributeError::new_err("__qualname__")),
}
.into_py(py),
None => self.name.clone_ref(py),
})
}

fn send(&mut self, py: Python<'_>, value: PyObject) -> PyResult<PyObject> {
Expand Down
4 changes: 2 additions & 2 deletions src/coroutine/cancel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ impl CancelHandle {
Cancelled(self).await
}

#[doc(hidden)]
/// Instantiate a [`ThrowCallback`] associated to this cancel handle.
pub fn throw_callback(&self) -> ThrowCallback {
ThrowCallback(self.0.clone())
}
Expand All @@ -71,7 +71,7 @@ impl Future for Cancelled<'_> {
}
}

#[doc(hidden)]
/// Callback for coroutine `throw` method, notifying the associated [`CancelHandle`]
pub struct ThrowCallback(Arc<Inner>);

impl ThrowCallback {
Expand Down
27 changes: 1 addition & 26 deletions src/impl_/coroutine.rs
Original file line number Diff line number Diff line change
@@ -1,33 +1,8 @@
use std::future::Future;
use std::mem;

use crate::coroutine::cancel::ThrowCallback;
use crate::pyclass::boolean_struct::False;
use crate::{
coroutine::Coroutine, types::PyString, IntoPy, Py, PyAny, PyCell, PyClass, PyErr, PyObject,
PyRef, PyRefMut, PyResult, Python,
};

pub fn new_coroutine<F, T, E>(
name: &PyString,
qualname_prefix: Option<&'static str>,
throw_callback: Option<ThrowCallback>,
allow_threads: bool,
future: F,
) -> Coroutine
where
F: Future<Output = Result<T, E>> + Send + 'static,
T: IntoPy<PyObject>,
E: Into<PyErr>,
{
Coroutine::new(
Some(name.into()),
qualname_prefix,
throw_callback,
allow_threads,
future,
)
}
use crate::{Py, PyAny, PyCell, PyClass, PyRef, PyRefMut, PyResult, Python};

fn get_ptr<T: PyClass>(obj: &Py<T>) -> *mut T {
// SAFETY: Py<T> can be casted as *const PyCell<T>
Expand Down
54 changes: 53 additions & 1 deletion tests/test_coroutine.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
#![cfg(feature = "macros")]
#![cfg(not(target_arch = "wasm32"))]
use std::ops::Deref;
use std::sync::Arc;
use std::{task::Poll, thread, time::Duration};

use futures::{channel::oneshot, future::poll_fn, FutureExt};
use pyo3::coroutine::CancelHandle;
use pyo3::coroutine::{CancelHandle, Coroutine};
use pyo3::sync::GILOnceCell;
use pyo3::types::{IntoPyDict, PyType};
use pyo3::{prelude::*, py_run};

Expand Down Expand Up @@ -183,3 +185,53 @@ fn test_async_method_receiver() {
async fn method_mut(&mut self) {}
}
}

#[test]
fn multi_thread_event_loop() {
Python::with_gil(|gil| {
let sleep = wrap_pyfunction!(sleep, gil).unwrap();
let test = r#"
import asyncio
import threading
loop = asyncio.new_event_loop()
# spawn the sleep task and run just one iteration of the event loop
# to schedule the sleep wakeup
task = loop.create_task(sleep(0.1))
loop.stop()
loop.run_forever()
assert not task.done()
# spawn a thread to complete the execution of the sleep task
def target(loop, task):
loop.run_until_complete(task)
thread = threading.Thread(target=target, args=(loop, task))
thread.start()
thread.join()
assert task.result() == 42
"#;
py_run!(gil, sleep, test);
})
}

#[test]
fn closed_event_loop() {
let waker = Arc::new(GILOnceCell::new());
let waker2 = waker.clone();
let future = poll_fn(move |cx| {
Python::with_gil(|gil| waker2.set(gil, cx.waker().clone()).unwrap());
Poll::Pending::<PyResult<()>>
});
Python::with_gil(|gil| {
let register_waker = Coroutine::new("register_waker".into_py(gil), future).into_py(gil);
let test = r#"
import asyncio
loop = asyncio.new_event_loop()
# register a waker by spawning a task and polling it once, then close the loop
task = loop.create_task(register_waker)
loop.stop()
loop.run_forever()
loop.close()
"#;
py_run!(gil, register_waker, test);
Python::with_gil(|gil| waker.get(gil).unwrap().wake_by_ref())
})
}

0 comments on commit 9ea2eb4

Please sign in to comment.