Skip to content

Commit

Permalink
feat(tls): Remove tls roots implicit configuration (#1731)
Browse files Browse the repository at this point in the history
* feat(tls): Add option to enable tls roots

* feat(tls): Remove tls roots implicit configuration
  • Loading branch information
tottoto authored Jun 21, 2024
1 parent 34b863b commit de73617
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 63 deletions.
5 changes: 2 additions & 3 deletions tonic/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,8 @@ zstd = ["dep:zstd"]
default = ["transport", "codegen", "prost"]
prost = ["dep:prost"]
tls = ["dep:rustls-pemfile", "dep:tokio-rustls", "dep:tokio", "tokio?/rt", "tokio?/macros"]
tls-roots = ["tls-roots-common", "dep:rustls-native-certs"]
tls-roots-common = ["tls", "channel"]
tls-webpki-roots = ["tls-roots-common", "dep:webpki-roots"]
tls-roots = ["tls", "channel", "dep:rustls-native-certs"]
tls-webpki-roots = ["tls", "channel", "dep:webpki-roots"]
router = ["dep:axum"]
server = [
"router",
Expand Down
34 changes: 5 additions & 29 deletions tonic/src/transport/channel/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,6 @@ pub struct Endpoint {
pub(crate) rate_limit: Option<(u64, Duration)>,
#[cfg(feature = "tls")]
pub(crate) tls: Option<TlsConnector>,
// Only applies if the tls config is not explicitly set. This allows users
// to connect to a server that doesn't support ALPN while using the
// tls-roots-common feature for setting up TLS.
#[cfg(feature = "tls-roots-common")]
pub(crate) tls_assume_http2: bool,
pub(crate) buffer_size: Option<usize>,
pub(crate) init_stream_window_size: Option<u32>,
pub(crate) init_connection_window_size: Option<u32>,
Expand Down Expand Up @@ -256,18 +251,6 @@ impl Endpoint {
})
}

/// Configures TLS to assume that the server offers HTTP/2 even if it
/// doesn't perform ALPN negotiation. This only applies if a tls_config has
/// not been set.
#[cfg(feature = "tls-roots-common")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls-roots-common")))]
pub fn tls_assume_http2(self, assume_http2: bool) -> Self {
Endpoint {
tls_assume_http2: assume_http2,
..self
}
}

/// Set the value of `TCP_NODELAY` option for accepted connections. Enabled by default.
pub fn tcp_nodelay(self, enabled: bool) -> Self {
Endpoint {
Expand Down Expand Up @@ -320,16 +303,11 @@ impl Endpoint {
}

pub(crate) fn connector<C>(&self, c: C) -> service::Connector<C> {
#[cfg(all(feature = "tls", not(feature = "tls-roots-common")))]
let connector = service::Connector::new(c, self.tls.clone());

#[cfg(all(feature = "tls", feature = "tls-roots-common"))]
let connector = service::Connector::new(c, self.tls.clone(), self.tls_assume_http2);

#[cfg(not(feature = "tls"))]
let connector = service::Connector::new(c);

connector
service::Connector::new(
c,
#[cfg(feature = "tls")]
self.tls.clone(),
)
}

/// Create a channel from this config.
Expand Down Expand Up @@ -435,8 +413,6 @@ impl From<Uri> for Endpoint {
timeout: None,
#[cfg(feature = "tls")]
tls: None,
#[cfg(feature = "tls-roots-common")]
tls_assume_http2: false,
buffer_size: None,
init_stream_window_size: None,
init_connection_window_size: None,
Expand Down
31 changes: 2 additions & 29 deletions tonic/src/transport/channel/service/connector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,40 +33,16 @@ pub(crate) struct Connector<C> {
inner: C,
#[cfg(feature = "tls")]
tls: Option<TlsConnector>,
// When connecting to a URI with the https scheme, assume that the server
// is capable of speaking HTTP/2 even if it doesn't offer ALPN.
#[cfg(feature = "tls-roots-common")]
assume_http2: bool,
}

impl<C> Connector<C> {
pub(crate) fn new(
inner: C,
#[cfg(feature = "tls")] tls: Option<TlsConnector>,
#[cfg(feature = "tls-roots-common")] assume_http2: bool,
) -> Self {
pub(crate) fn new(inner: C, #[cfg(feature = "tls")] tls: Option<TlsConnector>) -> Self {
Self {
inner,
#[cfg(feature = "tls")]
tls,
#[cfg(feature = "tls-roots-common")]
assume_http2,
}
}

#[cfg(feature = "tls-roots-common")]
fn tls_or_default(&self, scheme: Option<&str>, host: Option<&str>) -> Option<TlsConnector> {
if self.tls.is_some() {
return self.tls.clone();
}

let host = match (scheme, host) {
(Some("https"), Some(host)) => host,
_ => return None,
};

TlsConnector::new(Vec::new(), None, host, self.assume_http2).ok()
}
}

impl<C> Service<Uri> for Connector<C>
Expand All @@ -87,12 +63,9 @@ where
}

fn call(&mut self, uri: Uri) -> Self::Future {
#[cfg(all(feature = "tls", not(feature = "tls-roots-common")))]
#[cfg(feature = "tls")]
let tls = self.tls.clone();

#[cfg(feature = "tls-roots-common")]
let tls = self.tls_or_default(uri.scheme_str(), uri.host());

#[cfg(feature = "tls")]
let is_https = uri.scheme_str() == Some("https");
let connect = self.inner.call(uri);
Expand Down
10 changes: 8 additions & 2 deletions tonic/src/transport/channel/service/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,21 @@ impl TlsConnector {
identity: Option<Identity>,
domain: &str,
assume_http2: bool,
#[cfg(feature = "tls-roots")] with_native_roots: bool,
#[cfg(feature = "tls-webpki-roots")] with_webpki_roots: bool,
) -> Result<Self, crate::Error> {
let builder = ClientConfig::builder();
let mut roots = RootCertStore::empty();

#[cfg(feature = "tls-roots")]
roots.add_parsable_certificates(rustls_native_certs::load_native_certs()?);
if with_native_roots {
roots.add_parsable_certificates(rustls_native_certs::load_native_certs()?);
}

#[cfg(feature = "tls-webpki-roots")]
roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
if with_webpki_roots {
roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
}

for cert in ca_certs {
add_certs_from_pem(&mut Cursor::new(cert), &mut roots)?;
Expand Down
32 changes: 32 additions & 0 deletions tonic/src/transport/channel/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ pub struct ClientTlsConfig {
certs: Vec<Certificate>,
identity: Option<Identity>,
assume_http2: bool,
#[cfg(feature = "tls-roots")]
with_native_roots: bool,
#[cfg(feature = "tls-webpki-roots")]
with_webpki_roots: bool,
}

impl fmt::Debug for ClientTlsConfig {
Expand All @@ -33,6 +37,10 @@ impl ClientTlsConfig {
certs: Vec::new(),
identity: None,
assume_http2: false,
#[cfg(feature = "tls-roots")]
with_native_roots: false,
#[cfg(feature = "tls-webpki-roots")]
with_webpki_roots: false,
}
}

Expand Down Expand Up @@ -75,6 +83,26 @@ impl ClientTlsConfig {
}
}

/// Enables the platform's trusted certs.
#[cfg(feature = "tls-roots")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls-roots")))]
pub fn with_native_roots(self) -> Self {
ClientTlsConfig {
with_native_roots: true,
..self
}
}

/// Enables the webpki roots.
#[cfg(feature = "tls-webpki-roots")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls-webpki-roots")))]
pub fn with_webpki_roots(self) -> Self {
ClientTlsConfig {
with_webpki_roots: true,
..self
}
}

pub(crate) fn tls_connector(&self, uri: Uri) -> Result<TlsConnector, crate::Error> {
let domain = match &self.domain {
Some(domain) => domain,
Expand All @@ -85,6 +113,10 @@ impl ClientTlsConfig {
self.identity.clone(),
domain,
self.assume_http2,
#[cfg(feature = "tls-roots")]
self.with_native_roots,
#[cfg(feature = "tls-webpki-roots")]
self.with_webpki_roots,
)
}
}

0 comments on commit de73617

Please sign in to comment.