From d258fb8dee7190d806c8826a504440182fdb4330 Mon Sep 17 00:00:00 2001 From: tottoto Date: Fri, 9 Feb 2024 21:10:59 +0900 Subject: [PATCH] feat(transport): Make transport server and channel independent --- .github/workflows/CI.yml | 5 +- tonic/Cargo.toml | 19 ++++-- tonic/src/request.rs | 20 +++--- tonic/src/status.rs | 18 ++--- tonic/src/transport/error.rs | 6 +- tonic/src/transport/mod.rs | 9 ++- tonic/src/transport/server/incoming.rs | 2 +- tonic/src/transport/server/mod.rs | 9 +-- .../src/transport/{ => server}/service/io.rs | 3 +- tonic/src/transport/server/service/mod.rs | 8 +++ .../transport/{ => server}/service/router.rs | 0 tonic/src/transport/server/service/tls.rs | 65 +++++++++++++++++++ tonic/src/transport/server/tls.rs | 7 +- tonic/src/transport/service/mod.rs | 10 +-- tonic/src/transport/service/tls.rs | 65 +------------------ 15 files changed, 136 insertions(+), 110 deletions(-) rename tonic/src/transport/{ => server}/service/io.rs (98%) create mode 100644 tonic/src/transport/server/service/mod.rs rename tonic/src/transport/{ => server}/service/router.rs (100%) create mode 100644 tonic/src/transport/server/service/tls.rs diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 06bf43629..fc07a7b67 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -60,7 +60,10 @@ jobs: with: tool: protoc@${{ env.PROTOC_VERSION }} - uses: Swatinem/rust-cache@v2 - - run: cargo hack udeps --workspace --each-feature ${{ matrix.option }} + - run: cargo hack udeps --workspace --exclude-features tls --each-feature ${{ matrix.option }} + - run: cargo udeps --package tonic --features tls,transport + - run: cargo udeps --package tonic --features tls,server + - run: cargo udeps --package tonic --features tls,channel check: runs-on: ${{ matrix.os }} diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml index a9f452300..d5f8eb556 100644 --- a/tonic/Cargo.toml +++ b/tonic/Cargo.toml @@ -26,24 +26,29 @@ version = "0.11.0" codegen = ["dep:async-trait"] gzip = ["dep:flate2"] zstd = ["dep:zstd"] -default = ["channel", "codegen", "prost"] +default = ["channel", "codegen", "prost", "server"] prost = ["dep:prost"] -tls = ["dep:rustls-pki-types", "dep:rustls-pemfile", "transport", "dep:tokio-rustls", "dep:tokio", "tokio?/rt", "tokio?/macros"] +tls = ["dep:rustls-pki-types", "dep:rustls-pemfile", "dep:tokio-rustls", "dep:tokio", "tokio?/rt", "tokio?/macros"] tls-roots = ["tls-roots-common", "dep:rustls-native-certs"] tls-roots-common = ["tls", "channel"] tls-webpki-roots = ["tls-roots-common", "dep:webpki-roots"] transport = [ + "dep:tower", "tower?/util", "tower?/limit", + "dep:tokio", "tokio?/time", + "dep:hyper", +] +server = [ + "transport", "dep:async-stream", + "tokio?/net", "dep:axum", "dep:h2", - "dep:hyper", "hyper?/server", - "dep:tokio", "tokio?/net", "tokio?/time", - "dep:tower", "tower?/util", "tower?/limit", + "hyper?/server", ] channel = [ "transport", - "dep:hyper", "hyper?/client", - "dep:tower", "tower?/balance", "tower?/buffer", "tower?/discover", "tower?/load", "tower?/make", + "hyper?/client", + "tower?/balance", "tower?/buffer", "tower?/discover", "tower?/load", "tower?/make", "dep:hyper-timeout", ] diff --git a/tonic/src/request.rs b/tonic/src/request.rs index 76bf4e9eb..abe19d322 100644 --- a/tonic/src/request.rs +++ b/tonic/src/request.rs @@ -1,12 +1,12 @@ use crate::metadata::{MetadataMap, MetadataValue}; -#[cfg(feature = "transport")] +#[cfg(feature = "server")] use crate::transport::server::TcpConnectInfo; -#[cfg(feature = "tls")] +#[cfg(all(feature = "server", feature = "tls"))] use crate::transport::{server::TlsConnectInfo, Certificate}; use crate::Extensions; -#[cfg(feature = "transport")] +#[cfg(feature = "server")] use std::net::SocketAddr; -#[cfg(feature = "tls")] +#[cfg(all(feature = "server", feature = "tls"))] use std::sync::Arc; use std::time::Duration; use tokio_stream::Stream; @@ -209,8 +209,8 @@ impl Request { /// This will return `None` if the `IO` type used /// does not implement `Connected` or when using a unix domain socket. /// This currently only works on the server side. - #[cfg(feature = "transport")] - #[cfg_attr(docsrs, doc(cfg(feature = "transport")))] + #[cfg(feature = "server")] + #[cfg_attr(docsrs, doc(cfg(feature = "server")))] pub fn local_addr(&self) -> Option { let addr = self .extensions() @@ -232,8 +232,8 @@ impl Request { /// This will return `None` if the `IO` type used /// does not implement `Connected` or when using a unix domain socket. /// This currently only works on the server side. - #[cfg(feature = "transport")] - #[cfg_attr(docsrs, doc(cfg(feature = "transport")))] + #[cfg(feature = "server")] + #[cfg_attr(docsrs, doc(cfg(feature = "server")))] pub fn remote_addr(&self) -> Option { let addr = self .extensions() @@ -256,8 +256,8 @@ impl Request { /// and is mostly used for mTLS. This currently only returns /// `Some` on the server side of the `transport` server with /// TLS enabled connections. - #[cfg(feature = "tls")] - #[cfg_attr(docsrs, doc(cfg(feature = "tls")))] + #[cfg(all(feature = "server", feature = "tls"))] + #[cfg_attr(docsrs, doc(cfg(all(feature = "server", feature = "tls"))))] pub fn peer_certs(&self) -> Option>> { self.extensions() .get::>() diff --git a/tonic/src/status.rs b/tonic/src/status.rs index da8b792e5..40552201c 100644 --- a/tonic/src/status.rs +++ b/tonic/src/status.rs @@ -342,7 +342,7 @@ impl Status { Err(err) => err, }; - #[cfg(feature = "transport")] + #[cfg(feature = "server")] let err = match err.downcast::() { Ok(h2) => { return Ok(Status::from_h2_error(h2)); @@ -359,7 +359,7 @@ impl Status { } // FIXME: bubble this into `transport` and expose generic http2 reasons. - #[cfg(feature = "transport")] + #[cfg(feature = "server")] fn from_h2_error(err: Box) -> Status { let code = Self::code_from_h2(&err); @@ -368,7 +368,7 @@ impl Status { status } - #[cfg(feature = "transport")] + #[cfg(feature = "server")] fn code_from_h2(err: &h2::Error) -> Code { // See https://github.com/grpc/grpc/blob/3977c30/doc/PROTOCOL-HTTP2.md#errors match err.reason() { @@ -388,7 +388,7 @@ impl Status { } } - #[cfg(feature = "transport")] + #[cfg(feature = "server")] fn to_h2_error(&self) -> h2::Error { // conservatively transform to h2 error codes... let reason = match self.code { @@ -404,7 +404,7 @@ impl Status { /// /// Returns Some if there's a way to handle the error, or None if the information from this /// hyper error, but perhaps not its source, should be ignored. - #[cfg(feature = "transport")] + #[cfg(feature = "server")] fn from_hyper_error(err: &hyper::Error) -> Option { // is_timeout results from hyper's keep-alive logic // (https://docs.rs/hyper/0.14.11/src/hyper/error.rs.html#192-194). Per the grpc spec @@ -614,12 +614,12 @@ fn find_status_in_source_chain(err: &(dyn Error + 'static)) -> Option { }); } - #[cfg(feature = "transport")] + #[cfg(any(feature = "server", feature = "channel"))] if let Some(timeout) = err.downcast_ref::() { return Some(Status::cancelled(timeout.to_string())); } - #[cfg(feature = "transport")] + #[cfg(feature = "server")] if let Some(hyper) = err .downcast_ref::() .and_then(Status::from_hyper_error) @@ -666,14 +666,14 @@ fn invalid_header_value_byte(err: Error) -> Status { ) } -#[cfg(feature = "transport")] +#[cfg(feature = "server")] impl From for Status { fn from(err: h2::Error) -> Self { Status::from_h2_error(Box::new(err)) } } -#[cfg(feature = "transport")] +#[cfg(feature = "server")] impl From for h2::Error { fn from(status: Status) -> Self { status.to_h2_error() diff --git a/tonic/src/transport/error.rs b/tonic/src/transport/error.rs index 92a910498..668997a1c 100644 --- a/tonic/src/transport/error.rs +++ b/tonic/src/transport/error.rs @@ -14,6 +14,7 @@ struct ErrorImpl { #[derive(Debug)] pub(crate) enum Kind { + #[allow(unused)] Transport, #[cfg(feature = "channel")] InvalidUri, @@ -22,17 +23,20 @@ pub(crate) enum Kind { } impl Error { - pub(crate) fn new(kind: Kind) -> Self { + #[cfg(any(feature = "server", feature = "channel"))] + fn new(kind: Kind) -> Self { Self { inner: ErrorImpl { kind, source: None }, } } + #[cfg(any(feature = "server", feature = "channel"))] pub(crate) fn with(mut self, source: impl Into) -> Self { self.inner.source = Some(source.into()); self } + #[cfg(any(feature = "server", feature = "channel"))] pub(crate) fn from_source(source: impl Into) -> Self { Error::new(Kind::Transport).with(source) } diff --git a/tonic/src/transport/mod.rs b/tonic/src/transport/mod.rs index 5a41e9524..a84f03e1e 100644 --- a/tonic/src/transport/mod.rs +++ b/tonic/src/transport/mod.rs @@ -89,6 +89,7 @@ #[cfg(feature = "channel")] pub mod channel; +#[cfg(feature = "server")] pub mod server; mod error; @@ -102,12 +103,16 @@ mod tls; pub use self::channel::{Channel, Endpoint}; pub use self::error::Error; #[doc(inline)] +#[cfg(feature = "server")] pub use self::server::Server; #[doc(inline)] +#[cfg(any(feature = "server", feature = "channel"))] pub use self::service::grpc_timeout::TimeoutExpired; #[cfg(feature = "tls")] #[cfg_attr(docsrs, doc(cfg(feature = "tls")))] pub use self::tls::Certificate; +#[cfg(feature = "server")] +#[cfg_attr(docsrs, doc(cfg(feature = "server")))] pub use axum::{body::BoxBody as AxumBoxBody, Router as AxumRouter}; pub use hyper::{Body, Uri}; @@ -117,8 +122,8 @@ pub(crate) use self::channel::service::executor::Executor; #[cfg(all(feature = "channel", feature = "tls"))] #[cfg_attr(docsrs, doc(cfg(all(feature = "channel", feature = "tls"))))] pub use self::channel::ClientTlsConfig; -#[cfg(feature = "tls")] -#[cfg_attr(docsrs, doc(cfg(feature = "tls")))] +#[cfg(all(feature = "server", feature = "tls"))] +#[cfg_attr(docsrs, doc(cfg(all(feature = "server", feature = "tls"))))] pub use self::server::ServerTlsConfig; #[cfg(feature = "tls")] #[cfg_attr(docsrs, doc(cfg(feature = "tls")))] diff --git a/tonic/src/transport/server/incoming.rs b/tonic/src/transport/server/incoming.rs index bc1bb7650..eb431affb 100644 --- a/tonic/src/transport/server/incoming.rs +++ b/tonic/src/transport/server/incoming.rs @@ -1,5 +1,5 @@ +use super::service::ServerIo; use super::{Connected, Server}; -use crate::transport::service::ServerIo; use hyper::server::{ accept::Accept, conn::{AddrIncoming, AddrStream}, diff --git a/tonic/src/transport/server/mod.rs b/tonic/src/transport/server/mod.rs index 7f2ffde2b..e1798d117 100644 --- a/tonic/src/transport/server/mod.rs +++ b/tonic/src/transport/server/mod.rs @@ -3,14 +3,14 @@ mod conn; mod incoming; mod recover_error; +mod service; #[cfg(feature = "tls")] #[cfg_attr(docsrs, doc(cfg(feature = "tls")))] mod tls; #[cfg(unix)] mod unix; -pub use super::service::Routes; -pub use super::service::RoutesBuilder; +pub use self::service::{Routes, RoutesBuilder}; pub use conn::{Connected, TcpConnectInfo}; #[cfg(feature = "tls")] @@ -20,7 +20,7 @@ pub use tls::ServerTlsConfig; pub use conn::TlsConnectInfo; #[cfg(feature = "tls")] -use super::service::TlsAcceptor; +use self::service::tls::TlsAcceptor; #[cfg(unix)] pub use unix::UdsConnectInfo; @@ -34,7 +34,8 @@ pub(crate) use tokio_rustls::server::TlsStream; use crate::transport::Error; use self::recover_error::RecoverError; -use super::service::{GrpcTimeout, ServerIo}; +use self::service::ServerIo; +use super::service::GrpcTimeout; use crate::body::BoxBody; use crate::server::NamedService; use bytes::Bytes; diff --git a/tonic/src/transport/service/io.rs b/tonic/src/transport/server/service/io.rs similarity index 98% rename from tonic/src/transport/service/io.rs rename to tonic/src/transport/server/service/io.rs index 7821f691c..ba8830d40 100644 --- a/tonic/src/transport/service/io.rs +++ b/tonic/src/transport/server/service/io.rs @@ -1,4 +1,3 @@ -use crate::transport::server::Connected; use std::io; use std::io::IoSlice; use std::pin::Pin; @@ -7,6 +6,8 @@ use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; #[cfg(feature = "tls")] use tokio_rustls::server::TlsStream; +use super::super::Connected; + pub(crate) enum ServerIo { Io(IO), #[cfg(feature = "tls")] diff --git a/tonic/src/transport/server/service/mod.rs b/tonic/src/transport/server/service/mod.rs new file mode 100644 index 000000000..79af829be --- /dev/null +++ b/tonic/src/transport/server/service/mod.rs @@ -0,0 +1,8 @@ +pub(crate) mod io; +pub(crate) use self::io::ServerIo; + +mod router; +pub use self::router::{Routes, RoutesBuilder}; + +#[cfg(feature = "tls")] +pub(crate) mod tls; diff --git a/tonic/src/transport/service/router.rs b/tonic/src/transport/server/service/router.rs similarity index 100% rename from tonic/src/transport/service/router.rs rename to tonic/src/transport/server/service/router.rs diff --git a/tonic/src/transport/server/service/tls.rs b/tonic/src/transport/server/service/tls.rs new file mode 100644 index 000000000..7b0025f54 --- /dev/null +++ b/tonic/src/transport/server/service/tls.rs @@ -0,0 +1,65 @@ +use std::io::Cursor; +use std::{fmt, sync::Arc}; + +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_rustls::{ + rustls::{server::WebPkiClientVerifier, RootCertStore, ServerConfig}, + TlsAcceptor as RustlsAcceptor, +}; + +use crate::transport::server::Connected; +use crate::transport::server::TlsStream; +use crate::transport::service::tls::{add_certs_from_pem, load_identity, ALPN_H2}; +use crate::transport::tls::{Certificate, Identity}; + +#[derive(Clone)] +pub(crate) struct TlsAcceptor { + inner: Arc, +} + +impl TlsAcceptor { + pub(crate) fn new( + identity: Identity, + client_ca_root: Option, + client_auth_optional: bool, + ) -> Result { + let builder = ServerConfig::builder(); + + let builder = match client_ca_root { + None => builder.with_no_client_auth(), + Some(cert) => { + let mut roots = RootCertStore::empty(); + add_certs_from_pem(&mut Cursor::new(cert), &mut roots)?; + let verifier = if client_auth_optional { + WebPkiClientVerifier::builder(roots.into()).allow_unauthenticated() + } else { + WebPkiClientVerifier::builder(roots.into()) + } + .build()?; + builder.with_client_cert_verifier(verifier) + } + }; + + let (cert, key) = load_identity(identity)?; + let mut config = builder.with_single_cert(cert, key)?; + + config.alpn_protocols.push(ALPN_H2.into()); + Ok(Self { + inner: Arc::new(config), + }) + } + + pub(crate) async fn accept(&self, io: IO) -> Result, crate::Error> + where + IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static, + { + let acceptor = RustlsAcceptor::from(self.inner.clone()); + acceptor.accept(io).await.map_err(Into::into) + } +} + +impl fmt::Debug for TlsAcceptor { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TlsAcceptor").finish() + } +} diff --git a/tonic/src/transport/server/tls.rs b/tonic/src/transport/server/tls.rs index d320b5024..91d24a201 100644 --- a/tonic/src/transport/server/tls.rs +++ b/tonic/src/transport/server/tls.rs @@ -1,9 +1,8 @@ -use crate::transport::{ - service::TlsAcceptor, - tls::{Certificate, Identity}, -}; use std::fmt; +use crate::transport::server::service::tls::TlsAcceptor; +use crate::transport::tls::{Certificate, Identity}; + /// Configures TLS settings for servers. #[derive(Clone, Default)] pub struct ServerTlsConfig { diff --git a/tonic/src/transport/service/mod.rs b/tonic/src/transport/service/mod.rs index 98129790f..d5002dc34 100644 --- a/tonic/src/transport/service/mod.rs +++ b/tonic/src/transport/service/mod.rs @@ -1,13 +1,7 @@ +#[cfg(any(feature = "server", feature = "channel"))] pub(crate) mod grpc_timeout; -pub(crate) mod io; -mod router; #[cfg(feature = "tls")] pub(super) mod tls; +#[cfg(any(feature = "server", feature = "channel"))] pub(crate) use self::grpc_timeout::GrpcTimeout; -pub(crate) use self::io::ServerIo; -#[cfg(feature = "tls")] -pub(crate) use self::tls::TlsAcceptor; - -pub use self::router::Routes; -pub use self::router::RoutesBuilder; diff --git a/tonic/src/transport/service/tls.rs b/tonic/src/transport/service/tls.rs index 62f763e2e..8d5f24e89 100644 --- a/tonic/src/transport/service/tls.rs +++ b/tonic/src/transport/service/tls.rs @@ -1,17 +1,10 @@ +use std::fmt; use std::io::Cursor; -use std::{fmt, sync::Arc}; use rustls_pki_types::{CertificateDer, PrivateKeyDer}; -use tokio::io::{AsyncRead, AsyncWrite}; -use tokio_rustls::{ - rustls::{server::WebPkiClientVerifier, RootCertStore, ServerConfig}, - TlsAcceptor as RustlsAcceptor, -}; +use tokio_rustls::rustls::RootCertStore; -use crate::transport::{ - server::{Connected, TlsStream}, - Certificate, Identity, -}; +use crate::transport::tls::Identity; /// h2 alpn in plain format for rustls. pub(crate) const ALPN_H2: &[u8] = b"h2"; @@ -22,58 +15,6 @@ enum TlsError { PrivateKeyParseError, } -#[derive(Clone)] -pub(crate) struct TlsAcceptor { - inner: Arc, -} - -impl TlsAcceptor { - pub(crate) fn new( - identity: Identity, - client_ca_root: Option, - client_auth_optional: bool, - ) -> Result { - let builder = ServerConfig::builder(); - - let builder = match client_ca_root { - None => builder.with_no_client_auth(), - Some(cert) => { - let mut roots = RootCertStore::empty(); - add_certs_from_pem(&mut Cursor::new(cert), &mut roots)?; - let verifier = if client_auth_optional { - WebPkiClientVerifier::builder(roots.into()).allow_unauthenticated() - } else { - WebPkiClientVerifier::builder(roots.into()) - } - .build()?; - builder.with_client_cert_verifier(verifier) - } - }; - - let (cert, key) = load_identity(identity)?; - let mut config = builder.with_single_cert(cert, key)?; - - config.alpn_protocols.push(ALPN_H2.into()); - Ok(Self { - inner: Arc::new(config), - }) - } - - pub(crate) async fn accept(&self, io: IO) -> Result, crate::Error> - where - IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static, - { - let acceptor = RustlsAcceptor::from(self.inner.clone()); - acceptor.accept(io).await.map_err(Into::into) - } -} - -impl fmt::Debug for TlsAcceptor { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("TlsAcceptor").finish() - } -} - impl fmt::Display for TlsError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self {