diff --git a/tests/integration_tests/tests/http2_max_header_list_size.rs b/tests/integration_tests/tests/http2_max_header_list_size.rs new file mode 100644 index 000000000..2909bf048 --- /dev/null +++ b/tests/integration_tests/tests/http2_max_header_list_size.rs @@ -0,0 +1,70 @@ +use std::time::Duration; + +use integration_tests::pb::{test_client, test_server, Input, Output}; +use tokio::sync::oneshot; +use tonic::{ + transport::{Endpoint, Server}, + Request, Response, Status, +}; + +/// This test checks that the max header list size is respected, and that +/// it allows for error messages up to that size. +#[tokio::test] +async fn test_http_max_header_list_size_and_long_errors() { + struct Svc; + + // The default value is 16k. + const N: usize = 20_000; + + fn long_message() -> String { + "a".repeat(N) + } + + #[tonic::async_trait] + impl test_server::Test for Svc { + async fn unary_call(&self, _: Request) -> Result, Status> { + Err(Status::internal(long_message())) + } + } + + let svc = test_server::TestServer::new(Svc); + + let (tx, rx) = oneshot::channel::<()>(); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = format!("http://{}", listener.local_addr().unwrap()); + + let jh = tokio::spawn(async move { + let (nodelay, keepalive) = (true, Some(Duration::from_secs(1))); + let listener = + tonic::transport::server::TcpIncoming::from_listener(listener, nodelay, keepalive) + .unwrap(); + Server::builder() + .http2_max_pending_accept_reset_streams(Some(0)) + .add_service(svc) + .serve_with_incoming_shutdown(listener, async { drop(rx.await) }) + .await + .unwrap(); + }); + + tokio::time::sleep(Duration::from_millis(100)).await; + + let channel = Endpoint::from_shared(addr) + .unwrap() + // Increase the max header list size to 2x the default value. If this is + // not set, this test will hang. See + // . + .http2_max_header_list_size(u32::try_from(N * 2).unwrap()) + .connect() + .await + .unwrap(); + + let mut client = test_client::TestClient::new(channel); + + let err = client.unary_call(Request::new(Input {})).await.unwrap_err(); + + assert_eq!(err.message(), long_message()); + + tx.send(()).unwrap(); + + jh.await.unwrap(); +} diff --git a/tonic/src/transport/channel/endpoint.rs b/tonic/src/transport/channel/endpoint.rs index e8e8b957f..37ace4f12 100644 --- a/tonic/src/transport/channel/endpoint.rs +++ b/tonic/src/transport/channel/endpoint.rs @@ -33,6 +33,7 @@ pub struct Endpoint { pub(crate) http2_keep_alive_interval: Option, pub(crate) http2_keep_alive_timeout: Option, pub(crate) http2_keep_alive_while_idle: Option, + pub(crate) http2_max_header_list_size: Option, pub(crate) connect_timeout: Option, pub(crate) http2_adaptive_window: Option, pub(crate) executor: SharedExec, @@ -290,6 +291,14 @@ impl Endpoint { } } + /// Sets the max header list size. Uses `hyper`'s default otherwise. + pub fn http2_max_header_list_size(self, size: u32) -> Self { + Endpoint { + http2_max_header_list_size: Some(size), + ..self + } + } + /// Sets the executor used to spawn async tasks. /// /// Uses `tokio::spawn` by default. @@ -420,6 +429,7 @@ impl From for Endpoint { http2_keep_alive_interval: None, http2_keep_alive_timeout: None, http2_keep_alive_while_idle: None, + http2_max_header_list_size: None, connect_timeout: None, http2_adaptive_window: None, executor: SharedExec::tokio(), diff --git a/tonic/src/transport/channel/service/connection.rs b/tonic/src/transport/channel/service/connection.rs index d85586f60..ea23bb32f 100644 --- a/tonic/src/transport/channel/service/connection.rs +++ b/tonic/src/transport/channel/service/connection.rs @@ -51,6 +51,10 @@ impl Connection { settings.adaptive_window(val); } + if let Some(val) = endpoint.http2_max_header_list_size { + settings.max_header_list_size(val); + } + let stack = ServiceBuilder::new() .layer_fn(|s| { let origin = endpoint.origin.as_ref().unwrap_or(&endpoint.uri).clone();