-
Notifications
You must be signed in to change notification settings - Fork 783
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: support anyio with a Cargo feature
Asyncio is the standard and de facto main Python async runtime. Among non-standard runtime, only trio seems to have substantial traction, especially thanks to the anyio project. There is indeed a strong trend for anyio (e.g. FastApi), which can justify a dedicated support.
- Loading branch information
Showing
13 changed files
with
240 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
Support anyio with a Cargo feature |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
use std::{task::Poll, thread, time::Duration}; | ||
|
||
use futures::{channel::oneshot, future::poll_fn}; | ||
use pyo3::prelude::*; | ||
|
||
#[pyfunction] | ||
async fn sleep(seconds: f64, result: Option<PyObject>) -> Option<PyObject> { | ||
if seconds <= 0.0 { | ||
let mut ready = false; | ||
poll_fn(|cx| { | ||
if ready { | ||
return Poll::Ready(()); | ||
} | ||
ready = true; | ||
cx.waker().wake_by_ref(); | ||
Poll::Pending | ||
}) | ||
.await; | ||
} else { | ||
let (tx, rx) = oneshot::channel(); | ||
thread::spawn(move || { | ||
thread::sleep(Duration::from_secs_f64(seconds)); | ||
tx.send(()).unwrap(); | ||
}); | ||
rx.await.unwrap(); | ||
} | ||
result | ||
} | ||
|
||
#[pymodule] | ||
pub fn anyio(_py: Python<'_>, m: &PyModule) -> PyResult<()> { | ||
m.add_function(wrap_pyfunction!(sleep, m)?)?; | ||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import asyncio | ||
|
||
from pyo3_pytests.anyio import sleep | ||
import trio | ||
|
||
|
||
def test_asyncio(): | ||
assert asyncio.run(sleep(0)) is None | ||
assert asyncio.run(sleep(0.1, 42)) == 42 | ||
|
||
|
||
def test_trio(): | ||
assert trio.run(sleep, 0) is None | ||
assert trio.run(sleep, 0.1, 42) == 42 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
//! Coroutine implementation using sniffio to select the appropriate implementation, | ||
//! compatible with anyio. | ||
use crate::{ | ||
coroutine::{asyncio::AsyncioWaker, trio::TrioWaker}, | ||
exceptions::PyRuntimeError, | ||
sync::GILOnceCell, | ||
types::PyAnyMethods, | ||
Bound, PyAny, PyErr, PyObject, PyResult, Python, | ||
}; | ||
|
||
fn current_async_library(py: Python<'_>) -> PyResult<Bound<'_, PyAny>> { | ||
static CURRENT_ASYNC_LIBRARY: GILOnceCell<PyObject> = GILOnceCell::new(); | ||
let import = || -> PyResult<_> { | ||
let module = py.import_bound("sniffio")?; | ||
Ok(module.getattr("current_async_library")?.into()) | ||
}; | ||
CURRENT_ASYNC_LIBRARY | ||
.get_or_try_init(py, import)? | ||
.bind(py) | ||
.call0() | ||
} | ||
|
||
fn unsupported(runtime: &str) -> PyErr { | ||
PyRuntimeError::new_err(format!("unsupported runtime {rt}", rt = runtime)) | ||
} | ||
|
||
/// Sniffio/anyio-compatible coroutine waker. | ||
/// | ||
/// Polling a Rust future calls `sniffio.current_async_library` to select the appropriate | ||
/// implementation, either asyncio or trio. | ||
pub(super) enum AnyioWaker { | ||
/// [`AsyncioWaker`] | ||
Asyncio(AsyncioWaker), | ||
/// [`TrioWaker`] | ||
Trio(TrioWaker), | ||
} | ||
|
||
impl AnyioWaker { | ||
pub(super) fn new(py: Python<'_>) -> PyResult<Self> { | ||
let sniffed = current_async_library(py)?; | ||
match sniffed.extract()? { | ||
"asyncio" => Ok(Self::Asyncio(AsyncioWaker::new(py)?)), | ||
"trio" => Ok(Self::Trio(TrioWaker::new(py)?)), | ||
rt => Err(unsupported(rt)), | ||
} | ||
} | ||
|
||
pub(super) fn yield_(&self, py: Python<'_>) -> PyResult<PyObject> { | ||
match self { | ||
AnyioWaker::Asyncio(w) => w.yield_(py), | ||
AnyioWaker::Trio(w) => w.yield_(py), | ||
} | ||
} | ||
|
||
pub(super) fn yield_waken(py: Python<'_>) -> PyResult<PyObject> { | ||
let sniffed = current_async_library(py)?; | ||
match sniffed.extract()? { | ||
"asyncio" => AsyncioWaker::yield_waken(py), | ||
"trio" => TrioWaker::yield_waken(py), | ||
rt => Err(unsupported(rt)), | ||
} | ||
} | ||
|
||
pub(super) fn wake(&self, py: Python<'_>) -> PyResult<()> { | ||
match self { | ||
AnyioWaker::Asyncio(w) => w.wake(py), | ||
AnyioWaker::Trio(w) => w.wake(py), | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
//! Coroutine implementation compatible with trio. | ||
use pyo3_macros::pyfunction; | ||
|
||
use crate::{ | ||
intern, | ||
sync::GILOnceCell, | ||
types::{PyAnyMethods, PyCFunction, PyIterator}, | ||
wrap_pyfunction_bound, Bound, Py, PyAny, PyObject, PyResult, Python, | ||
}; | ||
|
||
struct Trio { | ||
cancel_shielded_checkpoint: PyObject, | ||
current_task: PyObject, | ||
current_trio_token: PyObject, | ||
reschedule: PyObject, | ||
succeeded: PyObject, | ||
wait_task_rescheduled: PyObject, | ||
} | ||
impl Trio { | ||
fn get(py: Python<'_>) -> PyResult<&Self> { | ||
static TRIO: GILOnceCell<Trio> = GILOnceCell::new(); | ||
TRIO.get_or_try_init(py, || { | ||
let module = py.import_bound("trio.lowlevel")?; | ||
Ok(Self { | ||
cancel_shielded_checkpoint: module.getattr("cancel_shielded_checkpoint")?.into(), | ||
current_task: module.getattr("current_task")?.into(), | ||
current_trio_token: module.getattr("current_trio_token")?.into(), | ||
reschedule: module.getattr("reschedule")?.into(), | ||
succeeded: module.getattr("Abort")?.getattr("SUCCEEDED")?.into(), | ||
wait_task_rescheduled: module.getattr("wait_task_rescheduled")?.into(), | ||
}) | ||
}) | ||
} | ||
} | ||
|
||
fn yield_from(coro_func: &PyAny) -> PyResult<PyObject> { | ||
PyIterator::from_object(coro_func.call_method0("__await__")?)? | ||
.next() | ||
.expect("cancel_shielded_checkpoint didn't yield") | ||
.map(Into::into) | ||
} | ||
|
||
/// Asyncio-compatible coroutine waker. | ||
/// | ||
/// Polling a Rust future yields `trio.lowlevel.wait_task_rescheduled()`, while `Waker::wake` | ||
/// reschedule the current task. | ||
pub(super) struct TrioWaker { | ||
task: PyObject, | ||
token: PyObject, | ||
} | ||
|
||
impl TrioWaker { | ||
pub(super) fn new(py: Python<'_>) -> PyResult<Self> { | ||
let trio = Trio::get(py)?; | ||
let task = trio.current_task.call0(py)?; | ||
let token = trio.current_trio_token.call0(py)?; | ||
Ok(Self { task, token }) | ||
} | ||
|
||
pub(super) fn yield_(&self, py: Python<'_>) -> PyResult<PyObject> { | ||
static ABORT_FUNC: GILOnceCell<Py<PyCFunction>> = GILOnceCell::new(); | ||
let abort_func = ABORT_FUNC.get_or_try_init(py, || { | ||
wrap_pyfunction_bound!(abort_func, py).map(Into::into) | ||
})?; | ||
let wait_task_rescheduled = Trio::get(py)? | ||
.wait_task_rescheduled | ||
.call1(py, (abort_func,))?; | ||
yield_from(wait_task_rescheduled.as_ref(py)) | ||
} | ||
|
||
pub(super) fn yield_waken(py: Python<'_>) -> PyResult<PyObject> { | ||
let checkpoint = Trio::get(py)?.cancel_shielded_checkpoint.call0(py)?; | ||
yield_from(checkpoint.as_ref(py)) | ||
} | ||
|
||
pub(super) fn wake(&self, py: Python<'_>) -> PyResult<()> { | ||
self.token.call_method1( | ||
py, | ||
intern!(py, "run_sync_soon"), | ||
(&Trio::get(py)?.reschedule, &self.task), | ||
)?; | ||
Ok(()) | ||
} | ||
} | ||
|
||
#[pyfunction(crate = "crate")] | ||
fn abort_func(py: Python<'_>, _arg: &Bound<'_, PyAny>) -> PyResult<PyObject> { | ||
Ok(Trio::get(py)?.succeeded.clone()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters