Skip to content

Commit

Permalink
Add LazyConfigAcceptor API (#69)
Browse files Browse the repository at this point in the history
  • Loading branch information
djc authored Oct 30, 2021
1 parent 48caaf7 commit 3350601
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 42 deletions.
49 changes: 8 additions & 41 deletions tokio-rustls/src/common/test_stream.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
use super::Stream;
use futures_util::future::poll_fn;
use futures_util::task::noop_waker_ref;
use rustls::{ClientConnection, Connection, OwnedTrustAnchor, RootCertStore, ServerConnection};
use rustls_pemfile::{certs, rsa_private_keys};
use std::io::{self, BufReader, Cursor, Read, Write};
use rustls::{ClientConnection, Connection, ServerConnection};
use std::io::{self, Cursor, Read, Write};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};

Expand Down Expand Up @@ -261,45 +259,11 @@ async fn stream_eof() -> io::Result<()> {
fn make_pair() -> (ServerConnection, ClientConnection) {
use std::convert::TryFrom;

const CERT: &str = include_str!("../../tests/end.cert");
const CHAIN: &str = include_str!("../../tests/end.chain");
const RSA: &str = include_str!("../../tests/end.rsa");

let cert = certs(&mut BufReader::new(Cursor::new(CERT)))
.unwrap()
.drain(..)
.map(rustls::Certificate)
.collect();
let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap();
let mut keys = keys.drain(..).map(rustls::PrivateKey);
let sconfig = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(cert, keys.next().unwrap())
.unwrap();
let server = ServerConnection::new(Arc::new(sconfig)).unwrap();
let (sconfig, cconfig) = utils::make_configs();
let server = ServerConnection::new(sconfig).unwrap();

let domain = rustls::ServerName::try_from("localhost").unwrap();
let mut client_root_cert_store = RootCertStore::empty();
let mut chain = BufReader::new(Cursor::new(CHAIN));
let certs = certs(&mut chain).unwrap();
let trust_anchors = certs
.iter()
.map(|cert| {
let ta = webpki::TrustAnchor::try_from_cert_der(&cert[..]).unwrap();
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
})
.collect::<Vec<_>>();
client_root_cert_store.add_server_trust_anchors(trust_anchors.into_iter());
let cconfig = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(client_root_cert_store)
.with_no_client_auth();
let client = ClientConnection::new(Arc::new(cconfig), domain).unwrap();
let client = ClientConnection::new(cconfig, domain).unwrap();

(server, client)
}
Expand All @@ -322,3 +286,6 @@ fn do_handshake(

Poll::Ready(Ok(()))
}

// Share `utils` module with integration tests
include!("../../tests/utils.rs");
118 changes: 118 additions & 0 deletions tokio-rustls/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,124 @@ impl TlsAcceptor {
}
}

pub struct LazyConfigAcceptor<IO> {
acceptor: rustls::server::Acceptor,
buf: Vec<u8>,
used: usize,
io: Option<IO>,
}

impl<IO> LazyConfigAcceptor<IO>
where
IO: AsyncRead + AsyncWrite + Unpin,
{
#[inline]
pub fn new(acceptor: rustls::server::Acceptor, io: IO) -> Self {
Self {
acceptor,
buf: vec![0; 512],
used: 0,
io: Some(io),
}
}
}

impl<IO> Future for LazyConfigAcceptor<IO>
where
IO: AsyncRead + AsyncWrite + Unpin,
{
type Output = Result<StartHandshake<IO>, io::Error>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
loop {
let io = match this.io.as_mut() {
Some(io) => io,
None => {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::Other,
"acceptor cannot be polled after acceptance",
)))
}
};

let mut buf = ReadBuf::new(&mut this.buf);
buf.advance(this.used);
if buf.remaining() > 0 {
if let Err(err) = ready!(Pin::new(io).poll_read(cx, &mut buf)) {
return Poll::Ready(Err(err));
}
}

let read = match this.acceptor.read_tls(&mut buf.filled()) {
Ok(read) => read,
Err(err) => return Poll::Ready(Err(err)),
};

let received = buf.filled().len();
if read < received {
this.buf.copy_within(read.., 0);
this.used = received - read;
} else {
this.used = 0;
}

match this.acceptor.accept() {
Ok(Some(accepted)) => {
let io = this.io.take().unwrap();
return Poll::Ready(Ok(StartHandshake { accepted, io }));
}
Ok(None) => continue,
Err(err) => {
return Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidInput, err)))
}
}
}
}
}

pub struct StartHandshake<IO> {
accepted: rustls::server::Accepted,
io: IO,
}

