Skip to content

Commit

Permalink
Enhance WSGI iterators (#332)
Browse files Browse the repository at this point in the history
  • Loading branch information
gi0baro authored Jun 18, 2024
1 parent b9a92f9 commit 749bc24
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 53 deletions.
16 changes: 15 additions & 1 deletion granian/wsgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,20 @@ def __call__(self, status: str, headers: List[Tuple[str, str]], exc_info: Any =
self.headers = headers


class ResponseIterWrap:
__slots__ = ['inner', 'iter']

def __init__(self, inner):
self.inner = inner
self.iter = iter(inner)

def __next__(self):
return self.iter.__next__()

def close(self):
self.inner.close()


def _callback_wrapper(callback: Callable[..., Any], scope_opts: Dict[str, Any], access_log_fmt=None):
basic_env: Dict[str, Any] = dict(os.environ)
basic_env.update(
Expand All @@ -44,7 +58,7 @@ def _runner(scope) -> Tuple[int, List[Tuple[str, str]], int, bytes]:
rv = b''.join(rv)
else:
resp_type = 1
rv = iter(rv)
rv = ResponseIterWrap(rv)

return (resp.status, resp.headers, resp_type, rv)

Expand Down
21 changes: 2 additions & 19 deletions src/conversion.rs
Original file line number Diff line number Diff line change
@@ -1,37 +1,20 @@
use pyo3::prelude::*;
use std::ops::{Deref, DerefMut};

use crate::workers::{HTTP1Config, HTTP2Config};

pub(crate) struct BytesToPy(pub hyper::body::Bytes);

impl Deref for BytesToPy {
type Target = hyper::body::Bytes;

#[inline]
fn deref(&self) -> &Self::Target {
&self.0
}
}

impl DerefMut for BytesToPy {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}

impl IntoPy<PyObject> for BytesToPy {
#[inline]
fn into_py(self, py: Python) -> PyObject {
(&self[..]).into_py(py)
self.0.as_ref().into_py(py)
}
}

impl ToPyObject for BytesToPy {
#[inline]
fn to_object(&self, py: Python<'_>) -> PyObject {
(&self[..]).into_py(py)
self.0.as_ref().into_py(py)
}
}

Expand Down
10 changes: 3 additions & 7 deletions src/wsgi/callbacks.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use futures::TryStreamExt;
use http_body_util::combinators::BoxBody;
use http_body_util::BodyExt;
use hyper::{
Expand Down Expand Up @@ -107,12 +106,9 @@ fn run_callback(
.map_err(|e| match e {})
.boxed()
}
WSGI_ITER_RESPONSE_BODY => {
let body = http_body_util::StreamBody::new(
WSGIResponseBodyIter::new(pybody).map_ok(|v| body::Frame::data(Bytes::from(v))),
);
BodyExt::boxed(BodyExt::map_err(body, |e| match e {}))
}
WSGI_ITER_RESPONSE_BODY => BodyExt::boxed(http_body_util::StreamBody::new(tokio_stream::iter(
WSGIResponseBodyIter::new(pybody),
))),
_ => empty_body(),
};
Ok((status, headers, body))
Expand Down
71 changes: 45 additions & 26 deletions src/wsgi/types.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
use futures::{Stream, StreamExt};
use futures::StreamExt;
use http_body_util::BodyExt;
use hyper::body::{self, Bytes};
use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyList};
use std::borrow::Cow;
use std::sync::{Arc, Mutex};
use std::{
borrow::Cow,
convert::Infallible,
task::{Context, Poll},
};
use tokio::sync::Mutex as AsyncMutex;
use tokio_util::bytes::BytesMut;
use tokio_util::bytes::{BufMut, BytesMut};

use crate::conversion::BytesToPy;
use crate::runtime::RuntimeRef;
Expand Down Expand Up @@ -39,19 +35,26 @@ impl WSGIBody {
}
}

#[allow(clippy::await_holding_lock)]
async fn fill_buffer(
stream: Arc<AsyncMutex<http_body_util::BodyStream<body::Incoming>>>,
buffer: Arc<Mutex<BytesMut>>,
buffering: WSGIBodyBuffering,
) {
let mut buffer = buffer.lock().unwrap();
if let WSGIBodyBuffering::Size(size) = buffering {
if buffer.len() >= size {
return;
}
}

let mut stream = stream.lock().await;
loop {
if let Some(chunk) = stream.next().await {
let data = chunk
.map(|buf| buf.into_data().unwrap_or_default())
.unwrap_or(Bytes::new());
let mut buffer = buffer.lock().unwrap();
buffer.extend_from_slice(data.as_ref());
buffer.put(data);
match buffering {
WSGIBodyBuffering::Line => {
if !buffer.contains(&LINE_SPLIT) {
Expand All @@ -72,21 +75,20 @@ impl WSGIBody {
#[allow(clippy::map_unwrap_or)]
fn _readline(&self, py: Python) -> Bytes {
let inner = self.inner.clone();
let buffer = self.buffer.clone();
py.allow_threads(|| {
self.rt.inner.block_on(async move {
WSGIBody::fill_buffer(inner, buffer, WSGIBodyBuffering::Line).await;
WSGIBody::fill_buffer(inner, self.buffer.clone(), WSGIBodyBuffering::Line).await;
});
});

let mut buffer = self.buffer.lock().unwrap();
buffer
.iter()
.position(|&c| c == LINE_SPLIT)
.map(|next_split| buffer.split_to(next_split).into())
.map(|next_split| buffer.split_to(next_split).freeze())
.unwrap_or_else(|| {
let len = buffer.len();
buffer.split_to(len).into()
buffer.split_to(len).freeze()
})
}
}
Expand Down Expand Up @@ -124,22 +126,24 @@ impl WSGIBody {
0 => BytesToPy(Bytes::new()),
size => {
let inner = self.inner.clone();
let buffer = self.buffer.clone();
py.allow_threads(|| {
self.rt.inner.block_on(async move {
WSGIBody::fill_buffer(inner, buffer, WSGIBodyBuffering::Size(size)).await;
WSGIBody::fill_buffer(inner, self.buffer.clone(), WSGIBodyBuffering::Size(size)).await;
});
});

let mut buffer = self.buffer.lock().unwrap();
let limit = buffer.len();
let rsize = if size > limit { limit } else { size };
BytesToPy(buffer.split_to(rsize).into())
let data = buffer.split_to(rsize).freeze();
BytesToPy(data)
}
},
}
}

fn readline(&self, py: Python) -> BytesToPy {
#[pyo3(signature = (_size=None))]
fn readline(&self, py: Python, _size: Option<usize>) -> BytesToPy {
BytesToPy(self._readline(py))
}

Expand All @@ -164,26 +168,42 @@ impl WSGIBody {

pub(crate) struct WSGIResponseBodyIter {
inner: PyObject,
closed: bool,
}

impl WSGIResponseBodyIter {
pub fn new(body: PyObject) -> Self {
Self { inner: body }
Self {
inner: body,
closed: false,
}
}

#[inline]
fn close_inner(&self, py: Python) {
fn close_inner(&mut self, py: Python) {
let _ = self.inner.call_method0(py, pyo3::intern!(py, "close"));
self.closed = true;
}
}

impl Stream for WSGIResponseBodyIter {
type Item = Result<Box<[u8]>, Infallible>;
impl Drop for WSGIResponseBodyIter {
fn drop(&mut self) {
if !self.closed {
Python::with_gil(|py| self.close_inner(py));
}
}
}

fn poll_next(self: std::pin::Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let ret = Python::with_gil(|py| match self.inner.call_method0(py, pyo3::intern!(py, "__next__")) {
impl Iterator for WSGIResponseBodyIter {
type Item = Result<body::Frame<Bytes>, anyhow::Error>;

fn next(&mut self) -> Option<Self::Item> {
Python::with_gil(|py| match self.inner.call_method0(py, pyo3::intern!(py, "__next__")) {
Ok(chunk_obj) => match chunk_obj.extract::<Cow<[u8]>>(py) {
Ok(chunk) => Some(Ok(chunk.into())),
Ok(chunk) => {
let chunk: Box<[u8]> = chunk.into();
Some(Ok(body::Frame::data(Bytes::from(chunk))))
}
_ => {
self.close_inner(py);
None
Expand All @@ -196,7 +216,6 @@ impl Stream for WSGIResponseBodyIter {
self.close_inner(py);
None
}
});
Poll::Ready(ret)
})
}
}

0 comments on commit 749bc24

Please sign in to comment.