From 9c1f2f9402d97cfa4c9d9065b5af2fb99d2ef521 Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Wed, 12 Jun 2024 12:49:23 -0700 Subject: [PATCH] Upgrade to Hyper 1.0 & Axum 0.7 (#1670) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Upgrade to hyper 1 and http 1 Upgrades only in Cargo.toml Co-authored-by: Ivan Krivosheev Co-authored-by: Allan Zhang * Convert from hyper::Body to http_body::BoxedBody When appropriate, we replace `hyper::Body` with `http_body::BoxedBody`, a good general purpose replacement for `hyper::Body`. Hyper does provide `hyper::body::Incoming`, but we cannot construct that, so anywhere we might need a body that we can construct (even most Service trait impls) we must use something like `http_body::BoxedBody`. When a service accepts `BoxedBody` and not `Incoming`, this indicates that the service is designed to run in places where it is not adjacent to hyper, for example, after routing (which is managed by Axum) Additionally, http >= 1 requires that extension types are `Clone`, so this bound has been added where appropriate. Co-authored-by: Ivan Krivosheev Co-authored-by: Allan Zhang * Convert tonic::codec::decode to use http >= 1 body types The Body trait has changed (removed `poll_data` and `poll_trailers`, they are now combined in `poll_frame`) and so the codec must be re-written to merge those two methods. Co-authored-by: Ivan Krivosheev * Convert tonic::transport::channel to use http >= 1 body types tonic::transport::channel previously used `hyper::Body` as the response body type. This type no longer exists in hyper >= 1, and so has been converted to a `BoxBody` provided by `http_body_util` designed for interoperability between http crates. Co-authored-by: Ivan Krivosheev Co-authored-by: Allan Zhang * [tests] Convert tonic::codec::prost::tests to use http >= 1 body types The Body trait has changed (removed `poll_data` and `poll_trailers`, they are now combined in `poll_frame`) and so the codec must be re-written to merge those two methods. This also handles the return types which should now be wrapped in `Frame` when appropriate. Co-authored-by: Ivan Krivosheev Co-authored-by: Allan Zhang * Convert tonic::codec::encode to use http >= 1 body types The Body trait has changed (removed `poll_data` and `poll_trailers`, they are now combined in `poll_frame`) and so the codec must be re-written to merge those two methods. Co-authored-by: Ivan Krivosheev Co-authored-by: Allan Zhang * [tests] Convert tonic::service::interceptor::tests to use http >= 1 body types The Body trait has changed (removed `poll_data` and `poll_trailers`, they are now combined in `poll_frame`) and so the codec must be re-written to merge those two methods. This also handles the return types which should now be wrapped in `Frame` when appropriate. Co-authored-by: Ivan Krivosheev Co-authored-by: Allan Zhang * Convert tonic::transport to use http >= 1 body types Here, we must update some body types which are no longer valid. (A) BoxBody no longer has an `empty` method, instead we provide a helper in `tonic::body` for creating an empty boxed body via `http_body_util`. As well, `hyper::Body` is no longer a type, and instead, `hyper::Incoming` is used when directly recieving a Request from hyper, and `BoxBody` is used when the request may have passed through an axum router. In tonic, we prefer `BoxBody` as it allows for services to be used downstream from other components which enforce a specific body type (e.g. Axum), at the cost of making Body streaming opaque. Co-authored-by: Ivan Krivosheev Co-authored-by: Allan Zhang * Convert tonic::transport::server::recover_error to use http >= 1 body types The Body trait has changed (removed `poll_data` and `poll_trailers`, they are now combined in `poll_frame`) and so the codec must be re-written to merge those two methods. Co-authored-by: Ivan Krivosheev Co-authored-by: Ludea * Convert h2c examples to use http >= 1 body types In h2c, when a service is receiving from hyper, it has to accept a `hyper::body::Incoming` in hyper >= 1. Additionally, response bodies must be built from `http_body_util` combinators and become BoxBody objects. * [tests] Convert MergeTrailers body wrapper in interop server The Body trait has changed (removed `poll_data` and `poll_trailers`, they are now combined in `poll_frame`) and so the codec must be re-written to merge those two methods. * [tests] Convert compression tests to use hyper 1 body types The Body trait has changed (removed `poll_data` and `poll_trailers`, they are now combined in `poll_frame`) and so the codec must be re-written to merge those two methods. * [tests] Convert complex_tower_middleware Body for hyper 1 The Body trait has changed (removed `poll_data` and `poll_trailers`, they are now combined in `poll_frame`) and so the codec must be re-written to merge those two methods. * [tests] Convert integration_tests::origin to use http >= 1 body types The Body trait has changed (removed `poll_data` and `poll_trailers`, they are now combined in `poll_frame`) and so the codec must be re-written to merge those two methods. * Convert tonic-web to use http >= 1 body types The Body trait has changed (removed `poll_data` and `poll_trailers`, they are now combined in `poll_frame`) and so the codec must be re-written to merge those two methods. * Adapt for hyper-specific IO traits hyper >= 1 provides its own I/O traits (Read & Write) instead of relying on the equivalent traits from `tokio`. Then, `hyper-util` provides adaptor structs to wrap `tokio` I/O objects and implement the hyper equivalents. Therefore, we update the appropriate bounds to use the hyper traits, and update the I/O objects so that they are wrapped in the tokio to hyper adaptor. Co-authored-by: Ivan Krivosheev Co-authored-by: Allan Zhang * Upgrade axum to 0.7 Axum must be >= 0.7 to support hyper >= 1 Doing this also involves changing the Body type used. Since hyper >= 1 does not provide a generic body type, Axum and tonic both use `BoxBody` to provide a pointer to a Body. This changes the trait bounds required for methods which accept additional Serivces to be run alongside the primary GRPC service, since those will be routed with Axum, and therefore must accept a BoxBody. Co-authored-by: Ivan Krivosheev Co-authored-by: Allan Zhang * Convert service connector for hyper-1.0 Hyper >= 1 no longer includes automatic http2/http1 combined connections, and so we must swtich to the `http2::Builder` type (this is okay, we set http2_only(true) anyhow). As well, hyper >= 1 is generic over executors and does not directly depend on tokio. Since http2 connections can be multiplexed, they require some additional background task to handle sending and receiving requests. Additionally, these background tasks do not natively implement `tower::Service` since hyper >= 1 does not depend on `tower`. Therefore, we re-implement the `SendRequest` task as a tower::Service, so that it can be used within `Connection`, which expects to operate on a tower::Service to serve connections. Co-authored-by: Ivan Krivosheev Co-authored-by: Allan Zhang * Convert hyper::Client to hyper_util::legacy::Client `hyper::Client` has been moved to `hyper_util::legacy::Client` in version 1. Co-authored-by: Ivan Krivosheev Co-authored-by: Allan Zhang * Identify and propogate connect errors hyper::Error no longer provides information about Connect errors, especially since hyper_util now contains the connection implementation, it does not provide a separate error type. Instead, we create an internal Error type which is used in our own connectors, and then checked when figuring out what the gRPC status should be. * Remove hyper::server::conn::AddrStream hyper >= 1 has deprecated all of `hyper::server`, including `AddrStream` Co-authored-by: Ivan Krivosheev Co-authored-by: Allan Zhang Replace hyper::server::Accept hyper::server is deprectaed. Instead, we implement our own TCP-incoming based on the now removed hyper::server::Accept. In order to set `TCP_KEEPALIVE` we require the socket2 crate, since this option is not exposed in the standard library’s API. The implementaiton is inspired by that of hyper v0.14 * [examples] In h2c, replace hyper::Server with an accept loop hyper::Server is deprecated, with no current common replacement. Instead of implementing (or using tonic’s new) full server in here, we write a simple accept loop, which is sufficient to demonstrate the functionality of h2c. * Upgrade tls dependencies hyper-rustls requires version 0.27.0 to support hyper >= 1, bringing a few other tls bumps along. Importantly, we add the “ring” and “tls12” features to use ring as the crypto backend, consistent with previous versions of tonic. A future version of tonic might support selecting backends via features. Co-authored-by: Ivan Krivosheev * Combine trailers when streaming decode body We aren't sure if multiple trailers should even be legal, but if we get multiple trailers in an HTTP body stream, we'll combine them all, to preserve their data. Alternatively we'd have to pick the first or last trailers, and that might lose information. * Tweak imports in transport example Example used `empty_body()`, which is now fully qualified as `tonic::body::empty_body()` to make clear that this is a tonic helper method for creating an empty BoxBody. * Remove commented out code from examples/h2c * tonic-web: avoid copy to vector to base64 encode * tonic-web: Merge subsequent trailer frames Ideally, a body should only return a single trailer frame. If multiple trailers are returned, merge them together. * Comment in tonic::status::find_status_in_source_chain Comment mentions why we choose “Unavailable” for connection errors * Make TowerToHyperService crate-private This also requires vendoring it in the rustls example, which doesn’t use a server type. Making the type crate-private means we can delete some unused methods. * Fixup imports in tonic::transport --------- Co-authored-by: Ivan Krivosheev Co-authored-by: Allan Zhang Co-authored-by: Ludea --- examples/Cargo.toml | 29 +- examples/src/grpc-web/client.rs | 3 +- examples/src/h2c/client.rs | 29 +- examples/src/h2c/server.rs | 54 ++- examples/src/interceptor/server.rs | 1 + examples/src/mock/mock.rs | 3 +- examples/src/tls_rustls/client.rs | 10 +- examples/src/tls_rustls/server.rs | 109 +++++- examples/src/tower/client.rs | 3 +- examples/src/tower/server.rs | 10 +- examples/src/uds/client.rs | 6 +- interop/Cargo.toml | 6 +- interop/src/server.rs | 40 +-- tests/compression/Cargo.toml | 10 +- tests/compression/src/util.rs | 83 +++-- tests/integration_tests/Cargo.toml | 9 +- .../tests/complex_tower_middleware.rs | 11 +- tests/integration_tests/tests/connect_info.rs | 8 +- tests/integration_tests/tests/extensions.rs | 9 +- .../tests/max_message_size.rs | 5 +- tests/integration_tests/tests/origin.rs | 6 +- tests/integration_tests/tests/status.rs | 3 +- tonic-web/Cargo.toml | 8 +- tonic-web/src/call.rs | 139 ++++---- tonic-web/src/layer.rs | 2 +- tonic-web/src/lib.rs | 10 +- tonic-web/src/service.rs | 39 +- tonic-web/tests/integration/Cargo.toml | 5 +- tonic-web/tests/integration/tests/grpc_web.rs | 22 +- tonic/Cargo.toml | 25 +- tonic/benches/decode.rs | 21 +- tonic/src/body.rs | 6 +- tonic/src/codec/decode.rs | 79 +++-- tonic/src/codec/encode.rs | 46 +-- tonic/src/codec/prost.rs | 44 +-- tonic/src/extensions.rs | 2 +- tonic/src/request.rs | 2 +- tonic/src/service/interceptor.rs | 24 +- tonic/src/status.rs | 18 +- tonic/src/transport/channel/endpoint.rs | 16 +- tonic/src/transport/channel/mod.rs | 20 +- tonic/src/transport/mod.rs | 20 +- tonic/src/transport/server/conn.rs | 12 - tonic/src/transport/server/incoming.rs | 70 ++-- tonic/src/transport/server/mod.rs | 334 ++++++++++++++---- tonic/src/transport/server/recover_error.rs | 17 +- tonic/src/transport/service/connection.rs | 127 +++++-- tonic/src/transport/service/connector.rs | 68 ++-- tonic/src/transport/service/discover.rs | 3 +- tonic/src/transport/service/executor.rs | 9 +- tonic/src/transport/service/io.rs | 13 +- tonic/src/transport/service/mod.rs | 1 + tonic/src/transport/service/router.rs | 53 ++- tonic/src/transport/service/tls.rs | 3 +- 54 files changed, 1104 insertions(+), 601 deletions(-) diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 57d05d3e3..deab25fd4 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -271,21 +271,21 @@ routeguide = ["dep:async-stream", "tokio-stream", "dep:rand", "dep:serde", "dep: reflection = ["dep:tonic-reflection"] autoreload = ["tokio-stream/net", "dep:listenfd"] health = ["dep:tonic-health"] -grpc-web = ["dep:tonic-web", "dep:bytes", "dep:http", "dep:hyper", "dep:tracing-subscriber", "dep:tower"] +grpc-web = ["dep:tonic-web", "dep:bytes", "dep:http", "dep:hyper", "dep:hyper-util", "dep:tracing-subscriber", "dep:tower"] tracing = ["dep:tracing", "dep:tracing-subscriber"] -uds = ["tokio-stream/net", "dep:tower", "dep:hyper"] +uds = ["tokio-stream/net", "dep:tower", "dep:hyper", "dep:hyper-util"] streaming = ["tokio-stream", "dep:h2"] -mock = ["tokio-stream", "dep:tower"] -tower = ["dep:hyper", "dep:tower", "dep:http"] +mock = ["tokio-stream", "dep:tower", "dep:hyper-util"] +tower = ["dep:hyper", "dep:hyper-util", "dep:tower", "dep:http"] json-codec = ["dep:serde", "dep:serde_json", "dep:bytes"] compression = ["tonic/gzip"] tls = ["tonic/tls"] -tls-rustls = ["dep:hyper", "dep:hyper-rustls", "dep:tower", "tower-http/util", "tower-http/add-extension", "dep:rustls-pemfile", "dep:tokio-rustls"] +tls-rustls = ["dep:hyper", "dep:hyper-util", "dep:hyper-rustls", "dep:tower", "tower-http/util", "tower-http/add-extension", "dep:rustls-pemfile", "dep:tokio-rustls", "dep:pin-project", "dep:http-body-util"] dynamic-load-balance = ["dep:tower"] timeout = ["tokio/time", "dep:tower"] tls-client-auth = ["tonic/tls"] types = ["dep:tonic-types"] -h2c = ["dep:hyper", "dep:tower", "dep:http"] +h2c = ["dep:hyper", "dep:tower", "dep:http", "dep:hyper-util"] cancellation = ["dep:tokio-util"] full = ["gcp", "routeguide", "reflection", "autoreload", "health", "grpc-web", "tracing", "uds", "streaming", "mock", "tower", "json-codec", "compression", "tls", "tls-rustls", "dynamic-load-balance", "timeout", "tls-client-auth", "types", "cancellation", "h2c"] @@ -311,16 +311,19 @@ serde_json = { version = "1.0", optional = true } tracing = { version = "0.1.16", optional = true } tracing-subscriber = { version = "0.3", features = ["tracing-log", "fmt"], optional = true } prost-types = { version = "0.12", optional = true } -http = { version = "0.2", optional = true } -http-body = { version = "0.4.2", optional = true } -hyper = { version = "0.14", optional = true } +http = { version = "1", optional = true } +http-body = { version = "1", optional = true } +http-body-util = { version = "0.1", optional = true } +hyper = { version = "1", optional = true } +hyper-util = { version = ">=0.1.4, <0.2", optional = true } listenfd = { version = "1.0", optional = true } bytes = { version = "1", optional = true } h2 = { version = "0.3", optional = true } -tokio-rustls = { version = "0.24.0", optional = true } -hyper-rustls = { version = "0.24.0", features = ["http2"], optional = true } -rustls-pemfile = { version = "1", optional = true } -tower-http = { version = "0.4", optional = true } +tokio-rustls = { version = "0.26", optional = true, features = ["ring", "tls12"], default-features = false } +hyper-rustls = { version = "0.27.0", features = ["http2", "ring", "tls12"], optional = true, default-features = false } +rustls-pemfile = { version = "2.0.0", optional = true } +tower-http = { version = "0.5", optional = true } +pin-project = { version = "1.0.11", optional = true } [build-dependencies] tonic-build = { path = "../tonic-build", features = ["prost"] } diff --git a/examples/src/grpc-web/client.rs b/examples/src/grpc-web/client.rs index a16ac674a..fa64dd506 100644 --- a/examples/src/grpc-web/client.rs +++ b/examples/src/grpc-web/client.rs @@ -1,4 +1,5 @@ use hello_world::{greeter_client::GreeterClient, HelloRequest}; +use hyper_util::rt::TokioExecutor; use tonic_web::GrpcWebClientLayer; pub mod hello_world { @@ -8,7 +9,7 @@ pub mod hello_world { #[tokio::main] async fn main() -> Result<(), Box> { // Must use hyper directly... - let client = hyper::Client::builder().build_http(); + let client = hyper_util::client::legacy::Client::builder(TokioExecutor::new()).build_http(); let svc = tower::ServiceBuilder::new() .layer(GrpcWebClientLayer::new()) diff --git a/examples/src/h2c/client.rs b/examples/src/h2c/client.rs index 31076b1ac..b162fcc08 100644 --- a/examples/src/h2c/client.rs +++ b/examples/src/h2c/client.rs @@ -1,7 +1,8 @@ use hello_world::greeter_client::GreeterClient; use hello_world::HelloRequest; use http::Uri; -use hyper::Client; +use hyper_util::client::legacy::Client; +use hyper_util::rt::TokioExecutor; pub mod hello_world { tonic::include_proto!("helloworld"); @@ -11,7 +12,7 @@ pub mod hello_world { async fn main() -> Result<(), Box> { let origin = Uri::from_static("http://[::1]:50051"); let h2c_client = h2c::H2cChannel { - client: Client::new(), + client: Client::builder(TokioExecutor::new()).build_http(), }; let mut client = GreeterClient::with_origin(h2c_client, origin); @@ -33,16 +34,20 @@ mod h2c { task::{Context, Poll}, }; - use hyper::{client::HttpConnector, Client}; - use tonic::body::BoxBody; + use hyper::body::Incoming; + use hyper_util::{ + client::legacy::{connect::HttpConnector, Client}, + rt::TokioExecutor, + }; + use tonic::body::{empty_body, BoxBody}; use tower::Service; pub struct H2cChannel { - pub client: Client, + pub client: Client, } impl Service> for H2cChannel { - type Response = http::Response; + type Response = http::Response; type Error = hyper::Error; type Future = Pin> + Send>>; @@ -60,7 +65,7 @@ mod h2c { let h2c_req = hyper::Request::builder() .uri(origin) .header(http::header::UPGRADE, "h2c") - .body(hyper::Body::empty()) + .body(empty_body()) .unwrap(); let res = client.request(h2c_req).await.unwrap(); @@ -72,11 +77,11 @@ mod h2c { let upgraded_io = hyper::upgrade::on(res).await.unwrap(); // In an ideal world you would somehow cache this connection - let (mut h2_client, conn) = hyper::client::conn::Builder::new() - .http2_only(true) - .handshake(upgraded_io) - .await - .unwrap(); + let (mut h2_client, conn) = + hyper::client::conn::http2::Builder::new(TokioExecutor::new()) + .handshake(upgraded_io) + .await + .unwrap(); tokio::spawn(conn); h2_client.send_request(request).await diff --git a/examples/src/h2c/server.rs b/examples/src/h2c/server.rs index 92d08a417..cf981f957 100644 --- a/examples/src/h2c/server.rs +++ b/examples/src/h2c/server.rs @@ -1,8 +1,13 @@ +use std::net::SocketAddr; + +use hyper_util::rt::{TokioExecutor, TokioIo}; +use hyper_util::server::conn::auto::Builder; +use hyper_util::service::TowerToHyperService; +use tokio::net::TcpListener; use tonic::{transport::Server, Request, Response, Status}; use hello_world::greeter_server::{Greeter, GreeterServer}; use hello_world::{HelloReply, HelloRequest}; -use tower::make::Shared; pub mod hello_world { tonic::include_proto!("helloworld"); @@ -28,28 +33,45 @@ impl Greeter for MyGreeter { #[tokio::main] async fn main() -> Result<(), Box> { - let addr = "[::1]:50051".parse().unwrap(); + let addr: SocketAddr = "[::1]:50051".parse().unwrap(); let greeter = MyGreeter::default(); println!("GreeterServer listening on {}", addr); + let incoming = TcpListener::bind(addr).await?; let svc = Server::builder() .add_service(GreeterServer::new(greeter)) .into_router(); let h2c = h2c::H2c { s: svc }; - let server = hyper::Server::bind(&addr).serve(Shared::new(h2c)); - server.await.unwrap(); - - Ok(()) + loop { + match incoming.accept().await { + Ok((io, _)) => { + let router = h2c.clone(); + tokio::spawn(async move { + let builder = Builder::new(TokioExecutor::new()); + let conn = builder.serve_connection_with_upgrades( + TokioIo::new(io), + TowerToHyperService::new(router), + ); + let _ = conn.await; + }); + } + Err(e) => { + eprintln!("Error accepting connection: {}", e); + } + } + } } mod h2c { use std::pin::Pin; use http::{Request, Response}; - use hyper::Body; + use hyper::body::Incoming; + use hyper_util::{rt::TokioExecutor, service::TowerToHyperService}; + use tonic::{body::empty_body, transport::AxumBoxBody}; use tower::Service; #[derive(Clone)] @@ -59,17 +81,14 @@ mod h2c { type BoxError = Box; - impl Service> for H2c + impl Service> for H2c where - S: Service, Response = Response> - + Clone - + Send - + 'static, + S: Service, Response = Response> + Clone + Send + 'static, S::Future: Send + 'static, S::Error: Into + Sync + Send + 'static, S::Response: Send + 'static, { - type Response = hyper::Response; + type Response = hyper::Response; type Error = hyper::Error; type Future = Pin> + Send>>; @@ -81,20 +100,19 @@ mod h2c { std::task::Poll::Ready(Ok(())) } - fn call(&mut self, mut req: hyper::Request) -> Self::Future { + fn call(&mut self, mut req: hyper::Request) -> Self::Future { let svc = self.s.clone(); Box::pin(async move { tokio::spawn(async move { let upgraded_io = hyper::upgrade::on(&mut req).await.unwrap(); - hyper::server::conn::Http::new() - .http2_only(true) - .serve_connection(upgraded_io, svc) + hyper::server::conn::http2::Builder::new(TokioExecutor::new()) + .serve_connection(upgraded_io, TowerToHyperService::new(svc)) .await .unwrap(); }); - let mut res = hyper::Response::new(hyper::Body::empty()); + let mut res = hyper::Response::new(empty_body()); *res.status_mut() = http::StatusCode::SWITCHING_PROTOCOLS; res.headers_mut().insert( hyper::header::UPGRADE, diff --git a/examples/src/interceptor/server.rs b/examples/src/interceptor/server.rs index 263348a6d..fd0cf462f 100644 --- a/examples/src/interceptor/server.rs +++ b/examples/src/interceptor/server.rs @@ -57,6 +57,7 @@ fn intercept(mut req: Request<()>) -> Result, Status> { Ok(req) } +#[derive(Clone)] struct MyExtension { some_piece_of_data: String, } diff --git a/examples/src/mock/mock.rs b/examples/src/mock/mock.rs index 0d3754921..6c26a6735 100644 --- a/examples/src/mock/mock.rs +++ b/examples/src/mock/mock.rs @@ -1,3 +1,4 @@ +use hyper_util::rt::TokioIo; use tonic::{ transport::{Endpoint, Server, Uri}, Request, Response, Status, @@ -36,7 +37,7 @@ async fn main() -> Result<(), Box> { async move { if let Some(client) = client { - Ok(client) + Ok(TokioIo::new(client)) } else { Err(std::io::Error::new( std::io::ErrorKind::Other, diff --git a/examples/src/tls_rustls/client.rs b/examples/src/tls_rustls/client.rs index 23d6e8130..4c39a2c46 100644 --- a/examples/src/tls_rustls/client.rs +++ b/examples/src/tls_rustls/client.rs @@ -5,7 +5,8 @@ pub mod pb { tonic::include_proto!("/grpc.examples.unaryecho"); } -use hyper::{client::HttpConnector, Uri}; +use hyper::Uri; +use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor}; use pb::{echo_client::EchoClient, EchoRequest}; use tokio_rustls::rustls::{ClientConfig, RootCertStore}; @@ -17,11 +18,10 @@ async fn main() -> Result<(), Box> { let mut roots = RootCertStore::empty(); let mut buf = std::io::BufReader::new(&fd); - let certs = rustls_pemfile::certs(&mut buf)?; - roots.add_parsable_certificates(&certs); + let certs = rustls_pemfile::certs(&mut buf).collect::, _>>()?; + roots.add_parsable_certificates(certs.into_iter()); let tls = ClientConfig::builder() - .with_safe_defaults() .with_root_certificates(roots) .with_no_client_auth(); @@ -47,7 +47,7 @@ async fn main() -> Result<(), Box> { .map_request(|_| Uri::from_static("https://[::1]:50051")) .service(http); - let client = hyper::Client::builder().build(connector); + let client = hyper_util::client::legacy::Client::builder(TokioExecutor::new()).build(connector); // Using `with_origin` will let the codegenerated client set the `scheme` and // `authority` from the porvided `Uri`. diff --git a/examples/src/tls_rustls/server.rs b/examples/src/tls_rustls/server.rs index 82f009344..0fb31f8fa 100644 --- a/examples/src/tls_rustls/server.rs +++ b/examples/src/tls_rustls/server.rs @@ -2,45 +2,52 @@ pub mod pb { tonic::include_proto!("/grpc.examples.unaryecho"); } -use hyper::server::conn::Http; +use http_body_util::BodyExt; +use hyper::server::conn::http2::Builder; +use hyper_util::rt::{TokioExecutor, TokioIo}; use pb::{EchoRequest, EchoResponse}; use std::sync::Arc; use tokio::net::TcpListener; use tokio_rustls::{ - rustls::{Certificate, PrivateKey, ServerConfig}, + rustls::{ + pki_types::{CertificateDer, PrivatePkcs8KeyDer}, + ServerConfig, + }, TlsAcceptor, }; -use tonic::{transport::Server, Request, Response, Status}; +use tonic::{body::BoxBody, transport::Server, Request, Response, Status}; +use tower::{BoxError, ServiceExt}; use tower_http::ServiceBuilderExt; #[tokio::main] async fn main() -> Result<(), Box> { let data_dir = std::path::PathBuf::from_iter([std::env!("CARGO_MANIFEST_DIR"), "data"]); - let certs = { + let certs: Vec> = { let fd = std::fs::File::open(data_dir.join("tls/server.pem"))?; let mut buf = std::io::BufReader::new(&fd); - rustls_pemfile::certs(&mut buf)? + rustls_pemfile::certs(&mut buf) .into_iter() - .map(Certificate) - .collect() + .map(|res| res.map(|cert| cert.to_owned())) + .collect::, _>>()? }; - let key = { + let key: PrivatePkcs8KeyDer<'static> = { let fd = std::fs::File::open(data_dir.join("tls/server.key"))?; let mut buf = std::io::BufReader::new(&fd); - rustls_pemfile::pkcs8_private_keys(&mut buf)? + let key = rustls_pemfile::pkcs8_private_keys(&mut buf) .into_iter() - .map(PrivateKey) .next() - .unwrap() + .unwrap()? + .clone_key(); + + key // let key = std::fs::read(data_dir.join("tls/server.key"))?; // PrivateKey(key) }; let mut tls = ServerConfig::builder() - .with_safe_defaults() .with_no_client_auth() - .with_single_cert(certs, key)?; + .with_single_cert(certs, key.into())?; tls.alpn_protocols = vec![b"h2".to_vec()]; let server = EchoServer::default(); @@ -49,8 +56,7 @@ async fn main() -> Result<(), Box> { .add_service(pb::echo_server::EchoServer::new(server)) .into_service(); - let mut http = Http::new(); - http.http2_only(true); + let http = Builder::new(TokioExecutor::new()); let listener = TcpListener::bind("[::1]:50051").await?; let tls_acceptor = TlsAcceptor::from(Arc::new(tls)); @@ -86,7 +92,9 @@ async fn main() -> Result<(), Box> { .add_extension(Arc::new(ConnInfo { addr, certificates })) .service(svc); - http.serve_connection(conn, svc).await.unwrap(); + http.serve_connection(TokioIo::new(conn), TowerToHyperService::new(svc)) + .await + .unwrap(); }); } } @@ -94,7 +102,7 @@ async fn main() -> Result<(), Box> { #[derive(Debug)] struct ConnInfo { addr: std::net::SocketAddr, - certificates: Vec, + certificates: Vec>, } type EchoResult = Result, Status>; @@ -115,3 +123,70 @@ impl pb::echo_server::Echo for EchoServer { Ok(Response::new(EchoResponse { message })) } } + +/// An adaptor which converts a [`tower::Service`] to a [`hyper::service::Service`]. +/// +/// The [`hyper::service::Service`] trait is used by hyper to handle incoming requests, +/// and does not support the `poll_ready` method that is used by tower services. +/// +/// This is provided here because the equivalent adaptor in hyper-util does not support +/// tonic::body::BoxBody bodies. +#[derive(Debug, Clone)] +struct TowerToHyperService { + service: S, +} + +impl TowerToHyperService { + /// Create a new `TowerToHyperService` from a tower service. + fn new(service: S) -> Self { + Self { service } + } +} + +impl hyper::service::Service> for TowerToHyperService +where + S: tower::Service> + Clone, + S::Error: Into + 'static, +{ + type Response = S::Response; + type Error = BoxError; + type Future = TowerToHyperServiceFuture>; + + fn call(&self, req: hyper::Request) -> Self::Future { + let req = req.map(|incoming| { + incoming + .map_err(|err| Status::from_error(err.into())) + .boxed_unsync() + }); + TowerToHyperServiceFuture { + future: self.service.clone().oneshot(req), + } + } +} + +/// Future returned by [`TowerToHyperService`]. +#[derive(Debug)] +#[pin_project::pin_project] +struct TowerToHyperServiceFuture +where + S: tower::Service, +{ + #[pin] + future: tower::util::Oneshot, +} + +impl std::future::Future for TowerToHyperServiceFuture +where + S: tower::Service, + S::Error: Into + 'static, +{ + type Output = Result; + + #[inline] + fn poll( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + self.project().future.poll(cx).map_err(Into::into) + } +} diff --git a/examples/src/tower/client.rs b/examples/src/tower/client.rs index 0a33fffae..39fec5d47 100644 --- a/examples/src/tower/client.rs +++ b/examples/src/tower/client.rs @@ -44,7 +44,6 @@ mod service { use std::pin::Pin; use std::task::{Context, Poll}; use tonic::body::BoxBody; - use tonic::transport::Body; use tonic::transport::Channel; use tower::Service; @@ -59,7 +58,7 @@ mod service { } impl Service> for AuthSvc { - type Response = Response; + type Response = Response; type Error = Box; #[allow(clippy::type_complexity)] type Future = Pin> + Send>>; diff --git a/examples/src/tower/server.rs b/examples/src/tower/server.rs index cc85d62e5..b7066a1b6 100644 --- a/examples/src/tower/server.rs +++ b/examples/src/tower/server.rs @@ -1,4 +1,3 @@ -use hyper::Body; use std::{ pin::Pin, task::{Context, Poll}, @@ -84,9 +83,12 @@ struct MyMiddleware { type BoxFuture<'a, T> = Pin + Send + 'a>>; -impl Service> for MyMiddleware +impl Service> for MyMiddleware where - S: Service, Response = hyper::Response> + Clone + Send + 'static, + S: Service, Response = hyper::Response> + + Clone + + Send + + 'static, S::Future: Send + 'static, { type Response = S::Response; @@ -97,7 +99,7 @@ where self.inner.poll_ready(cx) } - fn call(&mut self, req: hyper::Request) -> Self::Future { + fn call(&mut self, req: hyper::Request) -> Self::Future { // This is necessary because tonic internally uses `tower::buffer::Buffer`. // See https://github.com/tower-rs/tower/issues/547#issuecomment-767629149 // for details on why this is necessary diff --git a/examples/src/uds/client.rs b/examples/src/uds/client.rs index e78531ac4..9a09e6981 100644 --- a/examples/src/uds/client.rs +++ b/examples/src/uds/client.rs @@ -5,6 +5,7 @@ pub mod hello_world { } use hello_world::{greeter_client::GreeterClient, HelloRequest}; +use hyper_util::rt::TokioIo; #[cfg(unix)] use tokio::net::UnixStream; use tonic::transport::{Endpoint, Uri}; @@ -16,12 +17,13 @@ async fn main() -> Result<(), Box> { // We will ignore this uri because uds do not use it // if your connector does use the uri it will be provided // as the request to the `MakeConnection`. + let channel = Endpoint::try_from("http://[::]:50051")? - .connect_with_connector(service_fn(|_: Uri| { + .connect_with_connector(service_fn(|_: Uri| async { let path = "/tmp/tonic/helloworld"; // Connect to a Uds socket - UnixStream::connect(path) + Ok::<_, std::io::Error>(TokioIo::new(UnixStream::connect(path).await?)) })) .await?; diff --git a/interop/Cargo.toml b/interop/Cargo.toml index a58ef64cf..9a32b2a1d 100644 --- a/interop/Cargo.toml +++ b/interop/Cargo.toml @@ -19,9 +19,9 @@ async-stream = "0.3" strum = {version = "0.26", features = ["derive"]} pico-args = {version = "0.5", features = ["eq-separator"]} console = "0.15" -http = "0.2" -http-body = "0.4.2" -hyper = "0.14" +http = "1" +http-body = "1" +hyper = "1" prost = "0.12" tokio = {version = "1.0", features = ["rt-multi-thread", "time", "macros"]} tokio-stream = "0.1" diff --git a/interop/src/server.rs b/interop/src/server.rs index b32468866..38b1be65e 100644 --- a/interop/src/server.rs +++ b/interop/src/server.rs @@ -1,10 +1,10 @@ use crate::pb::{self, *}; use async_stream::try_stream; -use http::header::{HeaderMap, HeaderName, HeaderValue}; +use http::header::{HeaderName, HeaderValue}; use http_body::Body; use std::future::Future; use std::pin::Pin; -use std::task::{Context, Poll}; +use std::task::{ready, Context, Poll}; use std::time::Duration; use tokio_stream::StreamExt; use tonic::{body::BoxBody, server::NamedService, Code, Request, Response, Status}; @@ -180,9 +180,9 @@ impl EchoHeadersSvc { } } -impl Service> for EchoHeadersSvc +impl Service> for EchoHeadersSvc where - S: Service, Response = http::Response> + Send, + S: Service, Response = http::Response> + Send, S::Future: Send + 'static, { type Response = S::Response; @@ -193,7 +193,7 @@ where Ok(()).into() } - fn call(&mut self, req: http::Request) -> Self::Future { + fn call(&mut self, req: http::Request) -> Self::Future { let echo_header = req.headers().get("x-grpc-test-echo-initial").cloned(); let echo_trailer = req @@ -235,25 +235,19 @@ impl Body for MergeTrailers { type Data = B::Data; type Error = B::Error; - fn poll_data( - mut self: Pin<&mut Self>, + fn poll_frame( + self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { - Pin::new(&mut self.inner).poll_data(cx) - } - - fn poll_trailers( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - Pin::new(&mut self.inner).poll_trailers(cx).map_ok(|h| { - h.map(|mut headers| { - if let Some((key, value)) = &self.trailer { - headers.insert(key.clone(), value.clone()); + ) -> Poll, Self::Error>>> { + let this = self.get_mut(); + let mut frame = ready!(Pin::new(&mut this.inner).poll_frame(cx)?); + if let Some(frame) = frame.as_mut() { + if let Some(trailers) = frame.trailers_mut() { + if let Some((key, value)) = &this.trailer { + trailers.insert(key.clone(), value.clone()); } - - headers - }) - }) + } + } + Poll::Ready(frame.map(Ok)) } } diff --git a/tests/compression/Cargo.toml b/tests/compression/Cargo.toml index 5bc87c829..cf4da321b 100644 --- a/tests/compression/Cargo.toml +++ b/tests/compression/Cargo.toml @@ -8,9 +8,11 @@ version = "0.1.0" [dependencies] bytes = "1" -http = "0.2" -http-body = "0.4" -hyper = "0.14.3" +http = "1" +http-body = "1" +http-body-util = "0.1" +hyper = "1" +hyper-util = "0.1" paste = "1.0.12" pin-project = "1.0" prost = "0.12" @@ -18,7 +20,7 @@ tokio = {version = "1.0", features = ["macros", "rt-multi-thread", "net"]} tokio-stream = "0.1" tonic = {path = "../../tonic", features = ["gzip", "zstd"]} tower = {version = "0.4", features = []} -tower-http = {version = "0.4", features = ["map-response-body", "map-request-body"]} +tower-http = {version = "0.5", features = ["map-response-body", "map-request-body"]} [build-dependencies] tonic-build = {path = "../../tonic-build" } diff --git a/tests/compression/src/util.rs b/tests/compression/src/util.rs index 28fa5d96a..d7e250ce4 100644 --- a/tests/compression/src/util.rs +++ b/tests/compression/src/util.rs @@ -1,6 +1,8 @@ use super::*; -use bytes::Bytes; -use http_body::Body; +use bytes::{Buf, Bytes}; +use http_body::{Body, Frame}; +use http_body_util::BodyExt as _; +use hyper_util::rt::TokioIo; use pin_project::pin_project; use std::{ pin::Pin, @@ -11,6 +13,7 @@ use std::{ task::{ready, Context, Poll}, }; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tonic::body::BoxBody; use tonic::codec::CompressionEncoding; use tonic::transport::{server::Connected, Channel}; use tower_http::map_request_body::MapRequestBodyLayer; @@ -46,29 +49,22 @@ where type Data = B::Data; type Error = B::Error; - fn poll_data( + fn poll_frame( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll, Self::Error>>> { let this = self.project(); let counter: Arc = this.counter.clone(); - match ready!(this.inner.poll_data(cx)) { + match ready!(this.inner.poll_frame(cx)) { Some(Ok(chunk)) => { - println!("response body chunk size = {}", chunk.len()); - counter.fetch_add(chunk.len(), SeqCst); + println!("response body chunk size = {}", frame_data_length(&chunk)); + counter.fetch_add(frame_data_length(&chunk), SeqCst); Poll::Ready(Some(Ok(chunk))) } x => Poll::Ready(x), } } - fn poll_trailers( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - self.project().inner.poll_trailers(cx) - } - fn is_end_stream(&self) -> bool { self.inner.is_end_stream() } @@ -78,28 +74,61 @@ where } } +fn frame_data_length(frame: &http_body::Frame) -> usize { + if let Some(data) = frame.data_ref() { + data.len() + } else { + 0 + } +} + +#[pin_project] +struct ChannelBody { + #[pin] + rx: tokio::sync::mpsc::Receiver>, +} + +impl ChannelBody { + pub fn new() -> (tokio::sync::mpsc::Sender>, Self) { + let (tx, rx) = tokio::sync::mpsc::channel(32); + (tx, Self { rx }) + } +} + +impl Body for ChannelBody +where + T: Buf, +{ + type Data = T; + type Error = tonic::Status; + + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + let frame = ready!(self.project().rx.poll_recv(cx)); + Poll::Ready(frame.map(Ok)) + } +} + #[allow(dead_code)] pub fn measure_request_body_size_layer( bytes_sent_counter: Arc, -) -> MapRequestBodyLayer hyper::Body + Clone> { - MapRequestBodyLayer::new(move |mut body: hyper::Body| { - let (mut tx, new_body) = hyper::Body::channel(); +) -> MapRequestBodyLayer BoxBody + Clone> { + MapRequestBodyLayer::new(move |mut body: BoxBody| { + let (tx, new_body) = ChannelBody::new(); let bytes_sent_counter = bytes_sent_counter.clone(); tokio::spawn(async move { - while let Some(chunk) = body.data().await { + while let Some(chunk) = body.frame().await { let chunk = chunk.unwrap(); - println!("request body chunk size = {}", chunk.len()); - bytes_sent_counter.fetch_add(chunk.len(), SeqCst); - tx.send_data(chunk).await.unwrap(); - } - - if let Some(trailers) = body.trailers().await.unwrap() { - tx.send_trailers(trailers).await.unwrap(); + println!("request body chunk size = {}", frame_data_length(&chunk)); + bytes_sent_counter.fetch_add(frame_data_length(&chunk), SeqCst); + tx.send(chunk).await.unwrap(); } }); - new_body + new_body.boxed_unsync() }) } @@ -110,7 +139,7 @@ pub async fn mock_io_channel(client: tokio::io::DuplexStream) -> Channel { Endpoint::try_from("http://[::]:50051") .unwrap() .connect_with_connector(service_fn(move |_: Uri| { - let client = client.take().unwrap(); + let client = TokioIo::new(client.take().unwrap()); async move { Ok::<_, std::io::Error>(client) } })) .await diff --git a/tests/integration_tests/Cargo.toml b/tests/integration_tests/Cargo.toml index 222d1919c..cfeebf725 100644 --- a/tests/integration_tests/Cargo.toml +++ b/tests/integration_tests/Cargo.toml @@ -17,12 +17,13 @@ tracing-subscriber = {version = "0.3"} [dev-dependencies] async-stream = "0.3" -http = "0.2" -http-body = "0.4" -hyper = "0.14" +http = "1" +http-body = "1" +hyper = "1" +hyper-util = "0.1" tokio-stream = {version = "0.1.5", features = ["net"]} tower = {version = "0.4", features = []} -tower-http = { version = "0.4", features = ["set-header", "trace"] } +tower-http = { version = "0.5", features = ["set-header", "trace"] } tower-service = "0.3" tracing = "0.1" diff --git a/tests/integration_tests/tests/complex_tower_middleware.rs b/tests/integration_tests/tests/complex_tower_middleware.rs index 5d7690be3..b1b669426 100644 --- a/tests/integration_tests/tests/complex_tower_middleware.rs +++ b/tests/integration_tests/tests/complex_tower_middleware.rs @@ -97,17 +97,10 @@ where type Data = B::Data; type Error = BoxError; - fn poll_data( + fn poll_frame( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { - unimplemented!() - } - - fn poll_trailers( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { + ) -> Poll, Self::Error>>> { unimplemented!() } } diff --git a/tests/integration_tests/tests/connect_info.rs b/tests/integration_tests/tests/connect_info.rs index 94fac8221..e87bb858f 100644 --- a/tests/integration_tests/tests/connect_info.rs +++ b/tests/integration_tests/tests/connect_info.rs @@ -51,6 +51,9 @@ async fn getting_connect_info() { #[cfg(unix)] pub mod unix { + use std::io; + + use hyper_util::rt::TokioIo; use tokio::{ net::{UnixListener, UnixStream}, sync::oneshot, @@ -106,7 +109,10 @@ pub mod unix { let path = unix_socket_path.clone(); let channel = Endpoint::try_from("http://[::]:50051") .unwrap() - .connect_with_connector(service_fn(move |_: Uri| UnixStream::connect(path.clone()))) + .connect_with_connector(service_fn(move |_: Uri| { + let path = path.clone(); + async move { Ok::<_, io::Error>(TokioIo::new(UnixStream::connect(path).await?)) } + })) .await .unwrap(); diff --git a/tests/integration_tests/tests/extensions.rs b/tests/integration_tests/tests/extensions.rs index b112f8e66..b2380181d 100644 --- a/tests/integration_tests/tests/extensions.rs +++ b/tests/integration_tests/tests/extensions.rs @@ -1,4 +1,4 @@ -use hyper::{Body, Request as HyperRequest, Response as HyperResponse}; +use hyper::{Request as HyperRequest, Response as HyperResponse}; use integration_tests::{ pb::{test_client, test_server, Input, Output}, BoxFuture, @@ -16,6 +16,7 @@ use tonic::{ }; use tower_service::Service; +#[derive(Clone)] struct ExtensionValue(i32); #[tokio::test] @@ -112,9 +113,9 @@ struct InterceptedService { inner: S, } -impl Service> for InterceptedService +impl Service> for InterceptedService where - S: Service, Response = HyperResponse> + S: Service, Response = HyperResponse> + NamedService + Clone + Send @@ -129,7 +130,7 @@ where self.inner.poll_ready(cx) } - fn call(&mut self, mut req: HyperRequest) -> Self::Future { + fn call(&mut self, mut req: HyperRequest) -> Self::Future { let clone = self.inner.clone(); let mut inner = std::mem::replace(&mut self.inner, clone); diff --git a/tests/integration_tests/tests/max_message_size.rs b/tests/integration_tests/tests/max_message_size.rs index 9ae524dbc..f03699cdf 100644 --- a/tests/integration_tests/tests/max_message_size.rs +++ b/tests/integration_tests/tests/max_message_size.rs @@ -1,5 +1,6 @@ use std::pin::Pin; +use hyper_util::rt::TokioIo; use integration_tests::{ pb::{test1_client, test1_server, Input1, Output1}, trace_init, @@ -163,7 +164,7 @@ async fn response_stream_limit() { async move { if let Some(client) = client { - Ok(client) + Ok(TokioIo::new(client)) } else { Err(std::io::Error::new( std::io::ErrorKind::Other, @@ -332,7 +333,7 @@ async fn max_message_run(case: &TestCase) -> Result<(), Status> { async move { if let Some(client) = client { - Ok(client) + Ok(TokioIo::new(client)) } else { Err(std::io::Error::new( std::io::ErrorKind::Other, diff --git a/tests/integration_tests/tests/origin.rs b/tests/integration_tests/tests/origin.rs index f149dc68d..e41287245 100644 --- a/tests/integration_tests/tests/origin.rs +++ b/tests/integration_tests/tests/origin.rs @@ -76,9 +76,9 @@ struct OriginService { inner: S, } -impl Service> for OriginService +impl Service> for OriginService where - T: Service>, + T: Service>, T::Future: Send + 'static, T::Error: Into>, { @@ -90,7 +90,7 @@ where self.inner.poll_ready(cx).map_err(Into::into) } - fn call(&mut self, req: Request) -> Self::Future { + fn call(&mut self, req: Request) -> Self::Future { assert_eq!(req.uri().host(), Some("docs.rs")); let fut = self.inner.call(req); diff --git a/tests/integration_tests/tests/status.rs b/tests/integration_tests/tests/status.rs index 3fdabcd36..df6bc4b3b 100644 --- a/tests/integration_tests/tests/status.rs +++ b/tests/integration_tests/tests/status.rs @@ -1,5 +1,6 @@ use bytes::Bytes; use http::Uri; +use hyper_util::rt::TokioIo; use integration_tests::mock::MockStream; use integration_tests::pb::{ test_client, test_server, test_stream_client, test_stream_server, Input, InputStream, Output, @@ -183,7 +184,7 @@ async fn status_from_server_stream_with_source() { let channel = Endpoint::try_from("http://[::]:50051") .unwrap() .connect_with_connector_lazy(tower::service_fn(move |_: Uri| async move { - Err::(std::io::Error::new(std::io::ErrorKind::Other, "WTF")) + Err::, _>(std::io::Error::new(std::io::ErrorKind::Other, "WTF")) })); let mut client = test_stream_client::TestStreamClient::new(channel); diff --git a/tonic-web/Cargo.toml b/tonic-web/Cargo.toml index d6649f65c..157605a95 100644 --- a/tonic-web/Cargo.toml +++ b/tonic-web/Cargo.toml @@ -18,14 +18,14 @@ version = "0.11.0" base64 = "0.22" bytes = "1" tokio-stream = "0.1" -http = "0.2" -http-body = "0.4" -hyper = {version = "0.14", default-features = false, features = ["stream"]} +http = "1" +http-body = "1" +http-body-util = "0.1" pin-project = "1" tonic = {version = "0.11", path = "../tonic", default-features = false} tower-service = "0.3" tower-layer = "0.3" -tower-http = { version = "0.4", features = ["cors"] } +tower-http = { version = "0.5", features = ["cors"] } tracing = "0.1" [dev-dependencies] diff --git a/tonic-web/src/call.rs b/tonic-web/src/call.rs index f52087e9e..5f67e4d5c 100644 --- a/tonic-web/src/call.rs +++ b/tonic-web/src/call.rs @@ -5,7 +5,7 @@ use std::task::{ready, Context, Poll}; use base64::Engine as _; use bytes::{Buf, BufMut, Bytes, BytesMut}; use http::{header, HeaderMap, HeaderName, HeaderValue}; -use http_body::{Body, SizeHint}; +use http_body::{Body, Frame, SizeHint}; use pin_project::pin_project; use tokio_stream::Stream; use tonic::Status; @@ -63,9 +63,9 @@ pub struct GrpcWebCall { #[pin] inner: B, buf: BytesMut, + decoded: BytesMut, direction: Direction, encoding: Encoding, - poll_trailers: bool, client: bool, trailers: Option, } @@ -75,9 +75,9 @@ impl Default for GrpcWebCall { Self { inner: Default::default(), buf: Default::default(), + decoded: Default::default(), direction: Direction::Empty, encoding: Encoding::None, - poll_trailers: Default::default(), client: Default::default(), trailers: Default::default(), } @@ -108,9 +108,12 @@ impl GrpcWebCall { (Direction::Encode, Encoding::Base64) => BUFFER_SIZE, _ => 0, }), + decoded: BytesMut::with_capacity(match direction { + Direction::Decode => BUFFER_SIZE, + _ => 0, + }), direction, encoding, - poll_trailers: true, client: true, trailers: None, } @@ -123,9 +126,9 @@ impl GrpcWebCall { (Direction::Encode, Encoding::Base64) => BUFFER_SIZE, _ => 0, }), + decoded: BytesMut::with_capacity(0), direction, encoding, - poll_trailers: true, client: false, trailers: None, } @@ -160,24 +163,37 @@ where B: Body, B::Error: Error, { + // Poll body for data, decoding (e.g. via Base64 if necessary) and returning frames + // to the caller. If the caller is a client, it should look for trailers before + // returning these frames. fn poll_decode( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll, Status>>> { match self.encoding { Encoding::Base64 => loop { if let Some(bytes) = self.as_mut().decode_chunk()? { - return Poll::Ready(Some(Ok(bytes))); + return Poll::Ready(Some(Ok(Frame::data(bytes)))); } let mut this = self.as_mut().project(); - match ready!(this.inner.as_mut().poll_data(cx)) { - Some(Ok(data)) => this.buf.put(data), + match ready!(this.inner.as_mut().poll_frame(cx)) { + Some(Ok(frame)) if frame.is_data() => this.buf.put(frame.into_data().unwrap()), + Some(Ok(frame)) if frame.is_trailers() => { + return Poll::Ready(Some(Err(internal_error( + "malformed base64 request has unencoded trailers", + )))) + } + Some(Ok(_)) => { + return Poll::Ready(Some(Err(internal_error("unexpected frame type")))) + } Some(Err(e)) => return Poll::Ready(Some(Err(internal_error(e)))), None => { return if this.buf.has_remaining() { Poll::Ready(Some(Err(internal_error("malformed base64 request")))) + } else if let Some(trailers) = this.trailers.take() { + Poll::Ready(Some(Ok(Frame::trailers(trailers)))) } else { Poll::Ready(None) } @@ -185,7 +201,7 @@ where } }, - Encoding::None => match ready!(self.project().inner.poll_data(cx)) { + Encoding::None => match ready!(self.project().inner.poll_frame(cx)) { Some(res) => Poll::Ready(Some(res.map_err(internal_error))), None => Poll::Ready(None), }, @@ -195,37 +211,31 @@ where fn poll_encode( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll, Status>>> { let mut this = self.as_mut().project(); - if let Some(mut res) = ready!(this.inner.as_mut().poll_data(cx)) { - if *this.encoding == Encoding::Base64 { - res = res.map(|b| crate::util::base64::STANDARD.encode(b).into()) - } + match ready!(this.inner.as_mut().poll_frame(cx)) { + Some(Ok(frame)) if frame.is_data() => { + let mut res = frame.into_data().unwrap(); - return Poll::Ready(Some(res.map_err(internal_error))); - } - - // this flag is needed because the inner stream never - // returns Poll::Ready(None) when polled for trailers - if *this.poll_trailers { - return match ready!(this.inner.poll_trailers(cx)) { - Ok(Some(map)) => { - let mut frame = make_trailers_frame(map); - - if *this.encoding == Encoding::Base64 { - frame = crate::util::base64::STANDARD.encode(frame).into_bytes(); - } + if *this.encoding == Encoding::Base64 { + res = crate::util::base64::STANDARD.encode(res).into(); + } - *this.poll_trailers = false; - Poll::Ready(Some(Ok(frame.into()))) + Poll::Ready(Some(Ok(Frame::data(res)))) + } + Some(Ok(frame)) if frame.is_trailers() => { + let trailers = frame.into_trailers().expect("must be trailers"); + let mut frame = make_trailers_frame(trailers); + if *this.encoding == Encoding::Base64 { + frame = crate::util::base64::STANDARD.encode(frame).into_bytes(); } - Ok(None) => Poll::Ready(None), - Err(e) => Poll::Ready(Some(Err(internal_error(e)))), - }; + Poll::Ready(Some(Ok(Frame::data(frame.into())))) + } + Some(Ok(_)) => Poll::Ready(Some(Err(internal_error("unexepected frame type")))), + Some(Err(e)) => Poll::Ready(Some(Err(internal_error(e)))), + None => Poll::Ready(None), } - - Poll::Ready(None) } } @@ -237,28 +247,41 @@ where type Data = Bytes; type Error = Status; - fn poll_data( + fn poll_frame( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll, Self::Error>>> { if self.client && self.direction == Direction::Decode { let mut me = self.as_mut(); loop { - let incoming_buf = match ready!(me.as_mut().poll_decode(cx)) { - Some(Ok(incoming_buf)) => incoming_buf, - None => { - // TODO: Consider eofing here? - // Even if the buffer has more data, this will hit the eof branch - // of decode in tonic - return Poll::Ready(None); + match ready!(me.as_mut().poll_decode(cx)) { + Some(Ok(incoming_buf)) if incoming_buf.is_data() => { + me.as_mut() + .project() + .decoded + .put(incoming_buf.into_data().unwrap()); } + Some(Ok(incoming_buf)) if incoming_buf.is_trailers() => { + let trailers = incoming_buf.into_trailers().unwrap(); + match me.as_mut().project().trailers { + Some(current_trailers) => { + current_trailers.extend(trailers); + } + None => { + me.as_mut().project().trailers.replace(trailers); + } + } + continue; + } + Some(Ok(_)) => unreachable!("unexpected frame type"), + None => {} // No more data to decode, time to look for trailers Some(Err(e)) => return Poll::Ready(Some(Err(e))), }; - let buf = &mut me.as_mut().project().buf; - - buf.put(incoming_buf); + // Hold the incoming, decoded data until we have a full message + // or trailers to return. + let buf = me.as_mut().project().decoded; return match find_trailers(&buf[..])? { FindTrailers::Trailer(len) => { @@ -266,20 +289,24 @@ where let msg_buf = buf.copy_to_bytes(len); match decode_trailers_frame(buf.split().freeze()) { Ok(Some(trailers)) => { - self.project().trailers.replace(trailers); + me.as_mut().project().trailers.replace(trailers); } Err(e) => return Poll::Ready(Some(Err(e))), _ => {} } if msg_buf.has_remaining() { - Poll::Ready(Some(Ok(msg_buf))) + Poll::Ready(Some(Ok(Frame::data(msg_buf)))) + } else if let Some(trailers) = me.as_mut().project().trailers.take() { + Poll::Ready(Some(Ok(Frame::trailers(trailers)))) } else { Poll::Ready(None) } } FindTrailers::IncompleteBuf => continue, - FindTrailers::Done(len) => Poll::Ready(Some(Ok(buf.split_to(len).freeze()))), + FindTrailers::Done(len) => { + Poll::Ready(Some(Ok(Frame::data(buf.split_to(len).freeze())))) + } }; } } @@ -291,14 +318,6 @@ where } } - fn poll_trailers( - self: Pin<&mut Self>, - _: &mut Context<'_>, - ) -> Poll>, Self::Error>> { - let trailers = self.project().trailers.take(); - Poll::Ready(Ok(trailers)) - } - fn is_end_stream(&self) -> bool { self.inner.is_end_stream() } @@ -313,10 +332,10 @@ where B: Body, B::Error: Error, { - type Item = Result; + type Item = Result, Status>; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Body::poll_data(self, cx) + self.poll_frame(cx) } } diff --git a/tonic-web/src/layer.rs b/tonic-web/src/layer.rs index 77b03c77e..7834f1990 100644 --- a/tonic-web/src/layer.rs +++ b/tonic-web/src/layer.rs @@ -24,7 +24,7 @@ impl Default for GrpcWebLayer { impl Layer for GrpcWebLayer where - S: Service, Response = http::Response>, + S: Service, Response = http::Response>, S: Send + 'static, S::Future: Send + 'static, S::Error: Into + Send, diff --git a/tonic-web/src/lib.rs b/tonic-web/src/lib.rs index 16e57e19d..50ed8c0a8 100644 --- a/tonic-web/src/lib.rs +++ b/tonic-web/src/lib.rs @@ -127,7 +127,7 @@ type BoxError = Box; /// You can customize the CORS configuration composing the [`GrpcWebLayer`] with the cors layer of your choice. pub fn enable(service: S) -> CorsGrpcWeb where - S: Service, Response = http::Response>, + S: Service, Response = http::Response>, S: Clone + Send + 'static, S::Future: Send + 'static, S::Error: Into + Send, @@ -159,9 +159,9 @@ where #[derive(Debug, Clone)] pub struct CorsGrpcWeb(tower_http::cors::Cors>); -impl Service> for CorsGrpcWeb +impl Service> for CorsGrpcWeb where - S: Service, Response = http::Response>, + S: Service, Response = http::Response>, S: Clone + Send + 'static, S::Future: Send + 'static, S::Error: Into + Send, @@ -169,7 +169,7 @@ where type Response = S::Response; type Error = S::Error; type Future = - > as Service>>::Future; + > as Service>>::Future; fn poll_ready( &mut self, @@ -178,7 +178,7 @@ where self.0.poll_ready(cx) } - fn call(&mut self, req: http::Request) -> Self::Future { + fn call(&mut self, req: http::Request) -> Self::Future { self.0.call(req) } } diff --git a/tonic-web/src/service.rs b/tonic-web/src/service.rs index af4c5276f..da65ba832 100644 --- a/tonic-web/src/service.rs +++ b/tonic-web/src/service.rs @@ -3,7 +3,7 @@ use std::pin::Pin; use std::task::{ready, Context, Poll}; use http::{header, HeaderMap, HeaderValue, Method, Request, Response, StatusCode, Version}; -use hyper::Body; +use http_body_util::BodyExt; use pin_project::pin_project; use tonic::{ body::{empty_body, BoxBody}, @@ -50,7 +50,7 @@ impl GrpcWebService { impl GrpcWebService where - S: Service, Response = Response> + Send + 'static, + S: Service, Response = Response> + Send + 'static, { fn response(&self, status: StatusCode) -> ResponseFuture { ResponseFuture { @@ -66,9 +66,9 @@ where } } -impl Service> for GrpcWebService +impl Service> for GrpcWebService where - S: Service, Response = Response> + Send + 'static, + S: Service, Response = Response> + Send + 'static, S::Future: Send + 'static, S::Error: Into + Send, { @@ -80,7 +80,7 @@ where self.inner.poll_ready(cx) } - fn call(&mut self, req: Request) -> Self::Future { + fn call(&mut self, req: Request) -> Self::Future { match RequestKind::new(req.headers(), req.method(), req.version()) { // A valid grpc-web request, regardless of HTTP version. // @@ -202,7 +202,7 @@ impl<'a> RequestKind<'a> { // Mutating request headers to conform to a gRPC request is not really // necessary for us at this point. We could remove most of these except // maybe for inserting `header::TE`, which tonic should check? -fn coerce_request(mut req: Request, encoding: Encoding) -> Request { +fn coerce_request(mut req: Request, encoding: Encoding) -> Request { req.headers_mut().remove(header::CONTENT_LENGTH); req.headers_mut() @@ -216,8 +216,7 @@ fn coerce_request(mut req: Request, encoding: Encoding) -> Request { HeaderValue::from_static("identity,deflate,gzip"), ); - req.map(|b| GrpcWebCall::request(b, encoding)) - .map(Body::wrap_stream) + req.map(|b| GrpcWebCall::request(b, encoding).boxed_unsync()) } fn coerce_response(res: Response, encoding: Encoding) -> Response { @@ -246,7 +245,7 @@ mod tests { #[derive(Debug, Clone)] struct Svc; - impl tower_service::Service> for Svc { + impl tower_service::Service> for Svc { type Response = Response; type Error = String; type Future = BoxFuture; @@ -255,7 +254,7 @@ mod tests { Poll::Ready(Ok(())) } - fn call(&mut self, _: Request) -> Self::Future { + fn call(&mut self, _: Request) -> Self::Future { Box::pin(async { Ok(Response::new(empty_body())) }) } } @@ -266,15 +265,14 @@ mod tests { mod grpc_web { use super::*; - use http::HeaderValue; use tower_layer::Layer; - fn request() -> Request { + fn request() -> Request { Request::builder() .method(Method::POST) .header(CONTENT_TYPE, GRPC_WEB) .header(ORIGIN, "http://example.com") - .body(Body::empty()) + .body(empty_body()) .unwrap() } @@ -350,13 +348,13 @@ mod tests { mod options { use super::*; - fn request() -> Request { + fn request() -> Request { Request::builder() .method(Method::OPTIONS) .header(ORIGIN, "http://example.com") .header(ACCESS_CONTROL_REQUEST_HEADERS, "x-grpc-web") .header(ACCESS_CONTROL_REQUEST_METHOD, "POST") - .body(Body::empty()) + .body(empty_body()) .unwrap() } @@ -371,13 +369,12 @@ mod tests { mod grpc { use super::*; - use http::HeaderValue; - fn request() -> Request { + fn request() -> Request { Request::builder() .version(Version::HTTP_2) .header(CONTENT_TYPE, GRPC) - .body(Body::empty()) + .body(empty_body()) .unwrap() } @@ -397,7 +394,7 @@ mod tests { let req = Request::builder() .header(CONTENT_TYPE, GRPC) - .body(Body::empty()) + .body(empty_body()) .unwrap(); let res = svc.call(req).await.unwrap(); @@ -425,10 +422,10 @@ mod tests { mod other { use super::*; - fn request() -> Request { + fn request() -> Request { Request::builder() .header(CONTENT_TYPE, "application/text") - .body(Body::empty()) + .body(empty_body()) .unwrap() } diff --git a/tonic-web/tests/integration/Cargo.toml b/tonic-web/tests/integration/Cargo.toml index 5c6d5727e..38fd9ff32 100644 --- a/tonic-web/tests/integration/Cargo.toml +++ b/tonic-web/tests/integration/Cargo.toml @@ -9,7 +9,10 @@ license = "MIT" [dependencies] base64 = "0.22" bytes = "1.0" -hyper = "0.14" +http-body = "1" +http-body-util = "0.1" +hyper = "1" +hyper-util = "0.1" prost = "0.12" tokio = { version = "1", features = ["macros", "rt", "net"] } tokio-stream = { version = "0.1", features = ["net"] } diff --git a/tonic-web/tests/integration/tests/grpc_web.rs b/tonic-web/tests/integration/tests/grpc_web.rs index 3343d754c..2c57f2680 100644 --- a/tonic-web/tests/integration/tests/grpc_web.rs +++ b/tonic-web/tests/integration/tests/grpc_web.rs @@ -2,21 +2,27 @@ use std::net::SocketAddr; use base64::Engine as _; use bytes::{Buf, BufMut, Bytes, BytesMut}; +use http_body_util::{BodyExt as _, Full}; +use hyper::body::Incoming; use hyper::http::{header, StatusCode}; -use hyper::{Body, Client, Method, Request, Uri}; +use hyper::{Method, Request, Uri}; +use hyper_util::client::legacy::Client; +use hyper_util::rt::TokioExecutor; use prost::Message; use tokio::net::TcpListener; use tokio_stream::wrappers::TcpListenerStream; +use tonic::body::BoxBody; use tonic::transport::Server; use integration::pb::{test_server::TestServer, Input, Output}; use integration::Svc; +use tonic::Status; use tonic_web::GrpcWebLayer; #[tokio::test] async fn binary_request() { let server_url = spawn().await; - let client = Client::new(); + let client = Client::builder(TokioExecutor::new()).build_http(); let req = build_request(server_url, "grpc-web", "grpc-web"); let res = client.request(req).await.unwrap(); @@ -39,7 +45,7 @@ async fn binary_request() { #[tokio::test] async fn text_request() { let server_url = spawn().await; - let client = Client::new(); + let client = Client::builder(TokioExecutor::new()).build_http(); let req = build_request(server_url, "grpc-web-text", "grpc-web-text"); let res = client.request(req).await.unwrap(); @@ -102,7 +108,7 @@ fn encode_body() -> Bytes { buf.split_to(len + 5).freeze() } -fn build_request(base_uri: String, content_type: &str, accept: &str) -> Request { +fn build_request(base_uri: String, content_type: &str, accept: &str) -> Request { use header::{ACCEPT, CONTENT_TYPE, ORIGIN}; let request_uri = format!("{}/{}/{}", base_uri, "test.Test", "UnaryCall") @@ -123,12 +129,14 @@ fn build_request(base_uri: String, content_type: &str, accept: &str) -> Request< .header(ORIGIN, "http://example.com") .header(ACCEPT, format!("application/{}", accept)) .uri(request_uri) - .body(Body::from(bytes)) + .body(BoxBody::new( + Full::new(bytes).map_err(|err| Status::internal(err.to_string())), + )) .unwrap() } -async fn decode_body(body: Body, content_type: &str) -> (Output, Bytes) { - let mut body = hyper::body::to_bytes(body).await.unwrap(); +async fn decode_body(body: Incoming, content_type: &str) -> (Output, Bytes) { + let mut body = body.collect().await.unwrap().to_bytes(); if content_type == "application/grpc-web-text+proto" { body = integration::util::base64::STANDARD diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml index 0d2be669f..b795cc5f9 100644 --- a/tonic/Cargo.toml +++ b/tonic/Cargo.toml @@ -37,10 +37,10 @@ transport = [ "dep:axum", "channel", "dep:h2", - "dep:hyper", - "dep:tokio", "tokio?/net", "tokio?/time", + "dep:hyper", "dep:hyper-util", "dep:hyper-timeout", + "dep:socket2", + "dep:tokio", "tokio?/macros", "tokio?/net", "tokio?/time", "dep:tower", - "dep:hyper-timeout", ] channel = [] @@ -51,10 +51,11 @@ channel = [] [dependencies] base64 = "0.22" bytes = "1.0" -http = "0.2" +http = "1" tracing = "0.1" -http-body = "0.4.4" +http-body = "1" +http-body-util = "0.1" percent-encoding = "2.1" pin-project = "1.0.11" tower-layer = "0.3" @@ -68,13 +69,15 @@ async-trait = {version = "0.1.13", optional = true} # transport async-stream = {version = "0.3", optional = true} -h2 = {version = "0.3.24", optional = true} -hyper = {version = "0.14.26", features = ["full"], optional = true} -hyper-timeout = {version = "0.4", optional = true} -tokio = {version = "1.0.1", optional = true} -tokio-stream = "0.1" +h2 = {version = "0.4", optional = true} +hyper = {version = "1", features = ["full"], optional = true} +hyper-util = { version = ">=0.1.4, <0.2", features = ["full"], optional = true } +hyper-timeout = {version = "0.5", optional = true} +socket2 = { version = ">=0.4.7, <0.6.0", optional = true, features = ["all"] } +tokio = {version = "1", default-features = false, optional = true} +tokio-stream = { version = "0.1", features = ["net"] } tower = {version = "0.4.7", default-features = false, features = ["balance", "buffer", "discover", "limit", "load", "make", "timeout", "util"], optional = true} -axum = {version = "0.6.9", default-features = false, optional = true} +axum = {version = "0.7", default-features = false, optional = true} # rustls rustls-pemfile = { version = "2.0", optional = true } diff --git a/tonic/benches/decode.rs b/tonic/benches/decode.rs index 5c7cd0159..22ab6d9d4 100644 --- a/tonic/benches/decode.rs +++ b/tonic/benches/decode.rs @@ -1,6 +1,6 @@ use bencher::{benchmark_group, benchmark_main, Bencher}; use bytes::{Buf, BufMut, Bytes, BytesMut}; -use http_body::Body; +use http_body::{Body, Frame, SizeHint}; use std::{ fmt::{Error, Formatter}, pin::Pin, @@ -58,23 +58,24 @@ impl Body for MockBody { type Data = Bytes; type Error = Status; - fn poll_data( + fn poll_frame( mut self: Pin<&mut Self>, - _: &mut Context<'_>, - ) -> Poll>> { + _cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { if self.data.has_remaining() { let split = std::cmp::min(self.chunk_size, self.data.remaining()); - Poll::Ready(Some(Ok(self.data.split_to(split)))) + Poll::Ready(Some(Ok(Frame::data(self.data.split_to(split))))) } else { Poll::Ready(None) } } - fn poll_trailers( - self: Pin<&mut Self>, - _: &mut Context<'_>, - ) -> Poll, Self::Error>> { - Poll::Ready(Ok(None)) + fn is_end_stream(&self) -> bool { + !self.data.is_empty() + } + + fn size_hint(&self) -> SizeHint { + SizeHint::with_exact(self.data.len() as u64) } } diff --git a/tonic/src/body.rs b/tonic/src/body.rs index ef95eec47..428c0dade 100644 --- a/tonic/src/body.rs +++ b/tonic/src/body.rs @@ -1,9 +1,9 @@ //! HTTP specific body utilities. -use http_body::Body; +use http_body_util::BodyExt; /// A type erased HTTP body used for tonic services. -pub type BoxBody = http_body::combinators::UnsyncBoxBody; +pub type BoxBody = http_body_util::combinators::UnsyncBoxBody; /// Convert a [`http_body::Body`] into a [`BoxBody`]. pub(crate) fn boxed(body: B) -> BoxBody @@ -16,7 +16,7 @@ where /// Create an empty `BoxBody` pub fn empty_body() -> BoxBody { - http_body::Empty::new() + http_body_util::Empty::new() .map_err(|err| match err {}) .boxed_unsync() } diff --git a/tonic/src/codec/decode.rs b/tonic/src/codec/decode.rs index 020551704..2117df8b5 100644 --- a/tonic/src/codec/decode.rs +++ b/tonic/src/codec/decode.rs @@ -2,8 +2,9 @@ use super::compression::{decompress, CompressionEncoding, CompressionSettings}; use super::{BufferSettings, DecodeBuf, Decoder, DEFAULT_MAX_RECV_MESSAGE_SIZE, HEADER_SIZE}; use crate::{body::BoxBody, metadata::MetadataMap, Code, Status}; use bytes::{Buf, BufMut, BytesMut}; -use http::StatusCode; +use http::{HeaderMap, StatusCode}; use http_body::Body; +use http_body_util::BodyExt; use std::{ fmt, future, pin::Pin, @@ -27,7 +28,7 @@ struct StreamingInner { state: State, direction: Direction, buf: BytesMut, - trailers: Option, + trailers: Option, decompress_buf: BytesMut, encoding: Option, max_message_size: Option, @@ -121,7 +122,7 @@ impl Streaming { decoder: Box::new(decoder), inner: StreamingInner { body: body - .map_data(|mut buf| buf.copy_to_bytes(buf.remaining())) + .map_frame(|frame| frame.map_data(|mut buf| buf.copy_to_bytes(buf.remaining()))) .map_err(|err| Status::map_error(err.into())) .boxed_unsync(), state: State::ReadHeader, @@ -239,8 +240,8 @@ impl StreamingInner { } // Returns Some(()) if data was found or None if the loop in `poll_next` should break - fn poll_data(&mut self, cx: &mut Context<'_>) -> Poll, Status>> { - let chunk = match ready!(Pin::new(&mut self.body).poll_data(cx)) { + fn poll_frame(&mut self, cx: &mut Context<'_>) -> Poll, Status>> { + let chunk = match ready!(Pin::new(&mut self.body).poll_frame(cx)) { Some(Ok(d)) => Some(d), Some(Err(status)) => { if self.direction == Direction::Request && status.code() == Code::Cancelled { @@ -254,9 +255,26 @@ impl StreamingInner { None => None, }; - Poll::Ready(if let Some(data) = chunk { - self.buf.put(data); - Ok(Some(())) + Poll::Ready(if let Some(frame) = chunk { + match frame { + frame if frame.is_data() => { + self.buf.put(frame.into_data().unwrap()); + Ok(Some(())) + } + frame if frame.is_trailers() => { + match &mut self.trailers { + Some(trailers) => { + trailers.extend(frame.into_trailers().unwrap()); + } + None => { + self.trailers = Some(frame.into_trailers().unwrap()); + } + } + + Ok(None) + } + frame => panic!("unexpected frame: {:?}", frame), + } } else { // FIXME: improve buf usage. if self.buf.has_remaining() { @@ -271,27 +289,18 @@ impl StreamingInner { }) } - fn poll_response(&mut self, cx: &mut Context<'_>) -> Poll> { + fn response(&mut self) -> Result<(), Status> { if let Direction::Response(status) = self.direction { - match ready!(Pin::new(&mut self.body).poll_trailers(cx)) { - Ok(trailer) => { - if let Err(e) = crate::status::infer_grpc_status(trailer.as_ref(), status) { - if let Some(e) = e { - return Poll::Ready(Err(e)); - } else { - return Poll::Ready(Ok(())); - } - } else { - self.trailers = trailer.map(MetadataMap::from_headers); - } - } - Err(status) => { - debug!("decoder inner trailers error: {:?}", status); - return Poll::Ready(Err(status)); + if let Err(e) = crate::status::infer_grpc_status(self.trailers.as_ref(), status) { + if let Some(e) = e { + // If the trailers contain a grpc-status, then we should return that as the error + // and otherwise stop the stream (by taking the error state) + self.trailers.take(); + return Err(e); } } } - Poll::Ready(Ok(())) + Ok(()) } } @@ -351,7 +360,7 @@ impl Streaming { // Shortcut to see if we already pulled the trailers in the stream step // we need to do that so that the stream can error on trailing grpc-status if let Some(trailers) = self.inner.trailers.take() { - return Ok(Some(trailers)); + return Ok(Some(MetadataMap::from_headers(trailers))); } // To fetch the trailers we must clear the body and drop it. @@ -360,16 +369,11 @@ impl Streaming { // Since we call poll_trailers internally on poll_next we need to // check if it got cached again. if let Some(trailers) = self.inner.trailers.take() { - return Ok(Some(trailers)); + return Ok(Some(MetadataMap::from_headers(trailers))); } - // Trailers were not caught during poll_next and thus lets poll for - // them manually. - let map = future::poll_fn(|cx| Pin::new(&mut self.inner.body).poll_trailers(cx)) - .await - .map_err(|e| Status::from_error(Box::new(e))); - - map.map(|x| x.map(MetadataMap::from_headers)) + // We've polled through all the frames, and still no trailers, return None + Ok(None) } fn decode_chunk(&mut self) -> Result, Status> { @@ -395,20 +399,17 @@ impl Stream for Streaming { return Poll::Ready(None); } - // FIXME: implement the ability to poll trailers when we _know_ that - // the consumer of this stream will only poll for the first message. - // This means we skip the poll_trailers step. if let Some(item) = self.decode_chunk()? { return Poll::Ready(Some(Ok(item))); } - match ready!(self.inner.poll_data(cx))? { + match ready!(self.inner.poll_frame(cx))? { Some(()) => (), None => break, } } - Poll::Ready(match ready!(self.inner.poll_response(cx)) { + Poll::Ready(match self.inner.response() { Ok(()) => None, Err(err) => Some(Err(err)), }) diff --git a/tonic/src/codec/encode.rs b/tonic/src/codec/encode.rs index 88ec9568e..82b4eb61d 100644 --- a/tonic/src/codec/encode.rs +++ b/tonic/src/codec/encode.rs @@ -5,7 +5,7 @@ use super::{BufferSettings, EncodeBuf, Encoder, DEFAULT_MAX_SEND_MESSAGE_SIZE, H use crate::{Code, Status}; use bytes::{BufMut, Bytes, BytesMut}; use http::HeaderMap; -use http_body::Body; +use http_body::{Body, Frame}; use pin_project::pin_project; use std::{ pin::Pin, @@ -298,22 +298,21 @@ where } impl EncodeState { - fn trailers(&mut self) -> Result, Status> { + fn trailers(&mut self) -> Option> { match self.role { - Role::Client => Ok(None), + Role::Client => None, Role::Server => { if self.is_end_stream { - return Ok(None); + return None; } + self.is_end_stream = true; let status = if let Some(status) = self.error.take() { - self.is_end_stream = true; status } else { Status::new(Code::Ok, "") }; - - Ok(Some(status.to_header_map()?)) + Some(status.to_header_map()) } } } @@ -330,38 +329,25 @@ where self.state.is_end_stream } - fn size_hint(&self) -> http_body::SizeHint { - let sh = self.inner.size_hint(); - let mut size_hint = http_body::SizeHint::new(); - size_hint.set_lower(sh.0 as u64); - if let Some(upper) = sh.1 { - size_hint.set_upper(upper as u64); - } - size_hint - } - - fn poll_data( + fn poll_frame( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll, Self::Error>>> { let self_proj = self.project(); match ready!(self_proj.inner.poll_next(cx)) { - Some(Ok(d)) => Some(Ok(d)).into(), + Some(Ok(d)) => Some(Ok(Frame::data(d))).into(), Some(Err(status)) => match self_proj.state.role { Role::Client => Some(Err(status)).into(), Role::Server => { - self_proj.state.error = Some(status); - None.into() + self_proj.state.is_end_stream = true; + Some(Ok(Frame::trailers(status.to_header_map()?))).into() } }, - None => None.into(), + None => self_proj + .state + .trailers() + .map(|t| t.map(Frame::trailers)) + .into(), } } - - fn poll_trailers( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll, Status>> { - Poll::Ready(self.project().state.trailers()) - } } diff --git a/tonic/src/codec/prost.rs b/tonic/src/codec/prost.rs index 4e925ba6e..8237daf32 100644 --- a/tonic/src/codec/prost.rs +++ b/tonic/src/codec/prost.rs @@ -156,6 +156,7 @@ mod tests { use crate::{Code, Status}; use bytes::{Buf, BufMut, BytesMut}; use http_body::Body; + use http_body_util::BodyExt as _; use std::pin::pin; const LEN: usize = 10000; @@ -238,7 +239,7 @@ mod tests { None, )); - while let Some(r) = body.data().await { + while let Some(r) = body.frame().await { r.unwrap(); } } @@ -260,12 +261,15 @@ mod tests { Some(MAX_MESSAGE_SIZE), )); - assert!(body.data().await.is_none()); + let frame = body + .frame() + .await + .expect("at least one frame") + .expect("no error polling frame"); assert_eq!( - body.trailers() - .await - .expect("no error polling trailers") - .expect("some trailers") + frame + .into_trailers() + .expect("got trailers") .get("grpc-status") .expect("grpc-status header"), "11" @@ -292,12 +296,15 @@ mod tests { Some(usize::MAX), )); - assert!(body.data().await.is_none()); + let frame = body + .frame() + .await + .expect("at least one frame") + .expect("no error polling frame"); assert_eq!( - body.trailers() - .await - .expect("no error polling trailers") - .expect("some trailers") + frame + .into_trailers() + .expect("got trailers") .get("grpc-status") .expect("grpc-status header"), "8" @@ -343,7 +350,7 @@ mod tests { mod body { use crate::Status; use bytes::Bytes; - use http_body::Body; + use http_body::{Body, Frame}; use std::{ pin::Pin, task::{Context, Poll}, @@ -374,10 +381,10 @@ mod tests { type Data = Bytes; type Error = Status; - fn poll_data( + fn poll_frame( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll, Self::Error>>> { // every other call to poll_data returns data let should_send = self.count % 2 == 0; let data_len = self.data.len(); @@ -395,18 +402,11 @@ mod tests { }; // make some fake progress self.count += 1; - result + result.map(|opt| opt.map(|res| res.map(|data| Frame::data(data)))) } else { Poll::Ready(None) } } - - fn poll_trailers( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - Poll::Ready(Ok(None)) - } } } } diff --git a/tonic/src/extensions.rs b/tonic/src/extensions.rs index 37d84b87b..32b9ad021 100644 --- a/tonic/src/extensions.rs +++ b/tonic/src/extensions.rs @@ -24,7 +24,7 @@ impl Extensions { /// If a extension of this type already existed, it will /// be returned. #[inline] - pub fn insert(&mut self, val: T) -> Option { + pub fn insert(&mut self, val: T) -> Option { self.inner.insert(val) } diff --git a/tonic/src/request.rs b/tonic/src/request.rs index a27a7070c..f2cca7c74 100644 --- a/tonic/src/request.rs +++ b/tonic/src/request.rs @@ -313,6 +313,7 @@ impl Request { /// ```no_run /// use tonic::{Request, service::interceptor}; /// + /// #[derive(Clone)] // Extensions must be Clone /// struct MyExtension { /// some_piece_of_data: String, /// } @@ -440,7 +441,6 @@ pub(crate) enum SanitizeHeaders { #[cfg(test)] mod tests { use super::*; - use crate::metadata::MetadataValue; use http::Uri; #[test] diff --git a/tonic/src/service/interceptor.rs b/tonic/src/service/interceptor.rs index cadff466f..ebe78093d 100644 --- a/tonic/src/service/interceptor.rs +++ b/tonic/src/service/interceptor.rs @@ -232,11 +232,8 @@ where mod tests { #[allow(unused_imports)] use super::*; - use http::header::HeaderMap; - use std::{ - pin::Pin, - task::{Context, Poll}, - }; + use http_body::Frame; + use http_body_util::Empty; use tower::ServiceExt; #[derive(Debug, Default)] @@ -246,19 +243,12 @@ mod tests { type Data = Bytes; type Error = Status; - fn poll_data( + fn poll_frame( self: Pin<&mut Self>, _cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll, Self::Error>>> { Poll::Ready(None) } - - fn poll_trailers( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - Poll::Ready(Ok(None)) - } } #[tokio::test] @@ -318,17 +308,17 @@ mod tests { #[tokio::test] async fn doesnt_change_http_method() { - let svc = tower::service_fn(|request: http::Request| async move { + let svc = tower::service_fn(|request: http::Request>| async move { assert_eq!(request.method(), http::Method::OPTIONS); - Ok::<_, hyper::Error>(hyper::Response::new(hyper::Body::empty())) + Ok::<_, hyper::Error>(hyper::Response::new(Empty::new())) }); let svc = InterceptedService::new(svc, Ok); let request = http::Request::builder() .method(http::Method::OPTIONS) - .body(hyper::Body::empty()) + .body(Empty::new()) .unwrap(); svc.oneshot(request).await.unwrap(); diff --git a/tonic/src/status.rs b/tonic/src/status.rs index 108ee3cf2..968693f87 100644 --- a/tonic/src/status.rs +++ b/tonic/src/status.rs @@ -412,13 +412,7 @@ impl Status { // > status. Note that the frequency of PINGs is highly dependent on the network // > environment, implementations are free to adjust PING frequency based on network and // > application requirements, which is why it's mapped to unavailable here. - // - // Likewise, if we are unable to connect to the server, map this to UNAVAILABLE. This is - // consistent with the behavior of a C++ gRPC client when the server is not running, and - // matches the spec of: - // > The service is currently unavailable. This is most likely a transient condition that - // > can be corrected if retried with a backoff. - if err.is_timeout() || err.is_connect() { + if err.is_timeout() { return Some(Status::unavailable(err.to_string())); } @@ -620,6 +614,16 @@ fn find_status_in_source_chain(err: &(dyn Error + 'static)) -> Option { return Some(Status::cancelled(timeout.to_string())); } + // If we are unable to connect to the server, map this to UNAVAILABLE. This is + // consistent with the behavior of a C++ gRPC client when the server is not running, and + // matches the spec of: + // > The service is currently unavailable. This is most likely a transient condition that + // > can be corrected if retried with a backoff. + #[cfg(feature = "transport")] + if let Some(connect) = err.downcast_ref::() { + return Some(Status::unavailable(connect.to_string())); + } + #[cfg(feature = "transport")] if let Some(hyper) = err .downcast_ref::() diff --git a/tonic/src/transport/channel/endpoint.rs b/tonic/src/transport/channel/endpoint.rs index 995e2a15b..6014960a8 100644 --- a/tonic/src/transport/channel/endpoint.rs +++ b/tonic/src/transport/channel/endpoint.rs @@ -7,8 +7,10 @@ use crate::transport::service::TlsConnector; use crate::transport::{service::SharedExec, Error, Executor}; use bytes::Bytes; use http::{uri::Uri, HeaderValue}; +use hyper::rt; +use hyper_util::client::legacy::connect::HttpConnector; use std::{fmt, future::Future, pin::Pin, str::FromStr, time::Duration}; -use tower::make::MakeConnection; +use tower_service::Service; /// Channel builder. /// @@ -332,7 +334,7 @@ impl Endpoint { /// Create a channel from this config. pub async fn connect(&self) -> Result { - let mut http = hyper::client::connect::HttpConnector::new(); + let mut http = HttpConnector::new(); http.enforce_http(false); http.set_nodelay(self.tcp_nodelay); http.set_keepalive(self.tcp_keepalive); @@ -348,7 +350,7 @@ impl Endpoint { /// The channel returned by this method does not attempt to connect to the endpoint until first /// use. pub fn connect_lazy(&self) -> Channel { - let mut http = hyper::client::connect::HttpConnector::new(); + let mut http = HttpConnector::new(); http.enforce_http(false); http.set_nodelay(self.tcp_nodelay); http.set_keepalive(self.tcp_keepalive); @@ -368,8 +370,8 @@ impl Endpoint { /// The [`connect_timeout`](Endpoint::connect_timeout) will still be applied. pub async fn connect_with_connector(&self, connector: C) -> Result where - C: MakeConnection + Send + 'static, - C::Connection: Unpin + Send + 'static, + C: Service + Send + 'static, + C::Response: rt::Read + rt::Write + Send + Unpin + 'static, C::Future: Send + 'static, crate::Error: From + Send + 'static, { @@ -393,8 +395,8 @@ impl Endpoint { /// uses a Unix socket transport. pub fn connect_with_connector_lazy(&self, connector: C) -> Channel where - C: MakeConnection + Send + 'static, - C::Connection: Unpin + Send + 'static, + C: Service + Send + 'static, + C::Response: rt::Read + rt::Write + Send + Unpin + 'static, C::Future: Send + 'static, crate::Error: From + Send + 'static, { diff --git a/tonic/src/transport/channel/mod.rs b/tonic/src/transport/channel/mod.rs index b510a6980..0983725f8 100644 --- a/tonic/src/transport/channel/mod.rs +++ b/tonic/src/transport/channel/mod.rs @@ -17,7 +17,7 @@ use http::{ uri::{InvalidUri, Uri}, Request, Response, }; -use hyper::client::connect::Connection as HyperConnection; +use hyper_util::client::legacy::connect::Connection as HyperConnection; use std::{ fmt, future::Future, @@ -25,11 +25,9 @@ use std::{ pin::Pin, task::{ready, Context, Poll}, }; -use tokio::{ - io::{AsyncRead, AsyncWrite}, - sync::mpsc::{channel, Sender}, -}; +use tokio::sync::mpsc::{channel, Sender}; +use hyper::rt; use tower::balance::p2c::Balance; use tower::{ buffer::{self, Buffer}, @@ -38,13 +36,13 @@ use tower::{ Service, }; -type Svc = Either, Response, crate::Error>>; +type Svc = Either, Response, crate::Error>>; const DEFAULT_BUFFER_SIZE: usize = 1024; /// A default batteries included `transport` channel. /// -/// This provides a fully featured http2 gRPC client based on [`hyper::Client`] +/// This provides a fully featured http2 gRPC client based on `hyper` /// and `tower` services. /// /// # Multiplexing requests @@ -152,7 +150,7 @@ impl Channel { C: Service + Send + 'static, C::Error: Into + Send, C::Future: Unpin + Send, - C::Response: AsyncRead + AsyncWrite + HyperConnection + Unpin + Send + 'static, + C::Response: rt::Read + rt::Write + HyperConnection + Unpin + Send + 'static, { let buffer_size = endpoint.buffer_size.unwrap_or(DEFAULT_BUFFER_SIZE); let executor = endpoint.executor.clone(); @@ -169,7 +167,7 @@ impl Channel { C: Service + Send + 'static, C::Error: Into + Send, C::Future: Unpin + Send, - C::Response: AsyncRead + AsyncWrite + HyperConnection + Unpin + Send + 'static, + C::Response: rt::Read + rt::Write + HyperConnection + Unpin + Send + 'static, { let buffer_size = endpoint.buffer_size.unwrap_or(DEFAULT_BUFFER_SIZE); let executor = endpoint.executor.clone(); @@ -201,7 +199,7 @@ impl Channel { } impl Service> for Channel { - type Response = http::Response; + type Response = http::Response; type Error = super::Error; type Future = ResponseFuture; @@ -217,7 +215,7 @@ impl Service> for Channel { } impl Future for ResponseFuture { - type Output = Result, super::Error>; + type Output = Result, super::Error>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let val = ready!(Pin::new(&mut self.inner).poll(cx)).map_err(super::Error::from_source)?; diff --git a/tonic/src/transport/mod.rs b/tonic/src/transport/mod.rs index 758bdb7d8..767534748 100644 --- a/tonic/src/transport/mod.rs +++ b/tonic/src/transport/mod.rs @@ -1,7 +1,7 @@ //! Batteries included server and client. //! //! This module provides a set of batteries included, fully featured and -//! fast set of HTTP/2 server and client's. These components each provide a or +//! fast set of HTTP/2 server and client's. These components each provide a //! `rustls` tls backend when the respective feature flag is enabled, and //! provides builders to configure transport behavior. //! @@ -38,7 +38,7 @@ //! .connect() //! .await?; //! -//! channel.call(Request::new(BoxBody::empty())).await?; +//! channel.call(Request::new(tonic::body::empty_body())).await?; //! # Ok(()) //! # } //! ``` @@ -46,21 +46,23 @@ //! ## Server //! //! ```no_run +//! # use std::convert::Infallible; //! # #[cfg(feature = "rustls")] //! # use tonic::transport::{Server, Identity, ServerTlsConfig}; +//! # use tonic::body::BoxBody; //! # use tower::Service; //! # #[cfg(feature = "rustls")] //! # async fn do_thing() -> Result<(), Box> { //! # #[derive(Clone)] //! # pub struct Svc; -//! # impl Service> for Svc { -//! # type Response = hyper::Response; -//! # type Error = tonic::Status; +//! # impl Service> for Svc { +//! # type Response = hyper::Response; +//! # type Error = Infallible; //! # type Future = std::future::Ready>; //! # fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> std::task::Poll> { //! # Ok(()).into() //! # } -//! # fn call(&mut self, _req: hyper::Request) -> Self::Future { +//! # fn call(&mut self, _req: hyper::Request) -> Self::Future { //! # unimplemented!() //! # } //! # } @@ -104,11 +106,13 @@ pub use self::error::Error; pub use self::server::Server; #[doc(inline)] pub use self::service::grpc_timeout::TimeoutExpired; +pub(crate) use self::service::ConnectError; + #[cfg(feature = "tls")] #[cfg_attr(docsrs, doc(cfg(feature = "tls")))] pub use self::tls::Certificate; -pub use axum::{body::BoxBody as AxumBoxBody, Router as AxumRouter}; -pub use hyper::{Body, Uri}; +pub use axum::{body::Body as AxumBoxBody, Router as AxumRouter}; +pub use hyper::{body::Body, Uri}; #[cfg(feature = "tls")] pub use tokio_rustls::rustls::pki_types::CertificateDer; diff --git a/tonic/src/transport/server/conn.rs b/tonic/src/transport/server/conn.rs index 122f13baf..49c086a59 100644 --- a/tonic/src/transport/server/conn.rs +++ b/tonic/src/transport/server/conn.rs @@ -1,4 +1,3 @@ -use hyper::server::conn::AddrStream; use std::net::SocketAddr; use tokio::net::TcpStream; @@ -86,17 +85,6 @@ impl TcpConnectInfo { } } -impl Connected for AddrStream { - type ConnectInfo = TcpConnectInfo; - - fn connect_info(&self) -> Self::ConnectInfo { - TcpConnectInfo { - local_addr: Some(self.local_addr()), - remote_addr: Some(self.remote_addr()), - } - } -} - impl Connected for TcpStream { type ConnectInfo = TcpConnectInfo; diff --git a/tonic/src/transport/server/incoming.rs b/tonic/src/transport/server/incoming.rs index bc1bb7650..7f5f76c25 100644 --- a/tonic/src/transport/server/incoming.rs +++ b/tonic/src/transport/server/incoming.rs @@ -1,20 +1,18 @@ use super::{Connected, Server}; use crate::transport::service::ServerIo; -use hyper::server::{ - accept::Accept, - conn::{AddrIncoming, AddrStream}, -}; use std::{ - net::SocketAddr, + net::{SocketAddr, TcpListener as StdTcpListener}, pin::{pin, Pin}, - task::{Context, Poll}, + task::{ready, Context, Poll}, time::Duration, }; use tokio::{ io::{AsyncRead, AsyncWrite}, - net::TcpListener, + net::{TcpListener, TcpStream}, }; +use tokio_stream::wrappers::TcpListenerStream; use tokio_stream::{Stream, StreamExt}; +use tracing::warn; #[cfg(not(feature = "tls"))] pub(crate) fn tcp_incoming( @@ -127,7 +125,9 @@ enum SelectOutput { /// of `AsyncRead + AsyncWrite` that communicate with clients that connect to a socket address. #[derive(Debug)] pub struct TcpIncoming { - inner: AddrIncoming, + inner: TcpListenerStream, + nodelay: bool, + keepalive: Option, } impl TcpIncoming { @@ -139,13 +139,13 @@ impl TcpIncoming { /// ```no_run /// # use tower_service::Service; /// # use http::{request::Request, response::Response}; - /// # use tonic::{body::BoxBody, server::NamedService, transport::{Body, Server, server::TcpIncoming}}; + /// # use tonic::{body::BoxBody, server::NamedService, transport::{Server, server::TcpIncoming}}; /// # use core::convert::Infallible; /// # use std::error::Error; /// # fn main() { } // Cannot have type parameters, hence instead define: /// # fn run(some_service: S) -> Result<(), Box> /// # where - /// # S: Service, Response = Response, Error = Infallible> + NamedService + Clone + Send + 'static, + /// # S: Service, Response = Response, Error = Infallible> + NamedService + Clone + Send + 'static, /// # S::Future: Send + 'static, /// # { /// // Find a free port @@ -167,10 +167,15 @@ impl TcpIncoming { nodelay: bool, keepalive: Option, ) -> Result { - let mut inner = AddrIncoming::bind(&addr)?; - inner.set_nodelay(nodelay); - inner.set_keepalive(keepalive); - Ok(TcpIncoming { inner }) + let std_listener = StdTcpListener::bind(addr)?; + std_listener.set_nonblocking(true)?; + + let inner = TcpListenerStream::new(TcpListener::from_std(std_listener)?); + Ok(Self { + inner, + nodelay, + keepalive, + }) } /// Creates a new `TcpIncoming` from an existing `tokio::net::TcpListener`. @@ -179,18 +184,43 @@ impl TcpIncoming { nodelay: bool, keepalive: Option, ) -> Result { - let mut inner = AddrIncoming::from_listener(listener)?; - inner.set_nodelay(nodelay); - inner.set_keepalive(keepalive); - Ok(TcpIncoming { inner }) + Ok(Self { + inner: TcpListenerStream::new(listener), + nodelay, + keepalive, + }) } } impl Stream for TcpIncoming { - type Item = Result; + type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.inner).poll_accept(cx) + match ready!(Pin::new(&mut self.inner).poll_next(cx)) { + Some(Ok(stream)) => { + set_accepted_socket_options(&stream, self.nodelay, self.keepalive); + Some(Ok(stream)).into() + } + other => Poll::Ready(other), + } + } +} + +// Consistent with hyper-0.14, this function does not return an error. +fn set_accepted_socket_options(stream: &TcpStream, nodelay: bool, keepalive: Option) { + if nodelay { + if let Err(e) = stream.set_nodelay(true) { + warn!("error trying to set TCP nodelay: {}", e); + } + } + + if let Some(timeout) = keepalive { + let sock_ref = socket2::SockRef::from(&stream); + let sock_keepalive = socket2::TcpKeepalive::new().with_time(timeout); + + if let Err(e) = sock_ref.set_tcp_keepalive(&sock_keepalive) { + warn!("error trying to set TCP keepalive: {}", e); + } } } diff --git a/tonic/src/transport/server/mod.rs b/tonic/src/transport/server/mod.rs index 7f2ffde2b..3fa406c77 100644 --- a/tonic/src/transport/server/mod.rs +++ b/tonic/src/transport/server/mod.rs @@ -9,10 +9,13 @@ mod tls; #[cfg(unix)] mod unix; -pub use super::service::Routes; -pub use super::service::RoutesBuilder; +use tokio_stream::StreamExt as _; +use tracing::{debug, trace}; + +pub use super::service::{Routes, RoutesBuilder}; pub use conn::{Connected, TcpConnectInfo}; +use hyper_util::rt::{TokioExecutor, TokioIo}; #[cfg(feature = "tls")] pub use tls::ServerTlsConfig; @@ -35,20 +38,20 @@ use crate::transport::Error; use self::recover_error::RecoverError; use super::service::{GrpcTimeout, ServerIo}; -use crate::body::BoxBody; +use crate::body::{boxed, BoxBody}; use crate::server::NamedService; use bytes::Bytes; use http::{Request, Response}; -use http_body::Body as _; -use hyper::{server::accept, Body}; +use http_body_util::BodyExt; +use hyper::body::Incoming; use pin_project::pin_project; use std::{ convert::Infallible, fmt, - future::{self, Future}, + future::{self, poll_fn, Future}, marker::PhantomData, net::SocketAddr, - pin::Pin, + pin::{pin, Pin}, sync::Arc, task::{ready, Context, Poll}, time::Duration, @@ -59,20 +62,21 @@ use tower::{ layer::util::{Identity, Stack}, layer::Layer, limit::concurrency::ConcurrencyLimitLayer, - util::Either, - Service, ServiceBuilder, + util::{BoxCloneService, Either, Oneshot}, + Service, ServiceBuilder, ServiceExt, }; -type BoxHttpBody = http_body::combinators::UnsyncBoxBody; -type BoxService = tower::util::BoxService, Response, crate::Error>; +type BoxHttpBody = crate::body::BoxBody; +type BoxError = crate::Error; +type BoxService = tower::util::BoxCloneService, Response, crate::Error>; type TraceInterceptor = Arc) -> tracing::Span + Send + Sync + 'static>; const DEFAULT_HTTP2_KEEPALIVE_TIMEOUT_SECS: u64 = 20; /// A default batteries included `transport` server. /// -/// This is a wrapper around [`hyper::Server`] and provides an easy builder -/// pattern style builder [`Server`]. This builder exposes easy configuration parameters +/// This provides an easy builder pattern style builder [`Server`] on top of +/// `hyper` connections. This builder exposes easy configuration parameters /// for providing a fully featured http2 based gRPC server. This should provide /// a very good out of the box http2 server for use with tonic but is also a /// reference implementation that should be a good starting point for anyone @@ -122,7 +126,7 @@ impl Default for Server { } } -/// A stack based `Service` router. +/// A stack based [`Service`] router. #[derive(Debug)] pub struct Router { server: Server, @@ -359,7 +363,7 @@ impl Server { /// route around different services. pub fn add_service(&mut self, svc: S) -> Router where - S: Service, Response = Response, Error = Infallible> + S: Service, Response = Response, Error = Infallible> + NamedService + Clone + Send @@ -380,7 +384,7 @@ impl Server { /// As a result, one cannot use this to toggle between two identically named implementations. pub fn add_optional_service(&mut self, svc: Option) -> Router where - S: Service, Response = Response, Error = Infallible> + S: Service, Response = Response, Error = Infallible> + NamedService + Clone + Send @@ -494,9 +498,11 @@ impl Server { ) -> Result<(), super::Error> where L: Layer, - L::Service: Service, Response = Response> + Clone + Send + 'static, - <>::Service as Service>>::Future: Send + 'static, - <>::Service as Service>>::Error: Into + Send, + L::Service: + Service, Response = Response> + Clone + Send + 'static, + <>::Service as Service>>::Future: Send + 'static, + <>::Service as Service>>::Error: + Into + Send + 'static, I: Stream>, IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static, IO::ConnectInfo: Clone + Send + Sync + 'static, @@ -523,10 +529,8 @@ impl Server { let svc = self.service_builder.service(svc); - let tcp = incoming::tcp_incoming(incoming, self); - let incoming = accept::from_stream::<_, _, crate::Error>(tcp); - - let svc = MakeSvc { + let incoming = incoming::tcp_incoming(incoming, self); + let mut svc = MakeSvc { inner: svc, concurrency_limit, timeout, @@ -534,31 +538,189 @@ impl Server { _io: PhantomData, }; - let server = hyper::Server::builder(incoming) - .http2_only(http2_only) - .http2_initial_connection_window_size(init_connection_window_size) - .http2_initial_stream_window_size(init_stream_window_size) - .http2_max_concurrent_streams(max_concurrent_streams) - .http2_keep_alive_interval(http2_keepalive_interval) - .http2_keep_alive_timeout(http2_keepalive_timeout) - .http2_adaptive_window(http2_adaptive_window.unwrap_or_default()) - .http2_max_pending_accept_reset_streams(http2_max_pending_accept_reset_streams) - .http2_max_frame_size(max_frame_size); - - if let Some(signal) = signal { - server - .serve(svc) - .with_graceful_shutdown(signal) - .await - .map_err(super::Error::from_source)? - } else { - server.serve(svc).await.map_err(super::Error::from_source)?; + let server = { + let mut builder = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new()); + + if http2_only { + builder = builder.http2_only(); + } + + builder + .http2() + .initial_connection_window_size(init_connection_window_size) + .initial_stream_window_size(init_stream_window_size) + .max_concurrent_streams(max_concurrent_streams) + .keep_alive_interval(http2_keepalive_interval) + .keep_alive_timeout(http2_keepalive_timeout) + .adaptive_window(http2_adaptive_window.unwrap_or_default()) + .max_pending_accept_reset_streams(http2_max_pending_accept_reset_streams) + .max_frame_size(max_frame_size); + + builder + }; + + let (signal_tx, signal_rx) = tokio::sync::watch::channel(()); + let signal_tx = Arc::new(signal_tx); + + let graceful = signal.is_some(); + let mut sig = pin!(Fuse { inner: signal }); + let mut incoming = pin!(incoming); + + loop { + tokio::select! { + _ = &mut sig => { + trace!("signal received, shutting down"); + break; + }, + io = incoming.next() => { + let io = match io { + Some(Ok(io)) => io, + Some(Err(e)) => { + trace!("error accepting connection: {:#}", e); + continue; + }, + None => { + break + }, + }; + + trace!("connection accepted"); + + poll_fn(|cx| svc.poll_ready(cx)) + .await + .map_err(super::Error::from_source)?; + + let req_svc = svc + .call(&io) + .await + .map_err(super::Error::from_source)?; + let hyper_svc = TowerToHyperService::new(req_svc); + + serve_connection(io, hyper_svc, server.clone(), graceful.then(|| signal_rx.clone())); + } + } + } + + if graceful { + let _ = signal_tx.send(()); + drop(signal_rx); + trace!( + "waiting for {} connections to close", + signal_tx.receiver_count() + ); + + // Wait for all connections to close + signal_tx.closed().await; } Ok(()) } } +// This is moved to its own function as a way to get around +// https://github.com/rust-lang/rust/issues/102211 +fn serve_connection( + io: ServerIo, + hyper_svc: TowerToHyperService, + builder: ConnectionBuilder, + mut watcher: Option>, +) where + S: Service, Response = Response> + Clone + Send + 'static, + S::Future: Send + 'static, + S::Error: Into + Send, + IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static, + IO::ConnectInfo: Clone + Send + Sync + 'static, +{ + tokio::spawn(async move { + { + let mut sig = pin!(Fuse { + inner: watcher.as_mut().map(|w| w.changed()), + }); + + let mut conn = pin!(builder.serve_connection(TokioIo::new(io), hyper_svc)); + + loop { + tokio::select! { + rv = &mut conn => { + if let Err(err) = rv { + debug!("failed serving connection: {:#}", err); + } + break; + }, + _ = &mut sig => { + conn.as_mut().graceful_shutdown(); + } + } + } + } + + drop(watcher); + trace!("connection closed"); + }); +} + +type ConnectionBuilder = hyper_util::server::conn::auto::Builder; + +/// An adaptor which converts a [`tower::Service`] to a [`hyper::service::Service`]. +/// +/// The [`hyper::service::Service`] trait is used by hyper to handle incoming requests, +/// and does not support the `poll_ready` method that is used by tower services. +#[derive(Debug, Copy, Clone)] +pub(crate) struct TowerToHyperService { + service: S, +} + +impl TowerToHyperService { + /// Create a new `TowerToHyperService` from a tower service. + pub(crate) fn new(service: S) -> Self { + Self { service } + } +} + +impl hyper::service::Service> for TowerToHyperService +where + S: tower_service::Service> + Clone, + S::Error: Into + 'static, +{ + type Response = S::Response; + type Error = super::Error; + type Future = TowerToHyperServiceFuture>; + + fn call(&self, req: Request) -> Self::Future { + let req = req.map(crate::body::boxed); + TowerToHyperServiceFuture { + future: self.service.clone().oneshot(req), + } + } +} + +/// Future returned by [`TowerToHyperService`]. +#[derive(Debug)] +#[pin_project] +pub(crate) struct TowerToHyperServiceFuture +where + S: tower_service::Service, +{ + #[pin] + future: Oneshot, +} + +impl Future for TowerToHyperServiceFuture +where + S: tower_service::Service, + S::Error: Into + 'static, +{ + type Output = Result; + + #[inline] + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.project() + .future + .poll(cx) + .map_err(super::Error::from_source) + } +} + impl Router { pub(crate) fn new(server: Server, routes: Routes) -> Self { Self { server, routes } @@ -569,7 +731,7 @@ impl Router { /// Add a new service to this router. pub fn add_service(mut self, svc: S) -> Self where - S: Service, Response = Response, Error = Infallible> + S: Service, Response = Response, Error = Infallible> + NamedService + Clone + Send @@ -588,7 +750,7 @@ impl Router { #[allow(clippy::type_complexity)] pub fn add_optional_service(mut self, svc: Option) -> Self where - S: Service, Response = Response, Error = Infallible> + S: Service, Response = Response, Error = Infallible> + NamedService + Clone + Send @@ -613,10 +775,12 @@ impl Router { /// [tokio]: https://docs.rs/tokio pub async fn serve(self, addr: SocketAddr) -> Result<(), super::Error> where - L: Layer, - L::Service: Service, Response = Response> + Clone + Send + 'static, - <>::Service as Service>>::Future: Send + 'static, - <>::Service as Service>>::Error: Into + Send, + L: Layer + Clone, + L::Service: + Service, Response = Response> + Clone + Send + 'static, + <>::Service as Service>>::Future: Send + 'static, + <>::Service as Service>>::Error: + Into + Send, ResBody: http_body::Body + Send + 'static, ResBody::Error: Into, { @@ -644,9 +808,11 @@ impl Router { ) -> Result<(), super::Error> where L: Layer, - L::Service: Service, Response = Response> + Clone + Send + 'static, - <>::Service as Service>>::Future: Send + 'static, - <>::Service as Service>>::Error: Into + Send, + L::Service: + Service, Response = Response> + Clone + Send + 'static, + <>::Service as Service>>::Future: Send + 'static, + <>::Service as Service>>::Error: + Into + Send, ResBody: http_body::Body + Send + 'static, ResBody::Error: Into, { @@ -673,9 +839,11 @@ impl Router { IO::ConnectInfo: Clone + Send + Sync + 'static, IE: Into, L: Layer, - L::Service: Service, Response = Response> + Clone + Send + 'static, - <>::Service as Service>>::Future: Send + 'static, - <>::Service as Service>>::Error: Into + Send, + L::Service: + Service, Response = Response> + Clone + Send + 'static, + <>::Service as Service>>::Future: Send + 'static, + <>::Service as Service>>::Error: + Into + Send, ResBody: http_body::Body + Send + 'static, ResBody::Error: Into, { @@ -708,9 +876,11 @@ impl Router { IE: Into, F: Future, L: Layer, - L::Service: Service, Response = Response> + Clone + Send + 'static, - <>::Service as Service>>::Future: Send + 'static, - <>::Service as Service>>::Error: Into + Send, + L::Service: + Service, Response = Response> + Clone + Send + 'static, + <>::Service as Service>>::Future: Send + 'static, + <>::Service as Service>>::Error: + Into + Send, ResBody: http_body::Body + Send + 'static, ResBody::Error: Into, { @@ -723,9 +893,11 @@ impl Router { pub fn into_service(self) -> L::Service where L: Layer, - L::Service: Service, Response = Response> + Clone + Send + 'static, - <>::Service as Service>>::Future: Send + 'static, - <>::Service as Service>>::Error: Into + Send, + L::Service: + Service, Response = Response> + Clone + Send + 'static, + <>::Service as Service>>::Future: Send + 'static, + <>::Service as Service>>::Error: + Into + Send, ResBody: http_body::Body + Send + 'static, ResBody::Error: Into, { @@ -739,14 +911,15 @@ impl fmt::Debug for Server { } } +#[derive(Clone)] struct Svc { inner: S, trace_interceptor: Option, } -impl Service> for Svc +impl Service> for Svc where - S: Service, Response = Response>, + S: Service, Response = Response>, S::Error: Into, ResBody: http_body::Body + Send + 'static, ResBody::Error: Into, @@ -759,7 +932,7 @@ where self.inner.poll_ready(cx).map_err(Into::into) } - fn call(&mut self, mut req: Request) -> Self::Future { + fn call(&mut self, mut req: Request) -> Self::Future { let span = if let Some(trace_interceptor) = &self.trace_interceptor { let (parts, body) = req.into_parts(); let bodyless_request = Request::from_parts(parts, ()); @@ -802,7 +975,7 @@ where let _guard = this.span.enter(); let response: Response = ready!(this.inner.poll(cx)).map_err(Into::into)?; - let response = response.map(|body| body.map_err(Into::into).boxed_unsync()); + let response = response.map(|body| boxed(body.map_err(Into::into))); Poll::Ready(Ok(response)) } } @@ -813,6 +986,7 @@ impl fmt::Debug for Svc { } } +#[derive(Clone)] struct MakeSvc { concurrency_limit: Option, timeout: Option, @@ -824,7 +998,7 @@ struct MakeSvc { impl Service<&ServerIo> for MakeSvc where IO: Connected, - S: Service, Response = Response> + Clone + Send + 'static, + S: Service, Response = Response> + Clone + Send + 'static, S::Future: Send + 'static, S::Error: Into + Send, ResBody: http_body::Body + Send + 'static, @@ -853,8 +1027,8 @@ where .service(svc); let svc = ServiceBuilder::new() - .layer(BoxService::layer()) - .map_request(move |mut request: Request| { + .layer(BoxCloneService::layer()) + .map_request(move |mut request: Request| { match &conn_info { tower::util::Either::A(inner) => { request.extensions_mut().insert(inner.clone()); @@ -885,3 +1059,29 @@ where future::ready(Ok(svc)) } } + +// From `futures-util` crate, borrowed since this is the only dependency tonic requires. +// LICENSE: MIT or Apache-2.0 +// A future which only yields `Poll::Ready` once, and thereafter yields `Poll::Pending`. +#[pin_project] +struct Fuse { + #[pin] + inner: Option, +} + +impl Future for Fuse +where + F: Future, +{ + type Output = F::Output; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.as_mut().project().inner.as_pin_mut() { + Some(fut) => fut.poll(cx).map(|output| { + self.project().inner.set(None); + output + }), + None => Poll::Pending, + } + } +} diff --git a/tonic/src/transport/server/recover_error.rs b/tonic/src/transport/server/recover_error.rs index fdb14a66a..60b0d9a7b 100644 --- a/tonic/src/transport/server/recover_error.rs +++ b/tonic/src/transport/server/recover_error.rs @@ -1,5 +1,6 @@ use crate::Status; use http::Response; +use http_body::Frame; use pin_project::pin_project; use std::{ future::Future, @@ -98,26 +99,16 @@ where type Data = B::Data; type Error = B::Error; - fn poll_data( + fn poll_frame( self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll, Self::Error>>> { match self.project().inner.as_pin_mut() { - Some(b) => b.poll_data(cx), + Some(b) => b.poll_frame(cx), None => Poll::Ready(None), } } - fn poll_trailers( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>> { - match self.project().inner.as_pin_mut() { - Some(b) => b.poll_trailers(cx), - None => Poll::Ready(Ok(None)), - } - } - fn is_end_stream(&self) -> bool { match &self.inner { Some(b) => b.is_end_stream(), diff --git a/tonic/src/transport/service/connection.rs b/tonic/src/transport/service/connection.rs index 46a88dda5..a31c9868b 100644 --- a/tonic/src/transport/service/connection.rs +++ b/tonic/src/transport/service/connection.rs @@ -1,17 +1,16 @@ +use super::SharedExec; use super::{grpc_timeout::GrpcTimeout, reconnect::Reconnect, AddOrigin, UserAgent}; use crate::{ - body::BoxBody, + body::{boxed, BoxBody}, transport::{BoxFuture, Endpoint}, }; use http::Uri; -use hyper::client::conn::Builder; -use hyper::client::connect::Connection as HyperConnection; -use hyper::client::service::Connect as HyperConnect; +use hyper::rt; +use hyper::{client::conn::http2::Builder, rt::Executor}; use std::{ fmt, task::{Context, Poll}, }; -use tokio::io::{AsyncRead, AsyncWrite}; use tower::load::Load; use tower::{ layer::Layer, @@ -21,8 +20,8 @@ use tower::{ }; use tower_service::Service; -pub(crate) type Request = http::Request; -pub(crate) type Response = http::Response; +pub(crate) type Response = http::Response; +pub(crate) type Request = http::Request; pub(crate) struct Connection { inner: BoxService, @@ -34,26 +33,24 @@ impl Connection { C: Service + Send + 'static, C::Error: Into + Send, C::Future: Unpin + Send, - C::Response: AsyncRead + AsyncWrite + HyperConnection + Unpin + Send + 'static, + C::Response: rt::Read + rt::Write + Unpin + Send + 'static, { - let mut settings = Builder::new() - .http2_initial_stream_window_size(endpoint.init_stream_window_size) - .http2_initial_connection_window_size(endpoint.init_connection_window_size) - .http2_only(true) - .http2_keep_alive_interval(endpoint.http2_keep_alive_interval) - .executor(endpoint.executor.clone()) + let mut settings: Builder = Builder::new(endpoint.executor.clone()) + .initial_stream_window_size(endpoint.init_stream_window_size) + .initial_connection_window_size(endpoint.init_connection_window_size) + .keep_alive_interval(endpoint.http2_keep_alive_interval) .clone(); if let Some(val) = endpoint.http2_keep_alive_timeout { - settings.http2_keep_alive_timeout(val); + settings.keep_alive_timeout(val); } if let Some(val) = endpoint.http2_keep_alive_while_idle { - settings.http2_keep_alive_while_idle(val); + settings.keep_alive_while_idle(val); } if let Some(val) = endpoint.http2_adaptive_window { - settings.http2_adaptive_window(val); + settings.adaptive_window(val); } let stack = ServiceBuilder::new() @@ -68,13 +65,13 @@ impl Connection { .option_layer(endpoint.rate_limit.map(|(l, d)| RateLimitLayer::new(l, d))) .into_inner(); - let connector = HyperConnect::new(connector, settings); - let conn = Reconnect::new(connector, endpoint.uri.clone(), is_lazy); + let make_service = + MakeSendRequestService::new(connector, endpoint.executor.clone(), settings); - let inner = stack.layer(conn); + let conn = Reconnect::new(make_service, endpoint.uri.clone(), is_lazy); Self { - inner: BoxService::new(inner), + inner: BoxService::new(stack.layer(conn)), } } @@ -83,7 +80,7 @@ impl Connection { C: Service + Send + 'static, C::Error: Into + Send, C::Future: Unpin + Send, - C::Response: AsyncRead + AsyncWrite + HyperConnection + Unpin + Send + 'static, + C::Response: rt::Read + rt::Write + Unpin + Send + 'static, { Self::new(connector, endpoint, false).ready_oneshot().await } @@ -93,7 +90,7 @@ impl Connection { C: Service + Send + 'static, C::Error: Into + Send, C::Future: Unpin + Send, - C::Response: AsyncRead + AsyncWrite + HyperConnection + Unpin + Send + 'static, + C::Response: rt::Read + rt::Write + Unpin + Send + 'static, { Self::new(connector, endpoint, true) } @@ -126,3 +123,87 @@ impl fmt::Debug for Connection { f.debug_struct("Connection").finish() } } + +struct SendRequest { + inner: hyper::client::conn::http2::SendRequest, +} + +impl From> for SendRequest { + fn from(inner: hyper::client::conn::http2::SendRequest) -> Self { + Self { inner } + } +} + +impl tower::Service> for SendRequest { + type Response = http::Response; + type Error = crate::Error; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx).map_err(Into::into) + } + + fn call(&mut self, req: Request) -> Self::Future { + let fut = self.inner.send_request(req); + + Box::pin(async move { + fut.await + .map_err(Into::into) + .map(|res| res.map(|body| boxed(body))) + }) + } +} + +struct MakeSendRequestService { + connector: C, + executor: super::SharedExec, + settings: Builder, +} + +impl MakeSendRequestService { + fn new(connector: C, executor: SharedExec, settings: Builder) -> Self { + Self { + connector, + executor, + settings, + } + } +} + +impl tower::Service for MakeSendRequestService +where + C: Service + Send + 'static, + C::Error: Into + Send, + C::Future: Unpin + Send, + C::Response: rt::Read + rt::Write + Unpin + Send + 'static, +{ + type Response = SendRequest; + type Error = crate::Error; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.connector.poll_ready(cx).map_err(Into::into) + } + + fn call(&mut self, req: Uri) -> Self::Future { + let fut = self.connector.call(req); + let builder = self.settings.clone(); + let executor = self.executor.clone(); + + Box::pin(async move { + let io = fut.await.map_err(Into::into)?; + let (send_request, conn) = builder.handshake(io).await?; + + Executor::>::execute( + &executor, + Box::pin(async move { + if let Err(e) = conn.await { + tracing::debug!("connection task error: {:?}", e); + } + }) as _, + ); + + Ok(SendRequest::from(send_request)) + }) + } +} diff --git a/tonic/src/transport/service/connector.rs b/tonic/src/transport/service/connector.rs index 12336813a..978441d75 100644 --- a/tonic/src/transport/service/connector.rs +++ b/tonic/src/transport/service/connector.rs @@ -3,12 +3,32 @@ use super::io::BoxedIo; #[cfg(feature = "tls")] use super::tls::TlsConnector; use http::Uri; -#[cfg(feature = "tls")] use std::fmt; use std::task::{Context, Poll}; -use tower::make::MakeConnection; + +use hyper::rt; + +#[cfg(feature = "tls")] +use hyper_util::rt::TokioIo; use tower_service::Service; +/// Wrapper type to indicate that an error occurs during the connection +/// process, so that the appropriate gRPC Status can be inferred. +#[derive(Debug)] +pub(crate) struct ConnectError(pub(crate) crate::Error); + +impl fmt::Display for ConnectError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&self.0, f) + } +} + +impl std::error::Error for ConnectError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + Some(self.0.as_ref()) + } +} + pub(crate) struct Connector { inner: C, #[cfg(feature = "tls")] @@ -51,17 +71,19 @@ impl Connector { impl Service for Connector where - C: MakeConnection, - C::Connection: Unpin + Send + 'static, + C: Service, + C::Response: rt::Read + rt::Write + Unpin + Send + 'static, C::Future: Send + 'static, crate::Error: From + Send + 'static, { type Response = BoxedIo; - type Error = crate::Error; + type Error = ConnectError; type Future = BoxFuture<'static, Result>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - MakeConnection::poll_ready(&mut self.inner, cx).map_err(Into::into) + self.inner + .poll_ready(cx) + .map_err(|err| ConnectError(From::from(err))) } fn call(&mut self, uri: Uri) -> Self::Future { @@ -73,26 +95,30 @@ where #[cfg(feature = "tls")] let is_https = uri.scheme_str() == Some("https"); - let connect = self.inner.make_connection(uri); + let connect = self.inner.call(uri); Box::pin(async move { - let io = connect.await?; - - #[cfg(feature = "tls")] - { - if let Some(tls) = tls { - if is_https { - let conn = tls.connect(io).await?; - return Ok(BoxedIo::new(conn)); - } else { - return Ok(BoxedIo::new(io)); + async { + let io = connect.await?; + + #[cfg(feature = "tls")] + { + if let Some(tls) = tls { + return if is_https { + let io = tls.connect(TokioIo::new(io)).await?; + Ok(io) + } else { + Ok(BoxedIo::new(io)) + }; + } else if is_https { + return Err(HttpsUriWithoutTlsSupport(()).into()); } - } else if is_https { - return Err(HttpsUriWithoutTlsSupport(()).into()); } - } - Ok(BoxedIo::new(io)) + Ok::<_, crate::Error>(BoxedIo::new(io)) + } + .await + .map_err(|err| ConnectError(From::from(err))) }) } } diff --git a/tonic/src/transport/service/discover.rs b/tonic/src/transport/service/discover.rs index 2d23ca74c..b9356110e 100644 --- a/tonic/src/transport/service/discover.rs +++ b/tonic/src/transport/service/discover.rs @@ -1,6 +1,7 @@ use super::connection::Connection; use crate::transport::Endpoint; +use hyper_util::client::legacy::connect::HttpConnector; use std::{ hash::Hash, pin::Pin, @@ -32,7 +33,7 @@ impl Stream for DynamicServiceStream { Poll::Pending | Poll::Ready(None) => Poll::Pending, Poll::Ready(Some(change)) => match change { Change::Insert(k, endpoint) => { - let mut http = hyper::client::connect::HttpConnector::new(); + let mut http = HttpConnector::new(); http.set_nodelay(endpoint.tcp_nodelay); http.set_keepalive(endpoint.tcp_keepalive); http.set_connect_timeout(endpoint.connect_timeout); diff --git a/tonic/src/transport/service/executor.rs b/tonic/src/transport/service/executor.rs index de3cfbe6e..7b699c307 100644 --- a/tonic/src/transport/service/executor.rs +++ b/tonic/src/transport/service/executor.rs @@ -36,8 +36,11 @@ impl SharedExec { } } -impl Executor> for SharedExec { - fn execute(&self, fut: BoxFuture<'static, ()>) { - self.inner.execute(fut) +impl Executor for SharedExec +where + F: Future + Send + 'static, +{ + fn execute(&self, fut: F) { + self.inner.execute(Box::pin(fut)) } } diff --git a/tonic/src/transport/service/io.rs b/tonic/src/transport/service/io.rs index 2230b9b2e..cb2296cac 100644 --- a/tonic/src/transport/service/io.rs +++ b/tonic/src/transport/service/io.rs @@ -1,5 +1,6 @@ use crate::transport::server::Connected; -use hyper::client::connect::{Connected as HyperConnected, Connection}; +use hyper::rt; +use hyper_util::client::legacy::connect::{Connected as HyperConnected, Connection}; use std::io; use std::io::IoSlice; use std::pin::Pin; @@ -9,11 +10,11 @@ use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio_rustls::server::TlsStream; pub(in crate::transport) trait Io: - AsyncRead + AsyncWrite + Send + 'static + rt::Read + rt::Write + Send + 'static { } -impl Io for T where T: AsyncRead + AsyncWrite + Send + 'static {} +impl Io for T where T: rt::Read + rt::Write + Send + 'static {} pub(crate) struct BoxedIo(Pin>); @@ -40,17 +41,17 @@ impl Connected for BoxedIo { #[derive(Copy, Clone)] pub(crate) struct NoneConnectInfo; -impl AsyncRead for BoxedIo { +impl rt::Read for BoxedIo { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, + buf: rt::ReadBufCursor<'_>, ) -> Poll> { Pin::new(&mut self.0).poll_read(cx, buf) } } -impl AsyncWrite for BoxedIo { +impl rt::Write for BoxedIo { fn poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, diff --git a/tonic/src/transport/service/mod.rs b/tonic/src/transport/service/mod.rs index 69d850f10..2b2a84070 100644 --- a/tonic/src/transport/service/mod.rs +++ b/tonic/src/transport/service/mod.rs @@ -13,6 +13,7 @@ mod user_agent; pub(crate) use self::add_origin::AddOrigin; pub(crate) use self::connection::Connection; +pub(crate) use self::connector::ConnectError; pub(crate) use self::connector::Connector; pub(crate) use self::discover::DynamicServiceStream; pub(crate) use self::executor::SharedExec; diff --git a/tonic/src/transport/service/router.rs b/tonic/src/transport/service/router.rs index 85636c4d4..c43782ba9 100644 --- a/tonic/src/transport/service/router.rs +++ b/tonic/src/transport/service/router.rs @@ -1,9 +1,9 @@ use crate::{ body::{boxed, BoxBody}, server::NamedService, + transport::BoxFuture, }; use http::{Request, Response}; -use hyper::Body; use pin_project::pin_project; use std::{ convert::Infallible, @@ -12,7 +12,6 @@ use std::{ pin::Pin, task::{ready, Context, Poll}, }; -use tower::ServiceExt; use tower_service::Service; /// A [`Service`] router. @@ -31,7 +30,7 @@ impl RoutesBuilder { /// Add a new service. pub fn add_service(&mut self, svc: S) -> &mut Self where - S: Service, Response = Response, Error = Infallible> + S: Service, Response = Response, Error = Infallible> + NamedService + Clone + Send @@ -53,7 +52,7 @@ impl Routes { /// Create a new routes with `svc` already added to it. pub fn new(svc: S) -> Self where - S: Service, Response = Response, Error = Infallible> + S: Service, Response = Response, Error = Infallible> + NamedService + Clone + Send @@ -68,7 +67,7 @@ impl Routes { /// Add a new service. pub fn add_service(mut self, svc: S) -> Self where - S: Service, Response = Response, Error = Infallible> + S: Service, Response = Response, Error = Infallible> + NamedService + Clone + Send @@ -76,10 +75,10 @@ impl Routes { S::Future: Send + 'static, S::Error: Into + Send, { - let svc = svc.map_response(|res| res.map(axum::body::boxed)); - self.router = self - .router - .route_service(&format!("/{}/*rest", S::NAME), svc); + self.router = self.router.route_service( + &format!("/{}/*rest", S::NAME), + AxumBodyService { service: svc }, + ); self } @@ -103,7 +102,7 @@ async fn unimplemented() -> impl axum::response::IntoResponse { (status, headers) } -impl Service> for Routes { +impl Service> for Routes { type Response = Response; type Error = crate::Error; type Future = RoutesFuture; @@ -113,13 +112,13 @@ impl Service> for Routes { Poll::Ready(Ok(())) } - fn call(&mut self, req: Request) -> Self::Future { + fn call(&mut self, req: Request) -> Self::Future { RoutesFuture(self.router.call(req)) } } #[pin_project] -pub struct RoutesFuture(#[pin] axum::routing::future::RouteFuture); +pub struct RoutesFuture(#[pin] axum::routing::future::RouteFuture); impl fmt::Debug for RoutesFuture { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -137,3 +136,33 @@ impl Future for RoutesFuture { } } } + +#[derive(Clone)] +struct AxumBodyService { + service: S, +} + +impl Service> for AxumBodyService +where + S: Service, Response = Response, Error = Infallible> + + Clone + + Send + + 'static, + S::Future: Send + 'static, +{ + type Response = Response; + type Error = Infallible; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.service.poll_ready(cx) + } + + fn call(&mut self, req: Request) -> Self::Future { + let fut = self.service.call(req.map(|body| boxed(body))); + Box::pin(async move { + fut.await + .map(|res| res.map(|body| axum::body::Body::new(body))) + }) + } +} diff --git a/tonic/src/transport/service/tls.rs b/tonic/src/transport/service/tls.rs index 2a6394a4f..2ce9dc5da 100644 --- a/tonic/src/transport/service/tls.rs +++ b/tonic/src/transport/service/tls.rs @@ -18,6 +18,7 @@ use crate::transport::{ server::{Connected, TlsStream}, Certificate, Identity, }; +use hyper_util::rt::TokioIo; /// h2 alpn in plain format for rustls. const ALPN_H2: &[u8] = b"h2"; @@ -88,7 +89,7 @@ impl TlsConnector { if !(alpn_protocol == Some(ALPN_H2) || self.assume_http2) { return Err(TlsError::H2NotNegotiated.into()); } - Ok(BoxedIo::new(io)) + Ok(BoxedIo::new(TokioIo::new(io))) } }