diff --git a/tokio-rustls/src/common/test_stream.rs b/tokio-rustls/src/common/test_stream.rs index 89a5686d..3b3966d4 100644 --- a/tokio-rustls/src/common/test_stream.rs +++ b/tokio-rustls/src/common/test_stream.rs @@ -1,11 +1,9 @@ use super::Stream; use futures_util::future::poll_fn; use futures_util::task::noop_waker_ref; -use rustls::{ClientConnection, Connection, OwnedTrustAnchor, RootCertStore, ServerConnection}; -use rustls_pemfile::{certs, rsa_private_keys}; -use std::io::{self, BufReader, Cursor, Read, Write}; +use rustls::{ClientConnection, Connection, ServerConnection}; +use std::io::{self, Cursor, Read, Write}; use std::pin::Pin; -use std::sync::Arc; use std::task::{Context, Poll}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}; @@ -261,45 +259,11 @@ async fn stream_eof() -> io::Result<()> { fn make_pair() -> (ServerConnection, ClientConnection) { use std::convert::TryFrom; - const CERT: &str = include_str!("../../tests/end.cert"); - const CHAIN: &str = include_str!("../../tests/end.chain"); - const RSA: &str = include_str!("../../tests/end.rsa"); - - let cert = certs(&mut BufReader::new(Cursor::new(CERT))) - .unwrap() - .drain(..) - .map(rustls::Certificate) - .collect(); - let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap(); - let mut keys = keys.drain(..).map(rustls::PrivateKey); - let sconfig = rustls::ServerConfig::builder() - .with_safe_defaults() - .with_no_client_auth() - .with_single_cert(cert, keys.next().unwrap()) - .unwrap(); - let server = ServerConnection::new(Arc::new(sconfig)).unwrap(); + let (sconfig, cconfig) = utils::make_configs(); + let server = ServerConnection::new(sconfig).unwrap(); let domain = rustls::ServerName::try_from("localhost").unwrap(); - let mut client_root_cert_store = RootCertStore::empty(); - let mut chain = BufReader::new(Cursor::new(CHAIN)); - let certs = certs(&mut chain).unwrap(); - let trust_anchors = certs - .iter() - .map(|cert| { - let ta = webpki::TrustAnchor::try_from_cert_der(&cert[..]).unwrap(); - OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject, - ta.spki, - ta.name_constraints, - ) - }) - .collect::>(); - client_root_cert_store.add_server_trust_anchors(trust_anchors.into_iter()); - let cconfig = rustls::ClientConfig::builder() - .with_safe_defaults() - .with_root_certificates(client_root_cert_store) - .with_no_client_auth(); - let client = ClientConnection::new(Arc::new(cconfig), domain).unwrap(); + let client = ClientConnection::new(cconfig, domain).unwrap(); (server, client) } @@ -322,3 +286,6 @@ fn do_handshake( Poll::Ready(Ok(())) } + +// Share `utils` module with integration tests +include!("../../tests/utils.rs"); diff --git a/tokio-rustls/src/lib.rs b/tokio-rustls/src/lib.rs index d849cb20..9c2a8d41 100644 --- a/tokio-rustls/src/lib.rs +++ b/tokio-rustls/src/lib.rs @@ -188,6 +188,124 @@ impl TlsAcceptor { } } +pub struct LazyConfigAcceptor { + acceptor: rustls::server::Acceptor, + buf: Vec, + used: usize, + io: Option, +} + +impl LazyConfigAcceptor +where + IO: AsyncRead + AsyncWrite + Unpin, +{ + #[inline] + pub fn new(acceptor: rustls::server::Acceptor, io: IO) -> Self { + Self { + acceptor, + buf: vec![0; 512], + used: 0, + io: Some(io), + } + } +} + +impl Future for LazyConfigAcceptor +where + IO: AsyncRead + AsyncWrite + Unpin, +{ + type Output = Result, io::Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + loop { + let io = match this.io.as_mut() { + Some(io) => io, + None => { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::Other, + "acceptor cannot be polled after acceptance", + ))) + } + }; + + let mut buf = ReadBuf::new(&mut this.buf); + buf.advance(this.used); + if buf.remaining() > 0 { + if let Err(err) = ready!(Pin::new(io).poll_read(cx, &mut buf)) { + return Poll::Ready(Err(err)); + } + } + + let read = match this.acceptor.read_tls(&mut buf.filled()) { + Ok(read) => read, + Err(err) => return Poll::Ready(Err(err)), + }; + + let received = buf.filled().len(); + if read < received { + this.buf.copy_within(read.., 0); + this.used = received - read; + } else { + this.used = 0; + } + + match this.acceptor.accept() { + Ok(Some(accepted)) => { + let io = this.io.take().unwrap(); + return Poll::Ready(Ok(StartHandshake { accepted, io })); + } + Ok(None) => continue, + Err(err) => { + return Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidInput, err))) + } + } + } + } +} + +pub struct StartHandshake { + accepted: rustls::server::Accepted, + io: IO, +} + +impl StartHandshake +where + IO: AsyncRead + AsyncWrite + Unpin, +{ + pub fn client_hello(&self) -> rustls::server::ClientHello<'_> { + self.accepted.client_hello() + } + + pub fn into_stream(self, config: Arc) -> Accept { + self.into_stream_with(config, |_| ()) + } + + pub fn into_stream_with(self, config: Arc, f: F) -> Accept + where + F: FnOnce(&mut ServerConnection), + { + let mut conn = match self.accepted.into_connection(config) { + Ok(conn) => conn, + Err(error) => { + return Accept(MidHandshake::Error { + io: self.io, + // TODO(eliza): should this really return an `io::Error`? + // Probably not... + error: io::Error::new(io::ErrorKind::Other, error), + }); + } + }; + f(&mut conn); + + Accept(MidHandshake::Handshaking(server::TlsStream { + session: conn, + io: self.io, + state: TlsState::Stream, + })) + } +} + /// Future returned from `TlsConnector::connect` which will resolve /// once the connection handshake has finished. pub struct Connect(MidHandshake>); diff --git a/tokio-rustls/tests/test.rs b/tokio-rustls/tests/test.rs index f0f5f56b..29fa6038 100644 --- a/tokio-rustls/tests/test.rs +++ b/tokio-rustls/tests/test.rs @@ -11,7 +11,7 @@ use std::{io, thread}; use tokio::io::{copy, split, AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; use tokio::runtime; -use tokio_rustls::{TlsAcceptor, TlsConnector}; +use tokio_rustls::{LazyConfigAcceptor, TlsAcceptor, TlsConnector}; const CERT: &str = include_str!("end.cert"); const CHAIN: &[u8] = include_bytes!("end.chain"); @@ -164,3 +164,43 @@ async fn fail() -> io::Result<()> { Ok(()) } + +#[tokio::test] +async fn test_lazy_config_acceptor() -> io::Result<()> { + let (sconfig, cconfig) = utils::make_configs(); + use std::convert::TryFrom; + + let (cstream, sstream) = tokio::io::duplex(1200); + let domain = rustls::ServerName::try_from("localhost").unwrap(); + tokio::spawn(async move { + let connector = crate::TlsConnector::from(cconfig); + let mut client = connector.connect(domain, cstream).await.unwrap(); + client.write_all(b"hello, world!").await.unwrap(); + + let mut buf = Vec::new(); + client.read_to_end(&mut buf).await.unwrap(); + }); + + let acceptor = LazyConfigAcceptor::new(rustls::server::Acceptor::new().unwrap(), sstream); + let start = acceptor.await.unwrap(); + let ch = start.client_hello(); + + assert_eq!(ch.server_name(), Some("localhost")); + assert_eq!( + ch.alpn() + .map(|protos| protos.collect::>()) + .unwrap_or(Vec::new()), + Vec::<&[u8]>::new() + ); + + let mut stream = start.into_stream(sconfig).await.unwrap(); + let mut buf = [0; 13]; + stream.read_exact(&mut buf).await.unwrap(); + assert_eq!(&buf[..], b"hello, world!"); + + stream.write_all(b"bye").await.unwrap(); + Ok(()) +} + +// Include `utils` module +include!("utils.rs"); diff --git a/tokio-rustls/tests/utils.rs b/tokio-rustls/tests/utils.rs new file mode 100644 index 00000000..c0f7ed0b --- /dev/null +++ b/tokio-rustls/tests/utils.rs @@ -0,0 +1,49 @@ +mod utils { + use std::io::{BufReader, Cursor}; + use std::sync::Arc; + + use rustls::{ClientConfig, OwnedTrustAnchor, PrivateKey, RootCertStore, ServerConfig}; + use rustls_pemfile::{certs, rsa_private_keys}; + + #[allow(dead_code)] + pub fn make_configs() -> (Arc, Arc) { + const CERT: &str = include_str!("end.cert"); + const CHAIN: &str = include_str!("end.chain"); + const RSA: &str = include_str!("end.rsa"); + + let cert = certs(&mut BufReader::new(Cursor::new(CERT))) + .unwrap() + .drain(..) + .map(rustls::Certificate) + .collect(); + let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap(); + let mut keys = keys.drain(..).map(PrivateKey); + let sconfig = ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_single_cert(cert, keys.next().unwrap()) + .unwrap(); + + let mut client_root_cert_store = RootCertStore::empty(); + let mut chain = BufReader::new(Cursor::new(CHAIN)); + let certs = certs(&mut chain).unwrap(); + let trust_anchors = certs + .iter() + .map(|cert| { + let ta = webpki::TrustAnchor::try_from_cert_der(&cert[..]).unwrap(); + OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject, + ta.spki, + ta.name_constraints, + ) + }) + .collect::>(); + client_root_cert_store.add_server_trust_anchors(trust_anchors.into_iter()); + let cconfig = ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(client_root_cert_store) + .with_no_client_auth(); + + (Arc::new(sconfig), Arc::new(cconfig)) + } +}