diff --git a/src/body/body.rs b/src/body/body.rs index 3308d3b3bd..939b4f5689 100644 --- a/src/body/body.rs +++ b/src/body/body.rs @@ -11,7 +11,7 @@ use futures_util::TryStreamExt; use http::HeaderMap; use http_body::{Body as HttpBody, SizeHint}; -use crate::common::{task, Future, Never, Pin, Poll}; +use crate::common::{task, watch, Future, Never, Pin, Poll}; use crate::proto::DecodedLength; use crate::upgrade::OnUpgrade; @@ -33,7 +33,7 @@ enum Kind { Once(Option), Chan { content_length: DecodedLength, - abort_rx: oneshot::Receiver<()>, + want_tx: watch::Sender, rx: mpsc::Receiver>, }, H2 { @@ -79,12 +79,14 @@ enum DelayEof { /// Useful when wanting to stream chunks from another thread. See /// [`Body::channel`](Body::channel) for more. #[must_use = "Sender does nothing unless sent on"] -#[derive(Debug)] pub struct Sender { - abort_tx: oneshot::Sender<()>, + want_rx: watch::Receiver, tx: BodySender, } +const WANT_PENDING: usize = 1; +const WANT_READY: usize = 2; + impl Body { /// Create an empty `Body` stream. /// @@ -106,17 +108,22 @@ impl Body { /// Useful when wanting to stream chunks from another thread. #[inline] pub fn channel() -> (Sender, Body) { - Self::new_channel(DecodedLength::CHUNKED) + Self::new_channel(DecodedLength::CHUNKED, /*wanter =*/ false) } - pub(crate) fn new_channel(content_length: DecodedLength) -> (Sender, Body) { + pub(crate) fn new_channel(content_length: DecodedLength, wanter: bool) -> (Sender, Body) { let (tx, rx) = mpsc::channel(0); - let (abort_tx, abort_rx) = oneshot::channel(); - let tx = Sender { abort_tx, tx }; + // If wanter is true, `Sender::poll_ready()` won't becoming ready + // until the `Body` has been polled for data once. + let want = if wanter { WANT_PENDING } else { WANT_READY }; + + let (want_tx, want_rx) = watch::channel(want); + + let tx = Sender { want_rx, tx }; let rx = Body::new(Kind::Chan { content_length, - abort_rx, + want_tx, rx, }); @@ -236,11 +243,9 @@ impl Body { Kind::Chan { content_length: ref mut len, ref mut rx, - ref mut abort_rx, + ref mut want_tx, } => { - if let Poll::Ready(Ok(())) = Pin::new(abort_rx).poll(cx) { - return Poll::Ready(Some(Err(crate::Error::new_body_write_aborted()))); - } + want_tx.send(WANT_READY); match ready!(Pin::new(rx).poll_next(cx)?) { Some(chunk) => { @@ -460,19 +465,29 @@ impl From> for Body { impl Sender { /// Check to see if this `Sender` can send more data. pub fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll> { - match self.abort_tx.poll_canceled(cx) { - Poll::Ready(()) => return Poll::Ready(Err(crate::Error::new_closed())), - Poll::Pending => (), // fallthrough - } - + // Check if the receiver end has tried polling for the body yet + ready!(self.poll_want(cx)?); self.tx .poll_ready(cx) .map_err(|_| crate::Error::new_closed()) } + fn poll_want(&mut self, cx: &mut task::Context<'_>) -> Poll> { + match self.want_rx.load(cx) { + WANT_READY => Poll::Ready(Ok(())), + WANT_PENDING => Poll::Pending, + watch::CLOSED => Poll::Ready(Err(crate::Error::new_closed())), + unexpected => unreachable!("want_rx value: {}", unexpected), + } + } + + async fn ready(&mut self) -> crate::Result<()> { + futures_util::future::poll_fn(|cx| self.poll_ready(cx)).await + } + /// Send data on this channel when it is ready. pub async fn send_data(&mut self, chunk: Bytes) -> crate::Result<()> { - futures_util::future::poll_fn(|cx| self.poll_ready(cx)).await?; + self.ready().await?; self.tx .try_send(Ok(chunk)) .map_err(|_| crate::Error::new_closed()) @@ -498,8 +513,11 @@ impl Sender { /// Aborts the body in an abnormal fashion. pub fn abort(self) { - // TODO(sean): this can just be `self.tx.clone().try_send()` - let _ = self.abort_tx.send(()); + let _ = self + .tx + // clone so the send works even if buffer is full + .clone() + .try_send(Err(crate::Error::new_body_write_aborted())); } pub(crate) fn send_error(&mut self, err: crate::Error) { @@ -507,11 +525,29 @@ impl Sender { } } +impl fmt::Debug for Sender { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + #[derive(Debug)] + struct Open; + #[derive(Debug)] + struct Closed; + + let mut builder = f.debug_tuple("Sender"); + match self.want_rx.peek() { + watch::CLOSED => builder.field(&Closed), + _ => builder.field(&Open), + }; + + builder.finish() + } +} + #[cfg(test)] mod tests { use std::mem; + use std::task::Poll; - use super::{Body, Sender}; + use super::{Body, DecodedLength, HttpBody, Sender}; #[test] fn test_size_of() { @@ -541,4 +577,97 @@ mod tests { "Option" ); } + + #[tokio::test] + async fn channel_abort() { + let (tx, mut rx) = Body::channel(); + + tx.abort(); + + let err = rx.data().await.unwrap().unwrap_err(); + assert!(err.is_body_write_aborted(), "{:?}", err); + } + + #[tokio::test] + async fn channel_abort_when_buffer_is_full() { + let (mut tx, mut rx) = Body::channel(); + + tx.try_send_data("chunk 1".into()).expect("send 1"); + // buffer is full, but can still send abort + tx.abort(); + + let chunk1 = rx.data().await.expect("item 1").expect("chunk 1"); + assert_eq!(chunk1, "chunk 1"); + + let err = rx.data().await.unwrap().unwrap_err(); + assert!(err.is_body_write_aborted(), "{:?}", err); + } + + #[test] + fn channel_buffers_one() { + let (mut tx, _rx) = Body::channel(); + + tx.try_send_data("chunk 1".into()).expect("send 1"); + + // buffer is now full + let chunk2 = tx.try_send_data("chunk 2".into()).expect_err("send 2"); + assert_eq!(chunk2, "chunk 2"); + } + + #[tokio::test] + async fn channel_empty() { + let (_, mut rx) = Body::channel(); + + assert!(rx.data().await.is_none()); + } + + #[test] + fn channel_ready() { + let (mut tx, _rx) = Body::new_channel(DecodedLength::CHUNKED, /*wanter = */ false); + + let mut tx_ready = tokio_test::task::spawn(tx.ready()); + + assert!(tx_ready.poll().is_ready(), "tx is ready immediately"); + } + + #[test] + fn channel_wanter() { + let (mut tx, mut rx) = Body::new_channel(DecodedLength::CHUNKED, /*wanter = */ true); + + let mut tx_ready = tokio_test::task::spawn(tx.ready()); + let mut rx_data = tokio_test::task::spawn(rx.data()); + + assert!( + tx_ready.poll().is_pending(), + "tx isn't ready before rx has been polled" + ); + + assert!(rx_data.poll().is_pending(), "poll rx.data"); + assert!(tx_ready.is_woken(), "rx poll wakes tx"); + + assert!( + tx_ready.poll().is_ready(), + "tx is ready after rx has been polled" + ); + } + + #[test] + fn channel_notices_closure() { + let (mut tx, rx) = Body::new_channel(DecodedLength::CHUNKED, /*wanter = */ true); + + let mut tx_ready = tokio_test::task::spawn(tx.ready()); + + assert!( + tx_ready.poll().is_pending(), + "tx isn't ready before rx has been polled" + ); + + drop(rx); + assert!(tx_ready.is_woken(), "dropping rx wakes tx"); + + match tx_ready.poll() { + Poll::Ready(Err(ref e)) if e.is_closed() => (), + unexpected => panic!("tx poll ready unexpected: {:?}", unexpected), + } + } } diff --git a/src/common/mod.rs b/src/common/mod.rs index 394e549895..3716a56c67 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -14,6 +14,7 @@ pub(crate) mod io; mod lazy; mod never; pub(crate) mod task; +pub(crate) mod watch; pub use self::exec::Executor; pub(crate) use self::exec::{BoxSendFuture, Exec}; diff --git a/src/common/watch.rs b/src/common/watch.rs new file mode 100644 index 0000000000..ba17d551cb --- /dev/null +++ b/src/common/watch.rs @@ -0,0 +1,73 @@ +//! An SPSC broadcast channel. +//! +//! - The value can only be a `usize`. +//! - The consumer is only notified if the value is different. +//! - The value `0` is reserved for closed. + +use futures_util::task::AtomicWaker; +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, +}; +use std::task; + +type Value = usize; + +pub(crate) const CLOSED: usize = 0; + +pub(crate) fn channel(initial: Value) -> (Sender, Receiver) { + debug_assert!( + initial != CLOSED, + "watch::channel initial state of 0 is reserved" + ); + + let shared = Arc::new(Shared { + value: AtomicUsize::new(initial), + waker: AtomicWaker::new(), + }); + + ( + Sender { + shared: shared.clone(), + }, + Receiver { shared }, + ) +} + +pub(crate) struct Sender { + shared: Arc, +} + +pub(crate) struct Receiver { + shared: Arc, +} + +struct Shared { + value: AtomicUsize, + waker: AtomicWaker, +} + +impl Sender { + pub(crate) fn send(&mut self, value: Value) { + if self.shared.value.swap(value, Ordering::SeqCst) != value { + self.shared.waker.wake(); + } + } +} + +impl Drop for Sender { + fn drop(&mut self) { + self.send(CLOSED); + } +} + +impl Receiver { + pub(crate) fn load(&mut self, cx: &mut task::Context<'_>) -> Value { + self.shared.waker.register(cx.waker()); + self.shared.value.load(Ordering::SeqCst) + } + + pub(crate) fn peek(&self) -> Value { + self.shared.value.load(Ordering::Relaxed) + } +} diff --git a/src/proto/h1/conn.rs b/src/proto/h1/conn.rs index 0575536827..8f3532b9d4 100644 --- a/src/proto/h1/conn.rs +++ b/src/proto/h1/conn.rs @@ -8,7 +8,7 @@ use http::{HeaderMap, Method, Version}; use tokio::io::{AsyncRead, AsyncWrite}; use super::io::Buffered; -use super::{/*Decode,*/ Decoder, Encode, EncodedBuf, Encoder, Http1Transaction, ParseContext,}; +use super::{Decoder, Encode, EncodedBuf, Encoder, Http1Transaction, ParseContext, Wants}; use crate::common::{task, Pin, Poll, Unpin}; use crate::headers::connection_keep_alive; use crate::proto::{BodyLength, DecodedLength, MessageHead}; @@ -114,7 +114,7 @@ where pub fn can_read_body(&self) -> bool { match self.state.reading { - Reading::Body(..) => true, + Reading::Body(..) | Reading::Continue(..) => true, _ => false, } } @@ -129,10 +129,10 @@ where read_buf.len() >= 24 && read_buf[..24] == *H2_PREFACE } - pub fn poll_read_head( + pub(super) fn poll_read_head( &mut self, cx: &mut task::Context<'_>, - ) -> Poll, DecodedLength, bool)>>> { + ) -> Poll, DecodedLength, Wants)>>> { debug_assert!(self.can_read_head()); trace!("Conn::read_head"); @@ -156,23 +156,28 @@ where self.state.keep_alive &= msg.keep_alive; self.state.version = msg.head.version; + let mut wants = if msg.wants_upgrade { + Wants::UPGRADE + } else { + Wants::EMPTY + }; + if msg.decode == DecodedLength::ZERO { - if log_enabled!(log::Level::Debug) && msg.expect_continue { + if msg.expect_continue { debug!("ignoring expect-continue since body is empty"); } self.state.reading = Reading::KeepAlive; if !T::should_read_first() { self.try_keep_alive(cx); } + } else if msg.expect_continue { + self.state.reading = Reading::Continue(Decoder::new(msg.decode)); + wants = wants.add(Wants::EXPECT); } else { - if msg.expect_continue { - let cont = b"HTTP/1.1 100 Continue\r\n\r\n"; - self.io.headers_buf().extend_from_slice(cont); - } self.state.reading = Reading::Body(Decoder::new(msg.decode)); - }; + } - Poll::Ready(Some(Ok((msg.head, msg.decode, msg.wants_upgrade)))) + Poll::Ready(Some(Ok((msg.head, msg.decode, wants)))) } fn on_read_head_error(&mut self, e: crate::Error) -> Poll>> { @@ -239,7 +244,19 @@ where } } } - _ => unreachable!("read_body invalid state: {:?}", self.state.reading), + Reading::Continue(ref decoder) => { + // Write the 100 Continue if not already responded... + if let Writing::Init = self.state.writing { + trace!("automatically sending 100 Continue"); + let cont = b"HTTP/1.1 100 Continue\r\n\r\n"; + self.io.headers_buf().extend_from_slice(cont); + } + + // And now recurse once in the Reading::Body state... + self.state.reading = Reading::Body(decoder.clone()); + return self.poll_read_body(cx); + } + _ => unreachable!("poll_read_body invalid state: {:?}", self.state.reading), }; self.state.reading = reading; @@ -346,7 +363,9 @@ where // would finish. match self.state.reading { - Reading::Body(..) | Reading::KeepAlive | Reading::Closed => return, + Reading::Continue(..) | Reading::Body(..) | Reading::KeepAlive | Reading::Closed => { + return + } Reading::Init => (), }; @@ -711,6 +730,7 @@ struct State { #[derive(Debug)] enum Reading { Init, + Continue(Decoder), Body(Decoder), KeepAlive, Closed, diff --git a/src/proto/h1/dispatch.rs b/src/proto/h1/dispatch.rs index 07d05fca16..ff5bf01832 100644 --- a/src/proto/h1/dispatch.rs +++ b/src/proto/h1/dispatch.rs @@ -4,7 +4,7 @@ use bytes::{Buf, Bytes}; use http::{Request, Response, StatusCode}; use tokio::io::{AsyncRead, AsyncWrite}; -use super::Http1Transaction; +use super::{Http1Transaction, Wants}; use crate::body::{Body, Payload}; use crate::common::{task, Future, Never, Pin, Poll, Unpin}; use crate::proto::{ @@ -235,16 +235,16 @@ where } // dispatch is ready for a message, try to read one match ready!(self.conn.poll_read_head(cx)) { - Some(Ok((head, body_len, wants_upgrade))) => { + Some(Ok((head, body_len, wants))) => { let mut body = match body_len { DecodedLength::ZERO => Body::empty(), other => { - let (tx, rx) = Body::new_channel(other); + let (tx, rx) = Body::new_channel(other, wants.contains(Wants::EXPECT)); self.body_tx = Some(tx); rx } }; - if wants_upgrade { + if wants.contains(Wants::UPGRADE) { body.set_on_upgrade(self.conn.on_upgrade()); } self.dispatch.recv_msg(Ok((head, body)))?; diff --git a/src/proto/h1/mod.rs b/src/proto/h1/mod.rs index 39efb8e7b8..2d0bf39bc9 100644 --- a/src/proto/h1/mod.rs +++ b/src/proto/h1/mod.rs @@ -74,3 +74,22 @@ pub(crate) struct Encode<'a, T> { req_method: &'a mut Option, title_case_headers: bool, } + +/// Extra flags that a request "wants", like expect-continue or upgrades. +#[derive(Clone, Copy, Debug)] +struct Wants(u8); + +impl Wants { + const EMPTY: Wants = Wants(0b00); + const EXPECT: Wants = Wants(0b01); + const UPGRADE: Wants = Wants(0b10); + + #[must_use] + fn add(self, other: Wants) -> Wants { + Wants(self.0 | other.0) + } + + fn contains(&self, other: Wants) -> bool { + (self.0 & other.0) == other.0 + } +} diff --git a/tests/server.rs b/tests/server.rs index 054dfc8aaf..59ef0f6fee 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -785,6 +785,57 @@ fn expect_continue_but_no_body_is_ignored() { assert_eq!(&resp[..expected.len()], expected); } +#[tokio::test] +async fn expect_continue_waits_for_body_poll() { + let _ = pretty_env_logger::try_init(); + let mut listener = tcp_bind(&"127.0.0.1:0".parse().unwrap()).unwrap(); + let addr = listener.local_addr().unwrap(); + + let child = thread::spawn(move || { + let mut tcp = connect(&addr); + + tcp.write_all( + b"\ + POST /foo HTTP/1.1\r\n\ + Host: example.domain\r\n\ + Expect: 100-continue\r\n\ + Content-Length: 100\r\n\ + Connection: Close\r\n\ + \r\n\ + ", + ) + .expect("write"); + + let expected = "HTTP/1.1 400 Bad Request\r\n"; + let mut resp = String::new(); + tcp.read_to_string(&mut resp).expect("read"); + + assert_eq!(&resp[..expected.len()], expected); + }); + + let (socket, _) = listener.accept().await.expect("accept"); + + Http::new() + .serve_connection( + socket, + service_fn(|req| { + assert_eq!(req.headers()["expect"], "100-continue"); + // But! We're never going to poll the body! + tokio::time::delay_for(Duration::from_millis(50)).map(move |_| { + // Move and drop the req, so we don't auto-close + drop(req); + Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(hyper::Body::empty()) + }) + }), + ) + .await + .expect("serve_connection"); + + child.join().expect("client thread"); +} + #[test] fn pipeline_disabled() { let server = serve();