impl<IO> StartHandshake<IO>
where
IO: AsyncRead + AsyncWrite + Unpin,
{
pub fn client_hello(&self) -> rustls::server::ClientHello<'_> {
self.accepted.client_hello()
}

pub fn into_stream(self, config: Arc<ServerConfig>) -> Accept<IO> {
self.into_stream_with(config, |_| ())
}

pub fn into_stream_with<F>(self, config: Arc<ServerConfig>, f: F) -> Accept<IO>
where
F: FnOnce(&mut ServerConnection),
{
let mut conn = match self.accepted.into_connection(config) {
Ok(conn) => conn,
Err(error) => {
return Accept(MidHandshake::Error {
io: self.io,
// TODO(eliza): should this really return an `io::Error`?
// Probably not...
error: io::Error::new(io::ErrorKind::Other, error),
});
}
};
f(&mut conn);

Accept(MidHandshake::Handshaking(server::TlsStream {
session: conn,
io: self.io,
state: TlsState::Stream,
}))
}
}

/// Future returned from `TlsConnector::connect` which will resolve
/// once the connection handshake has finished.
pub struct Connect<IO>(MidHandshake<client::TlsStream<IO>>);
Expand Down
42 changes: 41 additions & 1 deletion tokio-rustls/tests/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use std::{io, thread};
use tokio::io::{copy, split, AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::runtime;
use tokio_rustls::{TlsAcceptor, TlsConnector};
use tokio_rustls::{LazyConfigAcceptor, TlsAcceptor, TlsConnector};

const CERT: &str = include_str!("end.cert");
const CHAIN: &[u8] = include_bytes!("end.chain");
Expand Down Expand Up @@ -164,3 +164,43 @@ async fn fail() -> io::Result<()> {

Ok(())
}

#[tokio::test]
async fn test_lazy_config_acceptor() -> io::Result<()> {
let (sconfig, cconfig) = utils::make_configs();
use std::convert::TryFrom;

let (cstream, sstream) = tokio::io::duplex(1200);
let domain = rustls::ServerName::try_from("localhost").unwrap();
tokio::spawn(async move {
let connector = crate::TlsConnector::from(cconfig);
let mut client = connector.connect(domain, cstream).await.unwrap();
client.write_all(b"hello, world!").await.unwrap();

let mut buf = Vec::new();
client.read_to_end(&mut buf).await.unwrap();
});

let acceptor = LazyConfigAcceptor::new(rustls::server::Acceptor::new().unwrap(), sstream);
let start = acceptor.await.unwrap();
let ch = start.client_hello();

assert_eq!(ch.server_name(), Some("localhost"));
assert_eq!(
ch.alpn()
.map(|protos| protos.collect::<Vec<_>>())
.unwrap_or(Vec::new()),
Vec::<&[u8]>::new()
);

let mut stream = start.into_stream(sconfig).await.unwrap();
let mut buf = [0; 13];
stream.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf[..], b"hello, world!");

stream.write_all(b"bye").await.unwrap();
Ok(())
}

// Include `utils` module
include!("utils.rs");
49 changes: 49 additions & 0 deletions tokio-rustls/tests/utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
mod utils {
use std::io::{BufReader, Cursor};
use std::sync::Arc;

use rustls::{ClientConfig, OwnedTrustAnchor, PrivateKey, RootCertStore, ServerConfig};
use rustls_pemfile::{certs, rsa_private_keys};

#[allow(dead_code)]
pub fn make_configs() -> (Arc<ServerConfig>, Arc<ClientConfig>) {
const CERT: &str = include_str!("end.cert");
const CHAIN: &str = include_str!("end.chain");
const RSA: &str = include_str!("end.rsa");

let cert = certs(&mut BufReader::new(Cursor::new(CERT)))
.unwrap()
.drain(..)
.map(rustls::Certificate)
.collect();
let mut keys = rsa_private_keys(&mut BufReader::new(Cursor::new(RSA))).unwrap();
let mut keys = keys.drain(..).map(PrivateKey);
let sconfig = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(cert, keys.next().unwrap())
.unwrap();

let mut client_root_cert_store = RootCertStore::empty();
let mut chain = BufReader::new(Cursor::new(CHAIN));
let certs = certs(&mut chain).unwrap();
let trust_anchors = certs
.iter()
.map(|cert| {
let ta = webpki::TrustAnchor::try_from_cert_der(&cert[..]).unwrap();
OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
})
.collect::<Vec<_>>();
client_root_cert_store.add_server_trust_anchors(trust_anchors.into_iter());
let cconfig = ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(client_root_cert_store)
.with_no_client_auth();

(Arc::new(sconfig), Arc::new(cconfig))
}
}

0 comments on commit 3350601

Please sign in to comment.