Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade to rustls 0.20 #244

Merged
merged 2 commits into from
Oct 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,19 @@ version = "0.2.3"

[dependencies.rustls]
optional = true
version = "0.19.0"
version = "0.20.0"

[dependencies.rustls-native-certs]
optional = true
version = "0.5.0"
version = "0.6.0"

[dependencies.webpki]
optional = true
version = "0.21"
version = "0.22"

[dependencies.webpki-roots]
optional = true
version = "0.21"
version = "0.22"

[dev-dependencies]
criterion = "0.3.4"
Expand Down
2 changes: 1 addition & 1 deletion src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ pub fn connect_with_config<Req: IntoClientRequest>(
Mode::Tls => 443,
});
let addrs = (host, port).to_socket_addrs()?;
let mut stream = connect_to_some(addrs.as_slice(), &request.uri())?;
let mut stream = connect_to_some(addrs.as_slice(), request.uri())?;
NoDelay::set_nodelay(&mut stream, true)?;

#[cfg(not(any(feature = "native-tls", feature = "__rustls-tls")))]
Expand Down
10 changes: 7 additions & 3 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,13 @@ pub enum TlsError {
/// Rustls error.
#[cfg(feature = "__rustls-tls")]
#[error("rustls error: {0}")]
Rustls(#[from] rustls::TLSError),
Rustls(#[from] rustls::Error),
/// Webpki error.
#[cfg(feature = "__rustls-tls")]
#[error("webpki error: {0}")]
Webpki(#[from] webpki::Error),
/// DNS name resolution error.
#[cfg(feature = "__rustls-tls")]
#[error("Invalid DNS name: {0}")]
Dns(#[from] webpki::InvalidDNSNameError),
#[error("Invalid DNS name")]
InvalidDnsName,
}
6 changes: 3 additions & 3 deletions src/handshake/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,15 @@ fn create_parts<T>(request: &HttpRequest<T>) -> Result<Builder> {

/// Create a response for the request.
pub fn create_response(request: &Request) -> Result<Response> {
Ok(create_parts(&request)?.body(())?)
Ok(create_parts(request)?.body(())?)
}

/// Create a response for the request with a custom body.
pub fn create_response_with_body<T>(
request: &HttpRequest<T>,
generate_body: impl FnOnce() -> T,
) -> Result<HttpResponse<T>> {
Ok(create_parts(&request)?.body(generate_body())?)
Ok(create_parts(request)?.body(generate_body())?)
}

// Assumes that this is a valid response
Expand Down Expand Up @@ -263,7 +263,7 @@ impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
let resp = self.error_response.as_ref().unwrap();

let mut output = vec![];
write_response(&mut output, &resp)?;
write_response(&mut output, resp)?;

if let Some(body) = resp.body() {
output.extend_from_slice(body.as_bytes());
Expand Down
15 changes: 11 additions & 4 deletions src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
//! `native_tls` or `openssl` will work as long as there is a TLS stream supporting standard
//! `Read + Write` traits.

#[cfg(feature = "__rustls-tls")]
use std::ops::Deref;
use std::{
fmt::{self, Debug},
io::{Read, Result as IoResult, Write},
Expand Down Expand Up @@ -45,7 +47,12 @@ impl<S: Read + Write + NoDelay> NoDelay for TlsStream<S> {
}

#[cfg(feature = "__rustls-tls")]
impl<S: rustls::Session, T: Read + Write + NoDelay> NoDelay for StreamOwned<S, T> {
impl<S, SD, T> NoDelay for StreamOwned<S, T>
where
S: Deref<Target = rustls::ConnectionCommon<SD>>,
SD: rustls::SideData,
T: Read + Write + NoDelay,
{
fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> {
self.sock.set_nodelay(nodelay)
}
Expand All @@ -61,7 +68,7 @@ pub enum MaybeTlsStream<S: Read + Write> {
NativeTls(native_tls_crate::TlsStream<S>),
#[cfg(feature = "__rustls-tls")]
/// Encrypted socket stream using `rustls`.
Rustls(rustls::StreamOwned<rustls::ClientSession, S>),
Rustls(rustls::StreamOwned<rustls::ClientConnection, S>),
}

impl<S: Read + Write + Debug> Debug for MaybeTlsStream<S> {
Expand All @@ -73,13 +80,13 @@ impl<S: Read + Write + Debug> Debug for MaybeTlsStream<S> {
#[cfg(feature = "__rustls-tls")]
Self::Rustls(s) => {
struct RustlsStreamDebug<'a, S: Read + Write>(
&'a rustls::StreamOwned<rustls::ClientSession, S>,
&'a rustls::StreamOwned<rustls::ClientConnection, S>,
);

impl<'a, S: Read + Write + Debug> Debug for RustlsStreamDebug<'a, S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("StreamOwned")
.field("sess", &self.0.sess)
.field("conn", &self.0.conn)
.field("sock", &self.0.sock)
.finish()
}
Expand Down
40 changes: 28 additions & 12 deletions src/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,10 @@ mod encryption {

#[cfg(feature = "__rustls-tls")]
pub mod rustls {
use rustls::{ClientConfig, ClientSession, StreamOwned};
use webpki::DNSNameRef;
use rustls::{ClientConfig, ClientConnection, RootCertStore, ServerName, StreamOwned};

use std::{
convert::TryFrom,
io::{Read, Write},
sync::Arc,
};
Expand All @@ -100,24 +100,40 @@ mod encryption {
Some(config) => config,
None => {
#[allow(unused_mut)]
let mut config = ClientConfig::new();
let mut root_store = RootCertStore::empty();

#[cfg(feature = "rustls-tls-native-roots")]
{
config.root_store = rustls_native_certs::load_native_certs()
.map_err(|(_, err)| err)?;
for cert in rustls_native_certs::load_native_certs()? {
root_store
.add(&rustls::Certificate(cert.0))
.map_err(TlsError::Webpki)?;
}
}
#[cfg(feature = "rustls-tls-webpki-roots")]
{
config
.root_store
.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
root_store.add_server_trust_anchors(
webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
})
);
}

Arc::new(config)
Arc::new(
ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_no_client_auth(),
)
}
};
let domain = DNSNameRef::try_from_ascii_str(domain).map_err(TlsError::Dns)?;
let client = ClientSession::new(&config, domain);
let domain =
ServerName::try_from(domain).map_err(|_| TlsError::InvalidDnsName)?;
let client = ClientConnection::new(config, domain).map_err(TlsError::Rustls)?;
let stream = StreamOwned::new(client, socket);

Ok(MaybeTlsStream::Rustls(stream))
Expand Down Expand Up @@ -185,7 +201,7 @@ where
None => Err(Error::Url(UrlError::NoHostName)),
}?;

let mode = uri_mode(&request.uri())?;
let mode = uri_mode(request.uri())?;

let stream = match connector {
Some(conn) => match conn {
Expand Down