-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add HTTP Upgrade support to Response. (#1376)
- Loading branch information
Showing
5 changed files
with
161 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,3 +13,4 @@ pub mod decoder; | |
pub mod multipart; | ||
pub(crate) mod request; | ||
mod response; | ||
mod upgrade; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
use std::pin::Pin; | ||
use std::task::{self, Poll}; | ||
use std::{fmt, io}; | ||
|
||
use futures_util::TryFutureExt; | ||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; | ||
|
||
/// An upgraded HTTP connection. | ||
pub struct Upgraded { | ||
inner: hyper::upgrade::Upgraded, | ||
} | ||
|
||
impl AsyncRead for Upgraded { | ||
fn poll_read( | ||
mut self: Pin<&mut Self>, | ||
cx: &mut task::Context<'_>, | ||
buf: &mut ReadBuf<'_>, | ||
) -> Poll<io::Result<()>> { | ||
Pin::new(&mut self.inner).poll_read(cx, buf) | ||
} | ||
} | ||
|
||
impl AsyncWrite for Upgraded { | ||
fn poll_write( | ||
mut self: Pin<&mut Self>, | ||
cx: &mut task::Context<'_>, | ||
buf: &[u8], | ||
) -> Poll<io::Result<usize>> { | ||
Pin::new(&mut self.inner).poll_write(cx, buf) | ||
} | ||
|
||
fn poll_write_vectored( | ||
mut self: Pin<&mut Self>, | ||
cx: &mut task::Context<'_>, | ||
bufs: &[io::IoSlice<'_>], | ||
) -> Poll<io::Result<usize>> { | ||
Pin::new(&mut self.inner).poll_write_vectored(cx, bufs) | ||
} | ||
|
||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { | ||
Pin::new(&mut self.inner).poll_flush(cx) | ||
} | ||
|
||
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> { | ||
Pin::new(&mut self.inner).poll_shutdown(cx) | ||
} | ||
|
||
fn is_write_vectored(&self) -> bool { | ||
self.inner.is_write_vectored() | ||
} | ||
} | ||
|
||
impl fmt::Debug for Upgraded { | ||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | ||
f.debug_struct("Upgraded").finish() | ||
} | ||
} | ||
|
||
impl From<hyper::upgrade::Upgraded> for Upgraded { | ||
fn from(inner: hyper::upgrade::Upgraded) -> Self { | ||
Upgraded { inner } | ||
} | ||
} | ||
|
||
impl super::response::Response { | ||
/// Consumes the response and returns a future for a possible HTTP upgrade. | ||
pub async fn upgrade(self) -> crate::Result<Upgraded> { | ||
hyper::upgrade::on(self.res) | ||
.map_ok(Upgraded::from) | ||
.map_err(crate::error::upgrade) | ||
.await | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
#![cfg(not(target_arch = "wasm32"))] | ||
mod support; | ||
use support::*; | ||
use tokio::io::{AsyncReadExt, AsyncWriteExt}; | ||
|
||
#[tokio::test] | ||
async fn http_upgrade() { | ||
let server = server::http(move |req| { | ||
assert_eq!(req.method(), "GET"); | ||
assert_eq!(req.headers()["connection"], "upgrade"); | ||
assert_eq!(req.headers()["upgrade"], "foobar"); | ||
|
||
tokio::spawn(async move { | ||
let mut upgraded = hyper::upgrade::on(req).await.unwrap(); | ||
|
||
let mut buf = vec![0; 7]; | ||
upgraded.read_exact(&mut buf).await.unwrap(); | ||
assert_eq!(buf, b"foo=bar"); | ||
|
||
upgraded.write_all(b"bar=foo").await.unwrap(); | ||
}); | ||
|
||
async { | ||
http::Response::builder() | ||
.status(http::StatusCode::SWITCHING_PROTOCOLS) | ||
.header(http::header::CONNECTION, "upgrade") | ||
.header(http::header::UPGRADE, "foobar") | ||
.body(hyper::Body::empty()) | ||
.unwrap() | ||
} | ||
}); | ||
|
||
let res = reqwest::Client::builder() | ||
.build() | ||
.unwrap() | ||
.get(format!("http://{}", server.addr())) | ||
.header(http::header::CONNECTION, "upgrade") | ||
.header(http::header::UPGRADE, "foobar") | ||
.send() | ||
.await | ||
.unwrap(); | ||
|
||
assert_eq!(res.status(), http::StatusCode::SWITCHING_PROTOCOLS); | ||
let mut upgraded = res.upgrade().await.unwrap(); | ||
|
||
upgraded.write_all(b"foo=bar").await.unwrap(); | ||
|
||
let mut buf = vec![]; | ||
upgraded.read_to_end(&mut buf).await.unwrap(); | ||
assert_eq!(buf, b"bar=foo"); | ||
} |