Skip to content

Commit

Permalink
Keep asyncio task ref when running callbacks (#441)
Browse files Browse the repository at this point in the history
  • Loading branch information
gi0baro authored Nov 22, 2024
1 parent 9ce9551 commit 9c7b33c
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 1 deletion.
13 changes: 12 additions & 1 deletion src/asgi/callbacks.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use pyo3::prelude::*;
use pyo3::types::PyDict;
use std::{net::SocketAddr, sync::Arc};
use std::{
net::SocketAddr,
sync::{Arc, Mutex},
};
use tokio::sync::oneshot;

use super::{
Expand Down Expand Up @@ -117,6 +120,7 @@ pub(crate) struct CallbackWrappedRunnerHTTP {
cb: PyObject,
#[pyo3(get)]
scope: PyObject,
pytaskref: Arc<Mutex<Option<PyObject>>>,
}

impl CallbackWrappedRunnerHTTP {
Expand All @@ -126,6 +130,7 @@ impl CallbackWrappedRunnerHTTP {
context: cb.context,
cb: cb.callback.clone_ref(py),
scope: scope.into_py(py),
pytaskref: Arc::new(Mutex::new(None)),
}
}

Expand All @@ -140,10 +145,12 @@ impl CallbackWrappedRunnerHTTP {

fn done(&self) {
callback_impl_done_http!(self);
self.pytaskref.lock().unwrap().take();
}

fn err(&self, err: Bound<PyAny>) {
callback_impl_done_err!(self, &PyErr::from_value_bound(err));
self.pytaskref.lock().unwrap().take();
}
}

Expand Down Expand Up @@ -238,6 +245,7 @@ pub(crate) struct CallbackWrappedRunnerWebsocket {
cb: PyObject,
#[pyo3(get)]
scope: PyObject,
pytaskref: Arc<Mutex<Option<PyObject>>>,
}

impl CallbackWrappedRunnerWebsocket {
Expand All @@ -247,6 +255,7 @@ impl CallbackWrappedRunnerWebsocket {
context: cb.context,
cb: cb.callback.clone_ref(py),
scope: scope.into_py(py),
pytaskref: Arc::new(Mutex::new(None)),
}
}

Expand All @@ -261,10 +270,12 @@ impl CallbackWrappedRunnerWebsocket {

fn done(&self) {
callback_impl_done_ws!(self);
self.pytaskref.lock().unwrap().take();
}

fn err(&self, err: Bound<PyAny>) {
callback_impl_done_err!(self, &PyErr::from_value_bound(err));
self.pytaskref.lock().unwrap().take();
}
}

Expand Down
5 changes: 5 additions & 0 deletions src/callbacks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,10 +300,15 @@ macro_rules! callback_impl_run {
macro_rules! callback_impl_run_pytask {
() => {
pub fn run(self, py: Python<'_>) -> PyResult<Bound<PyAny>> {
let taskref = self.pytaskref.clone();
let event_loop = self.context.event_loop(py);
let context = self.context.context(py);
let target = self.into_py(py).getattr(py, pyo3::intern!(py, "_loop_task"))?;
let kwctx = pyo3::types::PyDict::new_bound(py);
{
let mut taskref_guard = taskref.lock().unwrap();
*taskref_guard = Some(target.clone_ref(py));
}
kwctx.set_item(pyo3::intern!(py, "context"), context)?;
event_loop.call_method(pyo3::intern!(py, "call_soon_threadsafe"), (target,), Some(&kwctx))
}
Expand Down
9 changes: 9 additions & 0 deletions src/rsgi/callbacks.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use pyo3::prelude::*;
use std::sync::{Arc, Mutex};
use tokio::sync::oneshot;

use super::{
Expand Down Expand Up @@ -114,6 +115,7 @@ pub(crate) struct CallbackWrappedRunnerHTTP {
cb: PyObject,
#[pyo3(get)]
scope: PyObject,
pytaskref: Arc<Mutex<Option<PyObject>>>,
}

impl CallbackWrappedRunnerHTTP {
Expand All @@ -123,6 +125,7 @@ impl CallbackWrappedRunnerHTTP {
context: cb.context,
cb: cb.callback.clone_ref(py),
scope: scope.into_py(py),
pytaskref: Arc::new(Mutex::new(None)),
}
}

Expand All @@ -137,10 +140,12 @@ impl CallbackWrappedRunnerHTTP {

fn done(&self) {
callback_impl_done_http!(self);
self.pytaskref.lock().unwrap().take();
}

fn err(&self, err: Bound<PyAny>) {
callback_impl_done_err!(self, &PyErr::from_value_bound(err));
self.pytaskref.lock().unwrap().take();
}
}

Expand Down Expand Up @@ -233,6 +238,7 @@ pub(crate) struct CallbackWrappedRunnerWebsocket {
cb: PyObject,
#[pyo3(get)]
scope: PyObject,
pytaskref: Arc<Mutex<Option<PyObject>>>,
}

impl CallbackWrappedRunnerWebsocket {
Expand All @@ -242,6 +248,7 @@ impl CallbackWrappedRunnerWebsocket {
context: cb.context,
cb: cb.callback.clone_ref(py),
scope: scope.into_py(py),
pytaskref: Arc::new(Mutex::new(None)),
}
}

Expand All @@ -256,10 +263,12 @@ impl CallbackWrappedRunnerWebsocket {

fn done(&self) {
callback_impl_done_ws!(self);
self.pytaskref.lock().unwrap().take();
}

fn err(&self, err: Bound<PyAny>) {
callback_impl_done_err!(self, &PyErr::from_value_bound(err));
self.pytaskref.lock().unwrap().take();
}
}

Expand Down

0 comments on commit 9c7b33c

Please sign in to comment.