diff --git a/tests/test_coroutine.rs b/tests/test_coroutine.rs index 84326101b57..a24d6bbd142 100644 --- a/tests/test_coroutine.rs +++ b/tests/test_coroutine.rs @@ -7,6 +7,15 @@ use pyo3::{prelude::*, py_run}; #[path = "../src/tests/common.rs"] mod common; +fn handle_windows(test: &str) -> String { + let set_event_loop_policy = r#" + import asyncio, sys + if sys.platform == "win32": + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + "#; + pyo3::unindent::unindent(set_event_loop_policy) + &pyo3::unindent::unindent(test) +} + #[test] fn noop_coroutine() { #[pyfunction] @@ -16,7 +25,7 @@ fn noop_coroutine() { Python::with_gil(|gil| { let noop = wrap_pyfunction!(noop, gil).unwrap(); let test = "import asyncio; assert asyncio.run(noop()) == 42"; - py_run!(gil, noop, test); + py_run!(gil, noop, &handle_windows(test)); }) } @@ -38,7 +47,7 @@ fn sleep_0_like_coroutine() { Python::with_gil(|gil| { let sleep_0 = wrap_pyfunction!(sleep_0, gil).unwrap(); let test = "import asyncio; assert asyncio.run(sleep_0()) == 42"; - py_run!(gil, sleep_0, test); + py_run!(gil, sleep_0, &handle_windows(test)); }) } @@ -56,13 +65,8 @@ async fn sleep(seconds: f64) -> usize { fn sleep_coroutine() { Python::with_gil(|gil| { let sleep = wrap_pyfunction!(sleep, gil).unwrap(); - let test = r#" - import asyncio, sys - if sys.platform == "win32": - asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) - assert asyncio.run(sleep(0.1)) == 42 - "#; - py_run!(gil, sleep, test); + let test = r#"import asyncio; assert asyncio.run(sleep(0.1)) == 42"#; + py_run!(gil, sleep, &handle_windows(test)); }) } @@ -71,9 +75,7 @@ fn cancelled_coroutine() { Python::with_gil(|gil| { let sleep = wrap_pyfunction!(sleep, gil).unwrap(); let test = r#" - import asyncio, sys - if sys.platform == "win32": - asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + import asyncio async def main(): task = asyncio.create_task(sleep(1)) await asyncio.sleep(0) @@ -84,7 +86,11 @@ fn cancelled_coroutine() { let globals = gil.import("__main__").unwrap().dict(); globals.set_item("sleep", sleep).unwrap(); let err = gil - .run(&pyo3::unindent::unindent(test), Some(globals), None) + .run( + &pyo3::unindent::unindent(&handle_windows(test)), + Some(globals), + None, + ) .unwrap_err(); assert_eq!(err.value(gil).get_type().name().unwrap(), "CancelledError"); })