Skip to content

Commit

Permalink
Add HTTP Upgrade support to Response. (#1376)
Browse files Browse the repository at this point in the history
  • Loading branch information
luqmana authored Jul 28, 2022
1 parent e9ba0a9 commit 61474f4
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 42 deletions.
1 change: 1 addition & 0 deletions src/async_impl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ pub mod decoder;
pub mod multipart;
pub(crate) mod request;
mod response;
mod upgrade;
72 changes: 30 additions & 42 deletions src/async_impl/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,10 @@ use crate::response::ResponseUrl;

/// A Response to a submitted `Request`.
pub struct Response {
status: StatusCode,
headers: HeaderMap,
pub(super) res: hyper::Response<Decoder>,
// Boxed to save space (11 words to 1 word), and it's not accessed
// frequently internally.
url: Box<Url>,
body: Decoder,
version: Version,
extensions: http::Extensions,
}

impl Response {
Expand All @@ -41,46 +37,38 @@ impl Response {
accepts: Accepts,
timeout: Option<Pin<Box<Sleep>>>,
) -> Response {
let (parts, body) = res.into_parts();
let status = parts.status;
let version = parts.version;
let extensions = parts.extensions;

let mut headers = parts.headers;
let decoder = Decoder::detect(&mut headers, Body::response(body, timeout), accepts);
let (mut parts, body) = res.into_parts();
let decoder = Decoder::detect(&mut parts.headers, Body::response(body, timeout), accepts);
let res = hyper::Response::from_parts(parts, decoder);

Response {
status,
headers,
res,
url: Box::new(url),
body: decoder,
version,
extensions,
}
}

/// Get the `StatusCode` of this `Response`.
#[inline]
pub fn status(&self) -> StatusCode {
self.status
self.res.status()
}

/// Get the HTTP `Version` of this `Response`.
#[inline]
pub fn version(&self) -> Version {
self.version
self.res.version()
}

/// Get the `Headers` of this `Response`.
#[inline]
pub fn headers(&self) -> &HeaderMap {
&self.headers
self.res.headers()
}

/// Get a mutable reference to the `Headers` of this `Response`.
#[inline]
pub fn headers_mut(&mut self) -> &mut HeaderMap {
&mut self.headers
self.res.headers_mut()
}

/// Get the content-length of this response, if known.
Expand All @@ -93,7 +81,7 @@ impl Response {
pub fn content_length(&self) -> Option<u64> {
use hyper::body::HttpBody;

HttpBody::size_hint(&self.body).exact()
HttpBody::size_hint(self.res.body()).exact()
}

/// Retrieve the cookies contained in the response.
Expand All @@ -106,7 +94,7 @@ impl Response {
#[cfg(feature = "cookies")]
#[cfg_attr(docsrs, doc(cfg(feature = "cookies")))]
pub fn cookies<'a>(&'a self) -> impl Iterator<Item = cookie::Cookie<'a>> + 'a {
cookie::extract_response_cookies(&self.headers).filter_map(Result::ok)
cookie::extract_response_cookies(self.res.headers()).filter_map(Result::ok)
}

/// Get the final `Url` of this `Response`.
Expand All @@ -117,19 +105,20 @@ impl Response {

/// Get the remote address used to get this `Response`.
pub fn remote_addr(&self) -> Option<SocketAddr> {
self.extensions
self.res
.extensions()
.get::<HttpInfo>()
.map(|info| info.remote_addr())
}

/// Returns a reference to the associated extensions.
pub fn extensions(&self) -> &http::Extensions {
&self.extensions
self.res.extensions()
}

/// Returns a mutable reference to the associated extensions.
pub fn extensions_mut(&mut self) -> &mut http::Extensions {
&mut self.extensions
self.res.extensions_mut()
}

// body methods
Expand Down Expand Up @@ -183,7 +172,7 @@ impl Response {
/// ```
pub async fn text_with_charset(self, default_encoding: &str) -> crate::Result<String> {
let content_type = self
.headers
.headers()
.get(crate::header::CONTENT_TYPE)
.and_then(|value| value.to_str().ok())
.and_then(|value| value.parse::<Mime>().ok());
Expand Down Expand Up @@ -271,7 +260,7 @@ impl Response {
/// # }
/// ```
pub async fn bytes(self) -> crate::Result<Bytes> {
hyper::body::to_bytes(self.body).await
hyper::body::to_bytes(self.res.into_body()).await
}

/// Stream a chunk of the response body.
Expand All @@ -291,7 +280,7 @@ impl Response {
/// # }
/// ```
pub async fn chunk(&mut self) -> crate::Result<Option<Bytes>> {
if let Some(item) = self.body.next().await {
if let Some(item) = self.res.body_mut().next().await {
Ok(Some(item?))
} else {
Ok(None)
Expand Down Expand Up @@ -323,7 +312,7 @@ impl Response {
#[cfg(feature = "stream")]
#[cfg_attr(docsrs, doc(cfg(feature = "stream")))]
pub fn bytes_stream(self) -> impl futures_core::Stream<Item = crate::Result<Bytes>> {
self.body
self.res.into_body()
}

// util methods
Expand All @@ -350,8 +339,9 @@ impl Response {
/// # fn main() {}
/// ```
pub fn error_for_status(self) -> crate::Result<Self> {
if self.status.is_client_error() || self.status.is_server_error() {
Err(crate::error::status_code(*self.url, self.status))
let status = self.status();
if status.is_client_error() || status.is_server_error() {
Err(crate::error::status_code(*self.url, status))
} else {
Ok(self)
}
Expand Down Expand Up @@ -379,8 +369,9 @@ impl Response {
/// # fn main() {}
/// ```
pub fn error_for_status_ref(&self) -> crate::Result<&Self> {
if self.status.is_client_error() || self.status.is_server_error() {
Err(crate::error::status_code(*self.url.clone(), self.status))
let status = self.status();
if status.is_client_error() || status.is_server_error() {
Err(crate::error::status_code(*self.url.clone(), status))
} else {
Ok(self)
}
Expand All @@ -395,7 +386,7 @@ impl Response {
// This method is just used by the blocking API.
#[cfg(feature = "blocking")]
pub(crate) fn body_mut(&mut self) -> &mut Decoder {
&mut self.body
self.res.body_mut()
}
}

Expand All @@ -413,27 +404,24 @@ impl<T: Into<Body>> From<http::Response<T>> for Response {
fn from(r: http::Response<T>) -> Response {
let (mut parts, body) = r.into_parts();
let body = body.into();
let body = Decoder::detect(&mut parts.headers, body, Accepts::none());
let decoder = Decoder::detect(&mut parts.headers, body, Accepts::none());
let url = parts
.extensions
.remove::<ResponseUrl>()
.unwrap_or_else(|| ResponseUrl(Url::parse("http://no.url.provided.local").unwrap()));
let url = url.0;
let res = hyper::Response::from_parts(parts, decoder);
Response {
status: parts.status,
headers: parts.headers,
res,
url: Box::new(url),
body,
version: parts.version,
extensions: parts.extensions,
}
}
}

/// A `Response` can be piped as the `Body` of another request.
impl From<Response> for Body {
fn from(r: Response) -> Body {
Body::stream(r.body)
Body::stream(r.res.into_body())
}
}

Expand Down
73 changes: 73 additions & 0 deletions src/async_impl/upgrade.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
use std::pin::Pin;
use std::task::{self, Poll};
use std::{fmt, io};

use futures_util::TryFutureExt;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

/// An upgraded HTTP connection.
pub struct Upgraded {
inner: hyper::upgrade::Upgraded,
}

impl AsyncRead for Upgraded {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_read(cx, buf)
}
}

impl AsyncWrite for Upgraded {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_write(cx, buf)
}

fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_flush(cx)
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_shutdown(cx)
}

fn is_write_vectored(&self) -> bool {
self.inner.is_write_vectored()
}
}

impl fmt::Debug for Upgraded {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Upgraded").finish()
}
}

impl From<hyper::upgrade::Upgraded> for Upgraded {
fn from(inner: hyper::upgrade::Upgraded) -> Self {
Upgraded { inner }
}
}

impl super::response::Response {
/// Consumes the response and returns a future for a possible HTTP upgrade.
pub async fn upgrade(self) -> crate::Result<Upgraded> {
hyper::upgrade::on(self.res)
.map_ok(Upgraded::from)
.map_err(crate::error::upgrade)
.await
}
}
6 changes: 6 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ impl fmt::Display for Error {
Kind::Body => f.write_str("request or response body error")?,
Kind::Decode => f.write_str("error decoding response body")?,
Kind::Redirect => f.write_str("error following redirect")?,
Kind::Upgrade => f.write_str("error upgrading connection")?,
Kind::Status(ref code) => {
let prefix = if code.is_client_error() {
"HTTP status client error"
Expand Down Expand Up @@ -236,6 +237,7 @@ pub(crate) enum Kind {
Status(StatusCode),
Body,
Decode,
Upgrade,
}

// constructors
Expand Down Expand Up @@ -274,6 +276,10 @@ if_wasm! {
}
}

pub(crate) fn upgrade<E: Into<BoxError>>(e: E) -> Error {
Error::new(Kind::Upgrade, Some(e))
}

// io::Error helpers

#[allow(unused)]
Expand Down
51 changes: 51 additions & 0 deletions tests/upgrade.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#![cfg(not(target_arch = "wasm32"))]
mod support;
use support::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt};

#[tokio::test]
async fn http_upgrade() {
let server = server::http(move |req| {
assert_eq!(req.method(), "GET");
assert_eq!(req.headers()["connection"], "upgrade");
assert_eq!(req.headers()["upgrade"], "foobar");

tokio::spawn(async move {
let mut upgraded = hyper::upgrade::on(req).await.unwrap();

let mut buf = vec![0; 7];
upgraded.read_exact(&mut buf).await.unwrap();
assert_eq!(buf, b"foo=bar");

upgraded.write_all(b"bar=foo").await.unwrap();
});

async {
http::Response::builder()
.status(http::StatusCode::SWITCHING_PROTOCOLS)
.header(http::header::CONNECTION, "upgrade")
.header(http::header::UPGRADE, "foobar")
.body(hyper::Body::empty())
.unwrap()
}
});

let res = reqwest::Client::builder()
.build()
.unwrap()
.get(format!("http://{}", server.addr()))
.header(http::header::CONNECTION, "upgrade")
.header(http::header::UPGRADE, "foobar")
.send()
.await
.unwrap();

assert_eq!(res.status(), http::StatusCode::SWITCHING_PROTOCOLS);
let mut upgraded = res.upgrade().await.unwrap();

upgraded.write_all(b"foo=bar").await.unwrap();

let mut buf = vec![];
upgraded.read_to_end(&mut buf).await.unwrap();
assert_eq!(buf, b"bar=foo");
}

0 comments on commit 61474f4

Please sign in to comment.