From 5442b6faddaff9aeb7c073031a3b7aa4497fda4d Mon Sep 17 00:00:00 2001 From: Anthony Ramine <123095+nox@users.noreply.github.com> Date: Mon, 24 May 2021 20:20:44 +0200 Subject: [PATCH] feat(http2): Implement Client-side CONNECT support over HTTP/2 (#2523) Closes #2508 --- Cargo.toml | 2 +- src/body/length.rs | 11 ++ src/client/client.rs | 7 +- src/error.rs | 6 +- src/proto/h2/client.rs | 121 +++++++++++---- src/proto/h2/mod.rs | 208 ++++++++++++++++++++++--- src/proto/h2/server.rs | 88 +++++++++-- src/upgrade.rs | 9 +- tests/client.rs | 123 ++++++++++++++- tests/server.rs | 336 +++++++++++++++++++++++++++++++++++++++++ 10 files changed, 833 insertions(+), 78 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 4a94f06aa9..93624a1ca6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,7 +31,7 @@ http = "0.2" http-body = "0.4" httpdate = "1.0" httparse = "1.4" -h2 = { version = "0.3", optional = true } +h2 = { version = "0.3.3", optional = true } itoa = "0.4.1" tracing = { version = "0.1", default-features = false, features = ["std"] } pin-project = "1.0" diff --git a/src/body/length.rs b/src/body/length.rs index aa9cf3dcd5..633a911fb2 100644 --- a/src/body/length.rs +++ b/src/body/length.rs @@ -3,6 +3,17 @@ use std::fmt; #[derive(Clone, Copy, PartialEq, Eq)] pub(crate) struct DecodedLength(u64); +#[cfg(any(feature = "http1", feature = "http2"))] +impl From> for DecodedLength { + fn from(len: Option) -> Self { + len.and_then(|len| { + // If the length is u64::MAX, oh well, just reported chunked. + Self::checked_new(len).ok() + }) + .unwrap_or(DecodedLength::CHUNKED) + } +} + #[cfg(any(feature = "http1", feature = "http2", test))] const MAX_LEN: u64 = std::u64::MAX - 2; diff --git a/src/client/client.rs b/src/client/client.rs index 3b6a0d9f31..a5d8dcfaf7 100644 --- a/src/client/client.rs +++ b/src/client/client.rs @@ -254,12 +254,9 @@ where absolute_form(req.uri_mut()); } else { origin_form(req.uri_mut()); - }; + } } else if req.method() == Method::CONNECT { - debug!("client does not support CONNECT requests over HTTP2"); - return Err(ClientError::Normal( - crate::Error::new_user_unsupported_request_method(), - )); + authority_form(req.uri_mut()); } let fut = pooled diff --git a/src/error.rs b/src/error.rs index 6912cd3a70..dd577b99a6 100644 --- a/src/error.rs +++ b/src/error.rs @@ -90,7 +90,7 @@ pub(super) enum User { /// User tried to send a certain header in an unexpected context. /// /// For example, sending both `content-length` and `transfer-encoding`. - #[cfg(feature = "http1")] + #[cfg(any(feature = "http1", feature = "http2"))] #[cfg(feature = "server")] UnexpectedHeader, /// User tried to create a Request with bad version. @@ -290,7 +290,7 @@ impl Error { Error::new(Kind::User(user)) } - #[cfg(feature = "http1")] + #[cfg(any(feature = "http1", feature = "http2"))] #[cfg(feature = "server")] pub(super) fn new_user_header() -> Error { Error::new_user(User::UnexpectedHeader) @@ -405,7 +405,7 @@ impl Error { Kind::User(User::MakeService) => "error from user's MakeService", #[cfg(any(feature = "http1", feature = "http2"))] Kind::User(User::Service) => "error from user's Service", - #[cfg(feature = "http1")] + #[cfg(any(feature = "http1", feature = "http2"))] #[cfg(feature = "server")] Kind::User(User::UnexpectedHeader) => "user sent unexpected header", #[cfg(any(feature = "http1", feature = "http2"))] diff --git a/src/proto/h2/client.rs b/src/proto/h2/client.rs index 6d310e94c9..3692a8f253 100644 --- a/src/proto/h2/client.rs +++ b/src/proto/h2/client.rs @@ -2,17 +2,21 @@ use std::error::Error as StdError; #[cfg(feature = "runtime")] use std::time::Duration; +use bytes::Bytes; use futures_channel::{mpsc, oneshot}; use futures_util::future::{self, Either, FutureExt as _, TryFutureExt as _}; use futures_util::stream::StreamExt as _; use h2::client::{Builder, SendRequest}; +use http::{Method, StatusCode}; use tokio::io::{AsyncRead, AsyncWrite}; -use super::{decode_content_length, ping, PipeToSendStream, SendBuf}; +use super::{ping, H2Upgraded, PipeToSendStream, SendBuf}; use crate::body::HttpBody; use crate::common::{exec::Exec, task, Future, Never, Pin, Poll}; use crate::headers; +use crate::proto::h2::UpgradedSendStream; use crate::proto::Dispatched; +use crate::upgrade::Upgraded; use crate::{Body, Request, Response}; type ClientRx = crate::client::dispatch::Receiver, Response>; @@ -233,8 +237,25 @@ where headers::set_content_length_if_missing(req.headers_mut(), len); } } + + let is_connect = req.method() == Method::CONNECT; let eos = body.is_end_stream(); - let (fut, body_tx) = match self.h2_tx.send_request(req, eos) { + let ping = self.ping.clone(); + + if is_connect { + if headers::content_length_parse_all(req.headers()) + .map_or(false, |len| len != 0) + { + warn!("h2 connect request with non-zero body not supported"); + cb.send(Err(( + crate::Error::new_h2(h2::Reason::INTERNAL_ERROR.into()), + None, + ))); + continue; + } + } + + let (fut, body_tx) = match self.h2_tx.send_request(req, !is_connect && eos) { Ok(ok) => ok, Err(err) => { debug!("client send request error: {}", err); @@ -243,45 +264,81 @@ where } }; - let ping = self.ping.clone(); - if !eos { - let mut pipe = Box::pin(PipeToSendStream::new(body, body_tx)).map(|res| { - if let Err(e) = res { - debug!("client request body error: {}", e); - } - }); - - // eagerly see if the body pipe is ready and - // can thus skip allocating in the executor - match Pin::new(&mut pipe).poll(cx) { - Poll::Ready(_) => (), - Poll::Pending => { - let conn_drop_ref = self.conn_drop_ref.clone(); - // keep the ping recorder's knowledge of an - // "open stream" alive while this body is - // still sending... - let ping = ping.clone(); - let pipe = pipe.map(move |x| { - drop(conn_drop_ref); - drop(ping); - x + let send_stream = if !is_connect { + if !eos { + let mut pipe = + Box::pin(PipeToSendStream::new(body, body_tx)).map(|res| { + if let Err(e) = res { + debug!("client request body error: {}", e); + } }); - self.executor.execute(pipe); + + // eagerly see if the body pipe is ready and + // can thus skip allocating in the executor + match Pin::new(&mut pipe).poll(cx) { + Poll::Ready(_) => (), + Poll::Pending => { + let conn_drop_ref = self.conn_drop_ref.clone(); + // keep the ping recorder's knowledge of an + // "open stream" alive while this body is + // still sending... + let ping = ping.clone(); + let pipe = pipe.map(move |x| { + drop(conn_drop_ref); + drop(ping); + x + }); + self.executor.execute(pipe); + } } } - } + + None + } else { + Some(body_tx) + }; let fut = fut.map(move |result| match result { Ok(res) => { // record that we got the response headers ping.record_non_data(); - let content_length = decode_content_length(res.headers()); - let res = res.map(|stream| { - let ping = ping.for_stream(&stream); - crate::Body::h2(stream, content_length, ping) - }); - Ok(res) + let content_length = headers::content_length_parse_all(res.headers()); + if let (Some(mut send_stream), StatusCode::OK) = + (send_stream, res.status()) + { + if content_length.map_or(false, |len| len != 0) { + warn!("h2 connect response with non-zero body not supported"); + + send_stream.send_reset(h2::Reason::INTERNAL_ERROR); + return Err(( + crate::Error::new_h2(h2::Reason::INTERNAL_ERROR.into()), + None, + )); + } + let (parts, recv_stream) = res.into_parts(); + let mut res = Response::from_parts(parts, Body::empty()); + + let (pending, on_upgrade) = crate::upgrade::pending(); + let io = H2Upgraded { + ping, + send_stream: unsafe { UpgradedSendStream::new(send_stream) }, + recv_stream, + buf: Bytes::new(), + }; + let upgraded = Upgraded::new(io, Bytes::new()); + + pending.fulfill(upgraded); + res.extensions_mut().insert(on_upgrade); + + Ok(res) + } else { + let res = res.map(|stream| { + let ping = ping.for_stream(&stream); + crate::Body::h2(stream, content_length.into(), ping) + }); + Ok(res) + } } Err(err) => { ping.ensure_not_timed_out().map_err(|e| (e, None))?; diff --git a/src/proto/h2/mod.rs b/src/proto/h2/mod.rs index cf06592903..0dbcc8d466 100644 --- a/src/proto/h2/mod.rs +++ b/src/proto/h2/mod.rs @@ -1,5 +1,5 @@ -use bytes::Buf; -use h2::SendStream; +use bytes::{Buf, Bytes}; +use h2::{RecvStream, SendStream}; use http::header::{ HeaderName, CONNECTION, PROXY_AUTHENTICATE, PROXY_AUTHORIZATION, TE, TRAILER, TRANSFER_ENCODING, UPGRADE, @@ -7,11 +7,14 @@ use http::header::{ use http::HeaderMap; use pin_project::pin_project; use std::error::Error as StdError; -use std::io::IoSlice; +use std::io::{self, Cursor, IoSlice}; +use std::mem; +use std::task::Context; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; -use crate::body::{DecodedLength, HttpBody}; +use crate::body::HttpBody; use crate::common::{task, Future, Pin, Poll}; -use crate::headers::content_length_parse_all; +use crate::proto::h2::ping::Recorder; pub(crate) mod ping; @@ -83,15 +86,6 @@ fn strip_connection_headers(headers: &mut HeaderMap, is_request: bool) { } } -fn decode_content_length(headers: &HeaderMap) -> DecodedLength { - if let Some(len) = content_length_parse_all(headers) { - // If the length is u64::MAX, oh well, just reported chunked. - DecodedLength::checked_new(len).unwrap_or_else(|_| DecodedLength::CHUNKED) - } else { - DecodedLength::CHUNKED - } -} - // body adapters used by both Client and Server #[pin_project] @@ -172,7 +166,7 @@ where is_eos, ); - let buf = SendBuf(Some(chunk)); + let buf = SendBuf::Buf(chunk); me.body_tx .send_data(buf, is_eos) .map_err(crate::Error::new_body_write)?; @@ -243,32 +237,202 @@ impl SendStreamExt for SendStream> { fn send_eos_frame(&mut self) -> crate::Result<()> { trace!("send body eos"); - self.send_data(SendBuf(None), true) + self.send_data(SendBuf::None, true) .map_err(crate::Error::new_body_write) } } -struct SendBuf(Option); +#[repr(usize)] +enum SendBuf { + Buf(B), + Cursor(Cursor>), + None, +} impl Buf for SendBuf { #[inline] fn remaining(&self) -> usize { - self.0.as_ref().map(|b| b.remaining()).unwrap_or(0) + match *self { + Self::Buf(ref b) => b.remaining(), + Self::Cursor(ref c) => c.remaining(), + Self::None => 0, + } } #[inline] fn chunk(&self) -> &[u8] { - self.0.as_ref().map(|b| b.chunk()).unwrap_or(&[]) + match *self { + Self::Buf(ref b) => b.chunk(), + Self::Cursor(ref c) => c.chunk(), + Self::None => &[], + } } #[inline] fn advance(&mut self, cnt: usize) { - if let Some(b) = self.0.as_mut() { - b.advance(cnt) + match *self { + Self::Buf(ref mut b) => b.advance(cnt), + Self::Cursor(ref mut c) => c.advance(cnt), + Self::None => {} } } fn chunks_vectored<'a>(&'a self, dst: &mut [IoSlice<'a>]) -> usize { - self.0.as_ref().map(|b| b.chunks_vectored(dst)).unwrap_or(0) + match *self { + Self::Buf(ref b) => b.chunks_vectored(dst), + Self::Cursor(ref c) => c.chunks_vectored(dst), + Self::None => 0, + } + } +} + +struct H2Upgraded +where + B: Buf, +{ + ping: Recorder, + send_stream: UpgradedSendStream, + recv_stream: RecvStream, + buf: Bytes, +} + +impl AsyncRead for H2Upgraded +where + B: Buf, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + read_buf: &mut ReadBuf<'_>, + ) -> Poll> { + if self.buf.is_empty() { + self.buf = loop { + match ready!(self.recv_stream.poll_data(cx)) { + None => return Poll::Ready(Ok(())), + Some(Ok(buf)) if buf.is_empty() && !self.recv_stream.is_end_stream() => { + continue + } + Some(Ok(buf)) => { + self.ping.record_data(buf.len()); + break buf; + } + Some(Err(e)) => { + return Poll::Ready(Err(h2_to_io_error(e))); + } + } + }; + } + let cnt = std::cmp::min(self.buf.len(), read_buf.remaining()); + read_buf.put_slice(&self.buf[..cnt]); + self.buf.advance(cnt); + let _ = self.recv_stream.flow_control().release_capacity(cnt); + Poll::Ready(Ok(())) + } +} + +impl AsyncWrite for H2Upgraded +where + B: Buf, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + if let Poll::Ready(reset) = self.send_stream.poll_reset(cx) { + return Poll::Ready(Err(h2_to_io_error(match reset { + Ok(reason) => reason.into(), + Err(e) => e, + }))); + } + if buf.is_empty() { + return Poll::Ready(Ok(0)); + } + self.send_stream.reserve_capacity(buf.len()); + Poll::Ready(match ready!(self.send_stream.poll_capacity(cx)) { + None => Ok(0), + Some(Ok(cnt)) => self.send_stream.write(&buf[..cnt], false).map(|()| cnt), + Some(Err(e)) => Err(h2_to_io_error(e)), + }) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(self.send_stream.write(&[], true)) + } +} + +fn h2_to_io_error(e: h2::Error) -> io::Error { + if e.is_io() { + e.into_io().unwrap() + } else { + io::Error::new(io::ErrorKind::Other, e) + } +} + +struct UpgradedSendStream(SendStream>>); + +impl UpgradedSendStream +where + B: Buf, +{ + unsafe fn new(inner: SendStream>) -> Self { + assert_eq!(mem::size_of::(), mem::size_of::>()); + Self(mem::transmute(inner)) + } + + fn reserve_capacity(&mut self, cnt: usize) { + unsafe { self.as_inner_unchecked().reserve_capacity(cnt) } + } + + fn poll_capacity(&mut self, cx: &mut Context<'_>) -> Poll>> { + unsafe { self.as_inner_unchecked().poll_capacity(cx) } + } + + fn poll_reset(&mut self, cx: &mut Context<'_>) -> Poll> { + unsafe { self.as_inner_unchecked().poll_reset(cx) } + } + + fn write(&mut self, buf: &[u8], end_of_stream: bool) -> Result<(), io::Error> { + let send_buf = SendBuf::Cursor(Cursor::new(buf.into())); + unsafe { + self.as_inner_unchecked() + .send_data(send_buf, end_of_stream) + .map_err(h2_to_io_error) + } + } + + unsafe fn as_inner_unchecked(&mut self) -> &mut SendStream> { + &mut *(&mut self.0 as *mut _ as *mut _) + } +} + +#[repr(transparent)] +struct Neutered { + _inner: B, + impossible: Impossible, +} + +enum Impossible {} + +unsafe impl Send for Neutered {} + +impl Buf for Neutered { + fn remaining(&self) -> usize { + match self.impossible {} + } + + fn chunk(&self) -> &[u8] { + match self.impossible {} + } + + fn advance(&mut self, _cnt: usize) { + match self.impossible {} } } diff --git a/src/proto/h2/server.rs b/src/proto/h2/server.rs index eea52e3e4b..de77eaa232 100644 --- a/src/proto/h2/server.rs +++ b/src/proto/h2/server.rs @@ -3,19 +3,24 @@ use std::marker::Unpin; #[cfg(feature = "runtime")] use std::time::Duration; +use bytes::Bytes; use h2::server::{Connection, Handshake, SendResponse}; -use h2::Reason; +use h2::{Reason, RecvStream}; +use http::{Method, Request}; use pin_project::pin_project; use tokio::io::{AsyncRead, AsyncWrite}; -use super::{decode_content_length, ping, PipeToSendStream, SendBuf}; +use super::{ping, PipeToSendStream, SendBuf}; use crate::body::HttpBody; use crate::common::exec::ConnStreamExec; use crate::common::{date, task, Future, Pin, Poll}; use crate::headers; +use crate::proto::h2::ping::Recorder; +use crate::proto::h2::{H2Upgraded, UpgradedSendStream}; use crate::proto::Dispatched; use crate::service::HttpService; +use crate::upgrade::{OnUpgrade, Pending, Upgraded}; use crate::{Body, Response}; // Our defaults are chosen for the "majority" case, which usually are not @@ -255,9 +260,9 @@ where // When the service is ready, accepts an incoming request. match ready!(self.conn.poll_accept(cx)) { - Some(Ok((req, respond))) => { + Some(Ok((req, mut respond))) => { trace!("incoming request"); - let content_length = decode_content_length(req.headers()); + let content_length = headers::content_length_parse_all(req.headers()); let ping = self .ping .as_ref() @@ -267,8 +272,36 @@ where // Record the headers received ping.record_non_data(); - let req = req.map(|stream| crate::Body::h2(stream, content_length, ping)); - let fut = H2Stream::new(service.call(req), respond); + let is_connect = req.method() == Method::CONNECT; + let (mut parts, stream) = req.into_parts(); + let (req, connect_parts) = if !is_connect { + ( + Request::from_parts( + parts, + crate::Body::h2(stream, content_length.into(), ping), + ), + None, + ) + } else { + if content_length.map_or(false, |len| len != 0) { + warn!("h2 connect request with non-zero body not supported"); + respond.send_reset(h2::Reason::INTERNAL_ERROR); + return Poll::Ready(Ok(())); + } + let (pending, upgrade) = crate::upgrade::pending(); + debug_assert!(parts.extensions.get::().is_none()); + parts.extensions.insert(upgrade); + ( + Request::from_parts(parts, crate::Body::empty()), + Some(ConnectParts { + pending, + ping, + recv_stream: stream, + }), + ) + }; + + let fut = H2Stream::new(service.call(req), connect_parts, respond); exec.execute_h2stream(fut); } Some(Err(e)) => { @@ -331,18 +364,28 @@ enum H2StreamState where B: HttpBody, { - Service(#[pin] F), + Service(#[pin] F, Option), Body(#[pin] PipeToSendStream), } +struct ConnectParts { + pending: Pending, + ping: Recorder, + recv_stream: RecvStream, +} + impl H2Stream where B: HttpBody, { - fn new(fut: F, respond: SendResponse>) -> H2Stream { + fn new( + fut: F, + connect_parts: Option, + respond: SendResponse>, + ) -> H2Stream { H2Stream { reply: respond, - state: H2StreamState::Service(fut), + state: H2StreamState::Service(fut, connect_parts), } } } @@ -364,6 +407,7 @@ impl H2Stream where F: Future, E>>, B: HttpBody, + B::Data: 'static, B::Error: Into>, E: Into>, { @@ -371,7 +415,7 @@ where let mut me = self.project(); loop { let next = match me.state.as_mut().project() { - H2StreamStateProj::Service(h) => { + H2StreamStateProj::Service(h, connect_parts) => { let res = match h.poll(cx) { Poll::Ready(Ok(r)) => r, Poll::Pending => { @@ -402,6 +446,29 @@ where .entry(::http::header::DATE) .or_insert_with(date::update_and_header_value); + if let Some(connect_parts) = connect_parts.take() { + if res.status().is_success() { + if headers::content_length_parse_all(res.headers()) + .map_or(false, |len| len != 0) + { + warn!("h2 successful response to CONNECT request with body not supported"); + me.reply.send_reset(h2::Reason::INTERNAL_ERROR); + return Poll::Ready(Err(crate::Error::new_user_header())); + } + let send_stream = reply!(me, res, false); + connect_parts.pending.fulfill(Upgraded::new( + H2Upgraded { + ping: connect_parts.ping, + recv_stream: connect_parts.recv_stream, + send_stream: unsafe { UpgradedSendStream::new(send_stream) }, + buf: Bytes::new(), + }, + Bytes::new(), + )); + return Poll::Ready(Ok(())); + } + } + // automatically set Content-Length from body... if let Some(len) = body.size_hint().exact() { headers::set_content_length_if_missing(res.headers_mut(), len); @@ -428,6 +495,7 @@ impl Future for H2Stream where F: Future, E>>, B: HttpBody, + B::Data: 'static, B::Error: Into>, E: Into>, { diff --git a/src/upgrade.rs b/src/upgrade.rs index 6004c1a31a..efab10a6fc 100644 --- a/src/upgrade.rs +++ b/src/upgrade.rs @@ -62,12 +62,12 @@ pub fn on(msg: T) -> OnUpgrade { msg.on_upgrade() } -#[cfg(feature = "http1")] +#[cfg(any(feature = "http1", feature = "http2"))] pub(super) struct Pending { tx: oneshot::Sender>, } -#[cfg(feature = "http1")] +#[cfg(any(feature = "http1", feature = "http2"))] pub(super) fn pending() -> (Pending, OnUpgrade) { let (tx, rx) = oneshot::channel(); (Pending { tx }, OnUpgrade { rx: Some(rx) }) @@ -76,7 +76,7 @@ pub(super) fn pending() -> (Pending, OnUpgrade) { // ===== impl Upgraded ===== impl Upgraded { - #[cfg(any(feature = "http1", test))] + #[cfg(any(feature = "http1", feature = "http2", test))] pub(super) fn new(io: T, read_buf: Bytes) -> Self where T: AsyncRead + AsyncWrite + Unpin + Send + 'static, @@ -187,13 +187,14 @@ impl fmt::Debug for OnUpgrade { // ===== impl Pending ===== -#[cfg(feature = "http1")] +#[cfg(any(feature = "http1", feature = "http2"))] impl Pending { pub(super) fn fulfill(self, upgraded: Upgraded) { trace!("pending upgrade fulfill"); let _ = self.tx.send(Ok(upgraded)); } + #[cfg(feature = "http1")] /// Don't fulfill the pending Upgrade, but instead signal that /// upgrades are handled manually. pub(super) fn manual(self) { diff --git a/tests/client.rs b/tests/client.rs index 978f79a1d1..3eb6dd9015 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -2261,14 +2261,16 @@ mod conn { use std::thread; use std::time::Duration; + use bytes::Buf; use futures_channel::oneshot; use futures_util::future::{self, poll_fn, FutureExt, TryFutureExt}; use futures_util::StreamExt; + use hyper::upgrade::OnUpgrade; use tokio::io::{AsyncRead, AsyncReadExt as _, AsyncWrite, AsyncWriteExt as _, ReadBuf}; use tokio::net::{TcpListener as TkTcpListener, TcpStream}; use hyper::client::conn; - use hyper::{self, Body, Method, Request}; + use hyper::{self, Body, Method, Request, Response, StatusCode}; use super::{concat, s, support, tcp_connect, FutureHyperExt}; @@ -2984,6 +2986,125 @@ mod conn { .expect("client should be open"); } + #[tokio::test] + async fn h2_connect() { + let _ = pretty_env_logger::try_init(); + + let listener = TkTcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0))) + .await + .unwrap(); + let addr = listener.local_addr().unwrap(); + + // Spawn an HTTP2 server that asks for bread and responds with baguette. + tokio::spawn(async move { + let sock = listener.accept().await.unwrap().0; + let mut h2 = h2::server::handshake(sock).await.unwrap(); + + let (req, mut respond) = h2.accept().await.unwrap().unwrap(); + tokio::spawn(async move { + poll_fn(|cx| h2.poll_closed(cx)).await.unwrap(); + }); + assert_eq!(req.method(), Method::CONNECT); + + let mut body = req.into_body(); + + let mut send_stream = respond.send_response(Response::default(), false).unwrap(); + + send_stream.send_data("Bread?".into(), true).unwrap(); + + let bytes = body.data().await.unwrap().unwrap(); + assert_eq!(&bytes[..], b"Baguette!"); + let _ = body.flow_control().release_capacity(bytes.len()); + + assert!(body.data().await.is_none()); + }); + + let io = tcp_connect(&addr).await.expect("tcp connect"); + let (mut client, conn) = conn::Builder::new() + .http2_only(true) + .handshake::<_, Body>(io) + .await + .expect("http handshake"); + + tokio::spawn(async move { + conn.await.expect("client conn shouldn't error"); + }); + + let req = Request::connect("localhost") + .body(hyper::Body::empty()) + .unwrap(); + let res = client.send_request(req).await.expect("send_request"); + assert_eq!(res.status(), StatusCode::OK); + + let mut upgraded = hyper::upgrade::on(res).await.unwrap(); + + let mut vec = vec![]; + upgraded.read_to_end(&mut vec).await.unwrap(); + assert_eq!(s(&vec), "Bread?"); + + upgraded.write_all(b"Baguette!").await.unwrap(); + + upgraded.shutdown().await.unwrap(); + } + + #[tokio::test] + async fn h2_connect_rejected() { + let _ = pretty_env_logger::try_init(); + + let listener = TkTcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0))) + .await + .unwrap(); + let addr = listener.local_addr().unwrap(); + let (done_tx, done_rx) = oneshot::channel(); + + tokio::spawn(async move { + let sock = listener.accept().await.unwrap().0; + let mut h2 = h2::server::handshake(sock).await.unwrap(); + + let (req, mut respond) = h2.accept().await.unwrap().unwrap(); + tokio::spawn(async move { + poll_fn(|cx| h2.poll_closed(cx)).await.unwrap(); + }); + assert_eq!(req.method(), Method::CONNECT); + + let res = Response::builder().status(400).body(()).unwrap(); + let mut send_stream = respond.send_response(res, false).unwrap(); + send_stream + .send_data("No bread for you!".into(), true) + .unwrap(); + done_rx.await.unwrap(); + }); + + let io = tcp_connect(&addr).await.expect("tcp connect"); + let (mut client, conn) = conn::Builder::new() + .http2_only(true) + .handshake::<_, Body>(io) + .await + .expect("http handshake"); + + tokio::spawn(async move { + conn.await.expect("client conn shouldn't error"); + }); + + let req = Request::connect("localhost") + .body(hyper::Body::empty()) + .unwrap(); + let res = client.send_request(req).await.expect("send_request"); + assert_eq!(res.status(), StatusCode::BAD_REQUEST); + assert!(res.extensions().get::().is_none()); + + let mut body = String::new(); + hyper::body::aggregate(res.into_body()) + .await + .unwrap() + .reader() + .read_to_string(&mut body) + .unwrap(); + assert_eq!(body, "No bread for you!"); + + done_tx.send(()).unwrap(); + } + async fn drain_til_eof(mut sock: T) -> io::Result<()> { let mut buf = [0u8; 1024]; loop { diff --git a/tests/server.rs b/tests/server.rs index 662e903d57..297b09ac73 100644 --- a/tests/server.rs +++ b/tests/server.rs @@ -13,10 +13,13 @@ use std::task::{Context, Poll}; use std::thread; use std::time::Duration; +use bytes::Bytes; use futures_channel::oneshot; use futures_util::future::{self, Either, FutureExt, TryFutureExt}; #[cfg(feature = "stream")] use futures_util::stream::StreamExt as _; +use h2::client::SendRequest; +use h2::{RecvStream, SendStream}; use http::header::{HeaderName, HeaderValue}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}; use tokio::net::{TcpListener, TcpStream as TkTcpStream}; @@ -1482,6 +1485,339 @@ async fn http_connect_new() { assert_eq!(s(&vec), "bar=foo"); } +#[tokio::test] +async fn h2_connect() { + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + let _ = pretty_env_logger::try_init(); + let listener = tcp_bind(&"127.0.0.1:0".parse().unwrap()).unwrap(); + let addr = listener.local_addr().unwrap(); + let conn = connect_async(addr).await; + + let (h2, connection) = h2::client::handshake(conn).await.unwrap(); + tokio::spawn(async move { + connection.await.unwrap(); + }); + let mut h2 = h2.ready().await.unwrap(); + + async fn connect_and_recv_bread( + h2: &mut SendRequest, + ) -> (RecvStream, SendStream) { + let request = Request::connect("localhost").body(()).unwrap(); + let (response, send_stream) = h2.send_request(request, false).unwrap(); + let response = response.await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + + let mut body = response.into_body(); + let bytes = body.data().await.unwrap().unwrap(); + assert_eq!(&bytes[..], b"Bread?"); + let _ = body.flow_control().release_capacity(bytes.len()); + + (body, send_stream) + } + + tokio::spawn(async move { + let (mut recv_stream, mut send_stream) = connect_and_recv_bread(&mut h2).await; + + send_stream.send_data("Baguette!".into(), true).unwrap(); + + assert!(recv_stream.data().await.unwrap().unwrap().is_empty()); + }); + + let svc = service_fn(move |req: Request| { + let on_upgrade = hyper::upgrade::on(req); + + tokio::spawn(async move { + let mut upgraded = on_upgrade.await.expect("on_upgrade"); + upgraded.write_all(b"Bread?").await.unwrap(); + + let mut vec = vec![]; + upgraded.read_to_end(&mut vec).await.unwrap(); + assert_eq!(s(&vec), "Baguette!"); + + upgraded.shutdown().await.unwrap(); + }); + + future::ok::<_, hyper::Error>( + Response::builder() + .status(200) + .body(hyper::Body::empty()) + .unwrap(), + ) + }); + + let (socket, _) = listener.accept().await.unwrap(); + Http::new() + .http2_only(true) + .serve_connection(socket, svc) + .with_upgrades() + .await + .unwrap(); +} + +#[tokio::test] +async fn h2_connect_multiplex() { + use futures_util::stream::FuturesUnordered; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + let _ = pretty_env_logger::try_init(); + let listener = tcp_bind(&"127.0.0.1:0".parse().unwrap()).unwrap(); + let addr = listener.local_addr().unwrap(); + let conn = connect_async(addr).await; + + let (h2, connection) = h2::client::handshake(conn).await.unwrap(); + tokio::spawn(async move { + connection.await.unwrap(); + }); + let mut h2 = h2.ready().await.unwrap(); + + tokio::spawn(async move { + let mut streams = vec![]; + for i in 0..80 { + let request = Request::connect(format!("localhost_{}", i % 4)) + .body(()) + .unwrap(); + let (response, send_stream) = h2.send_request(request, false).unwrap(); + streams.push((i, response, send_stream)); + } + + let futures = streams + .into_iter() + .map(|(i, response, mut send_stream)| async move { + if i % 4 == 0 { + return; + } + + let response = response.await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + + if i % 4 == 1 { + return; + } + + let mut body = response.into_body(); + let bytes = body.data().await.unwrap().unwrap(); + assert_eq!(&bytes[..], b"Bread?"); + let _ = body.flow_control().release_capacity(bytes.len()); + + if i % 4 == 2 { + return; + } + + send_stream.send_data("Baguette!".into(), true).unwrap(); + + assert!(body.data().await.unwrap().unwrap().is_empty()); + }) + .collect::>(); + + futures.for_each(future::ready).await; + }); + + let svc = service_fn(move |req: Request| { + let authority = req.uri().authority().unwrap().to_string(); + let on_upgrade = hyper::upgrade::on(req); + + tokio::spawn(async move { + let upgrade_res = on_upgrade.await; + if authority == "localhost_0" { + assert!(upgrade_res.expect_err("upgrade cancelled").is_canceled()); + return; + } + let mut upgraded = upgrade_res.expect("upgrade successful"); + + upgraded.write_all(b"Bread?").await.unwrap(); + + let mut vec = vec![]; + let read_res = upgraded.read_to_end(&mut vec).await; + + if authority == "localhost_1" || authority == "localhost_2" { + let err = read_res.expect_err("read failed"); + assert_eq!(err.kind(), io::ErrorKind::Other); + assert_eq!( + err.get_ref() + .unwrap() + .downcast_ref::() + .unwrap() + .reason(), + Some(h2::Reason::CANCEL), + ); + return; + } + + read_res.unwrap(); + assert_eq!(s(&vec), "Baguette!"); + + upgraded.shutdown().await.unwrap(); + }); + + future::ok::<_, hyper::Error>( + Response::builder() + .status(200) + .body(hyper::Body::empty()) + .unwrap(), + ) + }); + + let (socket, _) = listener.accept().await.unwrap(); + Http::new() + .http2_only(true) + .serve_connection(socket, svc) + .with_upgrades() + .await + .unwrap(); +} + +#[tokio::test] +async fn h2_connect_large_body() { + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + let _ = pretty_env_logger::try_init(); + let listener = tcp_bind(&"127.0.0.1:0".parse().unwrap()).unwrap(); + let addr = listener.local_addr().unwrap(); + let conn = connect_async(addr).await; + + let (h2, connection) = h2::client::handshake(conn).await.unwrap(); + tokio::spawn(async move { + connection.await.unwrap(); + }); + let mut h2 = h2.ready().await.unwrap(); + + const NO_BREAD: &str = "All work and no bread makes nox a dull boy.\n"; + + async fn connect_and_recv_bread( + h2: &mut SendRequest, + ) -> (RecvStream, SendStream) { + let request = Request::connect("localhost").body(()).unwrap(); + let (response, send_stream) = h2.send_request(request, false).unwrap(); + let response = response.await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + + let mut body = response.into_body(); + let bytes = body.data().await.unwrap().unwrap(); + assert_eq!(&bytes[..], b"Bread?"); + let _ = body.flow_control().release_capacity(bytes.len()); + + (body, send_stream) + } + + tokio::spawn(async move { + let (mut recv_stream, mut send_stream) = connect_and_recv_bread(&mut h2).await; + + let large_body = Bytes::from(NO_BREAD.repeat(9000)); + + send_stream.send_data(large_body.clone(), false).unwrap(); + send_stream.send_data(large_body, true).unwrap(); + + assert!(recv_stream.data().await.unwrap().unwrap().is_empty()); + }); + + let svc = service_fn(move |req: Request| { + let on_upgrade = hyper::upgrade::on(req); + + tokio::spawn(async move { + let mut upgraded = on_upgrade.await.expect("on_upgrade"); + upgraded.write_all(b"Bread?").await.unwrap(); + + let mut vec = vec![]; + if upgraded.read_to_end(&mut vec).await.is_err() { + return; + } + assert_eq!(vec.len(), NO_BREAD.len() * 9000 * 2); + + upgraded.shutdown().await.unwrap(); + }); + + future::ok::<_, hyper::Error>( + Response::builder() + .status(200) + .body(hyper::Body::empty()) + .unwrap(), + ) + }); + + let (socket, _) = listener.accept().await.unwrap(); + Http::new() + .http2_only(true) + .serve_connection(socket, svc) + .with_upgrades() + .await + .unwrap(); +} + +#[tokio::test] +async fn h2_connect_empty_frames() { + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + let _ = pretty_env_logger::try_init(); + let listener = tcp_bind(&"127.0.0.1:0".parse().unwrap()).unwrap(); + let addr = listener.local_addr().unwrap(); + let conn = connect_async(addr).await; + + let (h2, connection) = h2::client::handshake(conn).await.unwrap(); + tokio::spawn(async move { + connection.await.unwrap(); + }); + let mut h2 = h2.ready().await.unwrap(); + + async fn connect_and_recv_bread( + h2: &mut SendRequest, + ) -> (RecvStream, SendStream) { + let request = Request::connect("localhost").body(()).unwrap(); + let (response, send_stream) = h2.send_request(request, false).unwrap(); + let response = response.await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + + let mut body = response.into_body(); + let bytes = body.data().await.unwrap().unwrap(); + assert_eq!(&bytes[..], b"Bread?"); + let _ = body.flow_control().release_capacity(bytes.len()); + + (body, send_stream) + } + + tokio::spawn(async move { + let (mut recv_stream, mut send_stream) = connect_and_recv_bread(&mut h2).await; + + send_stream.send_data("".into(), false).unwrap(); + send_stream.send_data("".into(), false).unwrap(); + send_stream.send_data("".into(), false).unwrap(); + send_stream.send_data("Baguette!".into(), false).unwrap(); + send_stream.send_data("".into(), true).unwrap(); + + assert!(recv_stream.data().await.unwrap().unwrap().is_empty()); + }); + + let svc = service_fn(move |req: Request| { + let on_upgrade = hyper::upgrade::on(req); + + tokio::spawn(async move { + let mut upgraded = on_upgrade.await.expect("on_upgrade"); + upgraded.write_all(b"Bread?").await.unwrap(); + + let mut vec = vec![]; + upgraded.read_to_end(&mut vec).await.unwrap(); + assert_eq!(s(&vec), "Baguette!"); + + upgraded.shutdown().await.unwrap(); + }); + + future::ok::<_, hyper::Error>( + Response::builder() + .status(200) + .body(hyper::Body::empty()) + .unwrap(), + ) + }); + + let (socket, _) = listener.accept().await.unwrap(); + Http::new() + .http2_only(true) + .serve_connection(socket, svc) + .with_upgrades() + .await + .unwrap(); +} + #[tokio::test] async fn parse_errors_send_4xx_response() { let listener = tcp_bind(&"127.0.0.1:0".parse().unwrap()).unwrap();