Skip to content

Commit

Permalink
Fix HTTP/2: retry requests rejected with REFUSED_STREAM (#2081)
Browse files Browse the repository at this point in the history
  • Loading branch information
magurotuna authored Jan 15, 2024
1 parent 837a58e commit 4ab5fb0
Show file tree
Hide file tree
Showing 6 changed files with 226 additions and 20 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ mime_guess = { version = "2.0", default-features = false, optional = true }
encoding_rs = "0.8"
http-body = "0.4.0"
hyper = { version = "0.14.21", default-features = false, features = ["tcp", "http1", "http2", "client", "runtime"] }
h2 = "0.3.10"
h2 = "0.3.14"
once_cell = "1"
log = "0.4"
mime = "0.3.16"
Expand Down Expand Up @@ -155,6 +155,7 @@ libflate = "1.0"
brotli_crate = { package = "brotli", version = "3.3.0" }
doc-comment = "0.3"
tokio = { version = "1.0", default-features = false, features = ["macros", "rt-multi-thread"] }
futures-util = { version = "0.3.0", default-features = false, features = ["std", "alloc"] }

[target.'cfg(windows)'.dependencies]
winreg = "0.50.0"
Expand Down
13 changes: 10 additions & 3 deletions src/async_impl/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2218,9 +2218,16 @@ fn is_retryable_error(err: &(dyn std::error::Error + 'static)) -> bool {
if let Some(cause) = err.source() {
if let Some(err) = cause.downcast_ref::<h2::Error>() {
// They sent us a graceful shutdown, try with a new connection!
return err.is_go_away()
&& err.is_remote()
&& err.reason() == Some(h2::Reason::NO_ERROR);
if err.is_go_away() && err.is_remote() && err.reason() == Some(h2::Reason::NO_ERROR) {
return true;
}

// REFUSED_STREAM was sent from the server, which is safe to retry.
// https://www.rfc-editor.org/rfc/rfc9113.html#section-8.7-3.2
if err.is_reset() && err.is_remote() && err.reason() == Some(h2::Reason::REFUSED_STREAM)
{
return true;
}
}
}
false
Expand Down
68 changes: 68 additions & 0 deletions tests/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
mod support;

use futures_util::stream::StreamExt;
use support::delay_server;
use support::server;

#[cfg(feature = "json")]
Expand Down Expand Up @@ -438,3 +439,70 @@ async fn test_tls_info() {
let tls_info = resp.extensions().get::<reqwest::tls::TlsInfo>();
assert!(tls_info.is_none());
}

// NOTE: using the default "curernt_thread" runtime here would cause the test to
// fail, because the only thread would block until `panic_rx` receives a
// notification while the client needs to be driven to get the graceful shutdown
// done.
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn highly_concurrent_requests_to_http2_server_with_low_max_concurrent_streams() {
let client = reqwest::Client::builder()
.http2_prior_knowledge()
.build()
.unwrap();

let server = server::http_with_config(
move |req| async move {
assert_eq!(req.version(), http::Version::HTTP_2);
http::Response::default()
},
|builder| builder.http2_only(true).http2_max_concurrent_streams(1),
);

let url = format!("http://{}", server.addr());

let futs = (0..100).map(|_| {
let client = client.clone();
let url = url.clone();
async move {
let res = client.get(&url).send().await.unwrap();
assert_eq!(res.status(), reqwest::StatusCode::OK);
}
});
futures_util::future::join_all(futs).await;
}

#[tokio::test]
async fn highly_concurrent_requests_to_slow_http2_server_with_low_max_concurrent_streams() {
let client = reqwest::Client::builder()
.http2_prior_knowledge()
.build()
.unwrap();

let server = delay_server::Server::new(
move |req| async move {
assert_eq!(req.version(), http::Version::HTTP_2);
http::Response::default()
},
|mut http| {
http.http2_only(true).http2_max_concurrent_streams(1);
http
},
std::time::Duration::from_secs(2),
)
.await;

let url = format!("http://{}", server.addr());

let futs = (0..100).map(|_| {
let client = client.clone();
let url = url.clone();
async move {
let res = client.get(&url).send().await.unwrap();
assert_eq!(res.status(), reqwest::StatusCode::OK);
}
});
futures_util::future::join_all(futs).await;

server.shutdown().await;
}
119 changes: 119 additions & 0 deletions tests/support/delay_server.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
#![cfg(not(target_arch = "wasm32"))]
use std::convert::Infallible;
use std::future::Future;
use std::net;
use std::sync::Arc;
use std::time::Duration;

use futures_util::FutureExt;
use http::{Request, Response};
use hyper::service::service_fn;
use hyper::Body;
use tokio::net::TcpListener;
use tokio::select;
use tokio::sync::oneshot;

/// This server, unlike [`super::server::Server`], allows for delaying the
/// specified amount of time after each TCP connection is established. This is
/// useful for testing the behavior of the client when the server is slow.
///
/// For example, in case of HTTP/2, once the TCP/TLS connection is established,
/// both endpoints are supposed to send a preface and an initial `SETTINGS`
/// frame (See [RFC9113 3.4] for details). What if these frames are delayed for
/// whatever reason? This server allows for testing such scenarios.
///
/// [RFC9113 3.4]: https://www.rfc-editor.org/rfc/rfc9113.html#name-http-2-connection-preface
pub struct Server {
addr: net::SocketAddr,
shutdown_tx: Option<oneshot::Sender<()>>,
server_terminated_rx: oneshot::Receiver<()>,
}

impl Server {
pub async fn new<F1, Fut, F2>(func: F1, apply_config: F2, delay: Duration) -> Self
where
F1: Fn(Request<Body>) -> Fut + Clone + Send + 'static,
Fut: Future<Output = Response<Body>> + Send + 'static,
F2: FnOnce(hyper::server::conn::Http) -> hyper::server::conn::Http + Send + 'static,
{
let (shutdown_tx, shutdown_rx) = oneshot::channel();
let (server_terminated_tx, server_terminated_rx) = oneshot::channel();

let tcp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = tcp_listener.local_addr().unwrap();

tokio::spawn(async move {
let http = Arc::new(apply_config(hyper::server::conn::Http::new()));

tokio::spawn(async move {
let (connection_shutdown_tx, connection_shutdown_rx) = oneshot::channel();
let connection_shutdown_rx = connection_shutdown_rx.shared();
let mut shutdown_rx = std::pin::pin!(shutdown_rx);

let mut handles = Vec::new();
loop {
select! {
_ = shutdown_rx.as_mut() => {
connection_shutdown_tx.send(()).unwrap();
break;
}
res = tcp_listener.accept() => {
let (stream, _) = res.unwrap();


let handle = tokio::spawn({
let connection_shutdown_rx = connection_shutdown_rx.clone();
let http = http.clone();
let func = func.clone();

async move {
tokio::time::sleep(delay).await;

let mut conn = std::pin::pin!(http.serve_connection(
stream,
service_fn(move |req| {
let fut = func(req);
async move {
Ok::<_, Infallible>(fut.await)
}})
));

select! {
_ = conn.as_mut() => {}
_ = connection_shutdown_rx => {
conn.as_mut().graceful_shutdown();
conn.await.unwrap();
}
}
}
});

handles.push(handle);
}
}
}

futures_util::future::join_all(handles).await;
server_terminated_tx.send(()).unwrap();
});
});

Self {
addr,
shutdown_tx: Some(shutdown_tx),
server_terminated_rx,
}
}

pub async fn shutdown(mut self) {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}

self.server_terminated_rx.await.unwrap();
}

pub fn addr(&self) -> net::SocketAddr {
self.addr
}
}
1 change: 1 addition & 0 deletions tests/support/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub mod delay_server;
pub mod server;

// TODO: remove once done converting to new support server?
Expand Down
42 changes: 26 additions & 16 deletions tests/support/server.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
#![cfg(not(target_arch = "wasm32"))]
use std::convert::Infallible;
use std::convert::{identity, Infallible};
use std::future::Future;
use std::net;
use std::sync::mpsc as std_mpsc;
use std::thread;
use std::time::Duration;

use tokio::sync::oneshot;

pub use http::Response;
use hyper::server::conn::AddrIncoming;
use tokio::runtime;
use tokio::sync::oneshot;

pub struct Server {
addr: net::SocketAddr,
Expand Down Expand Up @@ -42,24 +41,35 @@ where
F: Fn(http::Request<hyper::Body>) -> Fut + Clone + Send + 'static,
Fut: Future<Output = http::Response<hyper::Body>> + Send + 'static,
{
//Spawn new runtime in thread to prevent reactor execution context conflict
http_with_config(func, identity)
}

pub fn http_with_config<F1, Fut, F2>(func: F1, apply_config: F2) -> Server
where
F1: Fn(http::Request<hyper::Body>) -> Fut + Clone + Send + 'static,
Fut: Future<Output = http::Response<hyper::Body>> + Send + 'static,
F2: FnOnce(hyper::server::Builder<AddrIncoming>) -> hyper::server::Builder<AddrIncoming>
+ Send
+ 'static,
{
// Spawn new runtime in thread to prevent reactor execution context conflict
thread::spawn(move || {
let rt = runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("new rt");
let srv = rt.block_on(async move {
hyper::Server::bind(&([127, 0, 0, 1], 0).into()).serve(hyper::service::make_service_fn(
move |_| {
let func = func.clone();
async move {
Ok::<_, Infallible>(hyper::service::service_fn(move |req| {
let fut = func(req);
async move { Ok::<_, Infallible>(fut.await) }
}))
}
},
))
let builder = hyper::Server::bind(&([127, 0, 0, 1], 0).into());

apply_config(builder).serve(hyper::service::make_service_fn(move |_| {
let func = func.clone();
async move {
Ok::<_, Infallible>(hyper::service::service_fn(move |req| {
let fut = func(req);
async move { Ok::<_, Infallible>(fut.await) }
}))
}
}))
});

let addr = srv.local_addr();
Expand Down

0 comments on commit 4ab5fb0

Please sign in to comment.