diff --git a/pyo3-derive-backend/src/defs.rs b/pyo3-derive-backend/src/defs.rs index e1990a158f0..c607156858b 100644 --- a/pyo3-derive-backend/src/defs.rs +++ b/pyo3-derive-backend/src/defs.rs @@ -156,18 +156,21 @@ pub const ASYNC: Proto = Proto { slot_table: "pyo3::ffi::PyAsyncMethods", set_slot_table: "set_async_methods", methods: &[ - MethodProto::Unary { + MethodProto::UnaryS { name: "__await__", + arg: "Receiver", pyres: true, proto: "pyo3::class::pyasync::PyAsyncAwaitProtocol", }, - MethodProto::Unary { + MethodProto::UnaryS { name: "__aiter__", + arg: "Receiver", pyres: true, proto: "pyo3::class::pyasync::PyAsyncAiterProtocol", }, - MethodProto::Unary { + MethodProto::UnaryS { name: "__anext__", + arg: "Receiver", pyres: true, proto: "pyo3::class::pyasync::PyAsyncAnextProtocol", }, diff --git a/src/class/macros.rs b/src/class/macros.rs index a4500d1fa32..a009b5ebcf1 100644 --- a/src/class/macros.rs +++ b/src/class/macros.rs @@ -34,8 +34,9 @@ macro_rules! py_unarys_func { { $crate::callback_body!(py, { let slf = py.from_borrowed_ptr::<$crate::PyCell>(slf); - let borrow = ::try_from_pycell(slf) - .map_err(|e| e.into())?; + let borrow = + >::try_from_pycell(slf) + .map_err(|e| e.into())?; $class::$f(borrow).into()$(.map($conv))? }) diff --git a/src/class/pyasync.rs b/src/class/pyasync.rs index 71f569385cc..ce824bea284 100644 --- a/src/class/pyasync.rs +++ b/src/class/pyasync.rs @@ -8,6 +8,7 @@ //! [PEP-0492](https://www.python.org/dev/peps/pep-0492/) //! +use crate::derive_utils::TryFromPyCell; use crate::err::PyResult; use crate::{ffi, PyClass, PyObject}; @@ -16,21 +17,21 @@ use crate::{ffi, PyClass, PyObject}; /// Each method in this trait corresponds to Python async/await implementation. #[allow(unused_variables)] pub trait PyAsyncProtocol<'p>: PyClass { - fn __await__(&'p self) -> Self::Result + fn __await__(slf: Self::Receiver) -> Self::Result where Self: PyAsyncAwaitProtocol<'p>, { unimplemented!() } - fn __aiter__(&'p self) -> Self::Result + fn __aiter__(slf: Self::Receiver) -> Self::Result where Self: PyAsyncAiterProtocol<'p>, { unimplemented!() } - fn __anext__(&'p mut self) -> Self::Result + fn __anext__(slf: Self::Receiver) -> Self::Result where Self: PyAsyncAnextProtocol<'p>, { @@ -58,16 +59,19 @@ pub trait PyAsyncProtocol<'p>: PyClass { } pub trait PyAsyncAwaitProtocol<'p>: PyAsyncProtocol<'p> { + type Receiver: TryFromPyCell<'p, Self>; type Success: crate::IntoPy; type Result: Into>; } pub trait PyAsyncAiterProtocol<'p>: PyAsyncProtocol<'p> { + type Receiver: TryFromPyCell<'p, Self>; type Success: crate::IntoPy; type Result: Into>; } pub trait PyAsyncAnextProtocol<'p>: PyAsyncProtocol<'p> { + type Receiver: TryFromPyCell<'p, Self>; type Success: crate::IntoPy; type Result: Into>>; } @@ -90,13 +94,13 @@ impl ffi::PyAsyncMethods { where T: for<'p> PyAsyncAwaitProtocol<'p>, { - self.am_await = py_unary_func!(PyAsyncAwaitProtocol, T::__await__); + self.am_await = py_unarys_func!(PyAsyncAwaitProtocol, T::__await__); } pub fn set_aiter(&mut self) where T: for<'p> PyAsyncAiterProtocol<'p>, { - self.am_aiter = py_unary_func!(PyAsyncAiterProtocol, T::__aiter__); + self.am_aiter = py_unarys_func!(PyAsyncAiterProtocol, T::__aiter__); } pub fn set_anext(&mut self) where @@ -123,7 +127,9 @@ mod anext { fn convert(self, py: Python) -> PyResult<*mut ffi::PyObject> { match self.0 { Some(val) => Ok(val.into_py(py).into_ptr()), - None => Err(crate::exceptions::StopAsyncIteration::py_err(())), + None => Err(crate::exceptions::StopAsyncIteration::py_err( + "Task Completed", + )), } } } @@ -133,12 +139,6 @@ mod anext { where T: for<'p> PyAsyncAnextProtocol<'p>, { - py_unary_func!( - PyAsyncAnextProtocol, - T::__anext__, - call_mut, - *mut crate::ffi::PyObject, - IterANextOutput - ) + py_unarys_func!(PyAsyncAnextProtocol, T::__anext__, IterANextOutput) } } diff --git a/tests/test_dunder.rs b/tests/test_dunder.rs index 421873f0f4b..a2fa3642f7d 100644 --- a/tests/test_dunder.rs +++ b/tests/test_dunder.rs @@ -1,5 +1,6 @@ use pyo3::class::{ - PyContextProtocol, PyIterProtocol, PyMappingProtocol, PyObjectProtocol, PySequenceProtocol, + PyAsyncProtocol, PyContextProtocol, PyIterProtocol, PyMappingProtocol, PyObjectProtocol, + PySequenceProtocol, }; use pyo3::exceptions::{IndexError, ValueError}; use pyo3::prelude::*; @@ -552,3 +553,63 @@ fn getattr_doesnt_override_member() { py_assert!(py, inst, "inst.data == 4"); py_assert!(py, inst, "inst.a == 8"); } + +/// Wraps a Python future and yield it once. +#[pyclass] +struct OnceFuture { + future: PyObject, + polled: bool, +} + +#[pymethods] +impl OnceFuture { + #[new] + fn new(future: PyObject) -> Self { + OnceFuture { + future, + polled: false, + } + } +} + +#[pyproto] +impl PyAsyncProtocol for OnceFuture { + fn __await__(slf: PyRef) -> PyResult> { + Ok(slf.into()) + } +} + +#[pyproto] +impl PyIterProtocol for OnceFuture { + fn __iter__(slf: PyRef) -> PyResult> { + Ok(slf.into()) + } + fn __next__(mut slf: PyRefMut) -> PyResult> { + if !slf.polled { + slf.polled = true; + Ok(Some(slf.future.clone())) + } else { + Ok(None) + } + } +} + +#[test] +fn test_await() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let once = py.get_type::(); + let source = pyo3::indoc::indoc!( + r#" +import asyncio +async def main(): + res = await Once(await asyncio.sleep(0.1)) + return res +loop = asyncio.get_event_loop() +assert loop.run_until_complete(main()) is None +loop.close() +"# + ); + let globals = [("Once", once)].into_py_dict(py); + py.run(source, Some(globals), None).unwrap(); +}