Skip to content

Commit

Permalink
Introduce set_async_get_session_callback
Browse files Browse the repository at this point in the history
  • Loading branch information
nox committed Oct 24, 2023
1 parent df60d1e commit 2ffe1fe
Show file tree
Hide file tree
Showing 4 changed files with 208 additions and 29 deletions.
62 changes: 57 additions & 5 deletions tokio-boring/src/async_callbacks.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use boring::ex_data::Index;
use boring::ssl::{self, ClientHello, PrivateKeyMethod, Ssl, SslContextBuilder};
use once_cell::sync::Lazy;
use std::convert::identity;
use std::future::Future;
use std::pin::Pin;
use std::task::{ready, Context, Poll, Waker};
Expand All @@ -19,6 +20,12 @@ pub type BoxPrivateKeyMethodFuture =
pub type BoxPrivateKeyMethodFinish =
Box<dyn FnOnce(&mut ssl::SslRef, &mut [u8]) -> Result<usize, AsyncPrivateKeyMethodError>>;

/// The type of futures to pass to [`SslContextBuilderExt::set_async_get_session_callback`].
pub type BoxGetSessionFuture = ExDataFuture<Option<BoxGetSessionFinish>>;

/// The type of callbacks returned by [`BoxSelectCertFuture`] methods.
pub type BoxGetSessionFinish = Box<dyn FnOnce(&mut ssl::SslRef, &[u8]) -> Option<ssl::SslSession>>;

/// Convenience alias for futures stored in [`Ssl`] ex data by [`SslContextBuilderExt`] methods.
///
/// Public for documentation purposes.
Expand All @@ -31,6 +38,8 @@ pub(crate) static SELECT_CERT_FUTURE_INDEX: Lazy<Index<Ssl, Option<BoxSelectCert
pub(crate) static SELECT_PRIVATE_KEY_METHOD_FUTURE_INDEX: Lazy<
Index<Ssl, Option<BoxPrivateKeyMethodFuture>>,
> = Lazy::new(|| Ssl::new_ex_index().unwrap());
pub(crate) static SELECT_GET_SESSION_FUTURE_INDEX: Lazy<Index<Ssl, Option<BoxGetSessionFuture>>> =
Lazy::new(|| Ssl::new_ex_index().unwrap());

/// Extensions to [`SslContextBuilder`].
///
Expand All @@ -57,6 +66,23 @@ pub trait SslContextBuilderExt: private::Sealed {
///
/// See [`AsyncPrivateKeyMethod`] for more details.
fn set_async_private_key_method(&mut self, method: impl AsyncPrivateKeyMethod);

/// Sets a callback that is called when a client proposed to resume a session
/// but it was not found in the internal cache.
///
/// The callback is passed a reference to the session ID provided by the client.
/// It should return the session corresponding to that ID if available. This is
/// only used for servers, not clients.
///
/// See [`SslContextBuilder::set_get_session_callback`] for the sync setter
/// of this callback.
///
/// # Safety
///
/// The returned [`SslSession`] must not be associated with a different [`SslContext`].
unsafe fn set_async_get_session_callback<F>(&mut self, callback: F)
where
F: Fn(&mut ssl::SslRef, &[u8]) -> Option<BoxGetSessionFuture> + Send + Sync + 'static;
}

impl SslContextBuilderExt for SslContextBuilder {
Expand All @@ -73,6 +99,7 @@ impl SslContextBuilderExt for SslContextBuilder {
*SELECT_CERT_FUTURE_INDEX,
ClientHello::ssl_mut,
&callback,
identity,
);

let fut_result = match fut_poll_result {
Expand All @@ -89,6 +116,29 @@ impl SslContextBuilderExt for SslContextBuilder {
fn set_async_private_key_method(&mut self, method: impl AsyncPrivateKeyMethod) {
self.set_private_key_method(AsyncPrivateKeyMethodBridge(Box::new(method)));
}

unsafe fn set_async_get_session_callback<F>(&mut self, callback: F)
where
F: Fn(&mut ssl::SslRef, &[u8]) -> Option<BoxGetSessionFuture> + Send + Sync + 'static,
{
let async_callback = move |ssl: &mut ssl::SslRef, id: &[u8]| {
let fut_poll_result = with_ex_data_future(
&mut *ssl,
*SELECT_GET_SESSION_FUTURE_INDEX,
|ssl| ssl,
|ssl| callback(ssl, id).ok_or(()),
|option| option.ok_or(()),
);

match fut_poll_result {
Poll::Ready(Err(())) => Ok(None),
Poll::Ready(Ok(finish)) => Ok(finish(ssl, id)),
Poll::Pending => Err(ssl::GetSessionPendingError),
}
};

self.set_get_session_callback(async_callback)
}
}

/// A fatal error to be returned from async select certificate callbacks.
Expand Down Expand Up @@ -201,6 +251,7 @@ fn with_private_key_method(
*SELECT_PRIVATE_KEY_METHOD_FUTURE_INDEX,
|ssl| ssl,
|ssl| create_fut(ssl, output),
identity,
);

let fut_result = match fut_poll_result {
Expand All @@ -217,11 +268,12 @@ fn with_private_key_method(
///
/// This function won't even bother storing the future in `index` if the future
/// created by `create_fut` returns `Poll::Ready(_)` on the first poll call.
fn with_ex_data_future<H, T, E>(
fn with_ex_data_future<H, R, T, E>(
ssl_handle: &mut H,
index: Index<ssl::Ssl, Option<ExDataFuture<Result<T, E>>>>,
index: Index<ssl::Ssl, Option<ExDataFuture<R>>>,
get_ssl_mut: impl Fn(&mut H) -> &mut ssl::SslRef,
create_fut: impl FnOnce(&mut H) -> Result<ExDataFuture<Result<T, E>>, E>,
create_fut: impl FnOnce(&mut H) -> Result<ExDataFuture<R>, E>,
into_result: impl Fn(R) -> Result<T, E>,
) -> Poll<Result<T, E>> {
let ssl = get_ssl_mut(ssl_handle);
let waker = ssl
Expand All @@ -233,7 +285,7 @@ fn with_ex_data_future<H, T, E>(
let mut ctx = Context::from_waker(&waker);

if let Some(data @ Some(_)) = ssl.ex_data_mut(index) {
let fut_result = ready!(data.as_mut().unwrap().as_mut().poll(&mut ctx));
let fut_result = into_result(ready!(data.as_mut().unwrap().as_mut().poll(&mut ctx)));

*data = None;

Expand All @@ -242,7 +294,7 @@ fn with_ex_data_future<H, T, E>(
let mut fut = create_fut(ssl_handle)?;

match fut.as_mut().poll(&mut ctx) {
Poll::Ready(fut_result) => Poll::Ready(fut_result),
Poll::Ready(fut_result) => Poll::Ready(into_result(fut_result)),
Poll::Pending => {
get_ssl_mut(ssl_handle).set_ex_data(index, Some(fut));

Expand Down
6 changes: 3 additions & 3 deletions tokio-boring/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ mod bridge;

use self::async_callbacks::TASK_WAKER_INDEX;
pub use self::async_callbacks::{
AsyncPrivateKeyMethod, AsyncPrivateKeyMethodError, AsyncSelectCertError,
BoxPrivateKeyMethodFinish, BoxPrivateKeyMethodFuture, BoxSelectCertFinish, BoxSelectCertFuture,
ExDataFuture, SslContextBuilderExt,
AsyncPrivateKeyMethod, AsyncPrivateKeyMethodError, AsyncSelectCertError, BoxGetSessionFinish,
BoxGetSessionFuture, BoxPrivateKeyMethodFinish, BoxPrivateKeyMethodFuture, BoxSelectCertFinish,
BoxSelectCertFuture, ExDataFuture, SslContextBuilderExt,
};
use self::bridge::AsyncStreamBridge;

Expand Down
111 changes: 111 additions & 0 deletions tokio-boring/tests/async_get_session.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
use boring::ssl::{SslOptions, SslRef, SslSession, SslSessionCacheMode, SslVersion};
use futures::future;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::OnceLock;
use tokio::net::TcpStream;
use tokio::task::yield_now;
use tokio_boring::{BoxGetSessionFinish, SslContextBuilderExt};

mod common;

use self::common::{create_acceptor, create_connector, create_listener};

#[tokio::test]
async fn test() {
static FOUND_SESSION: AtomicBool = AtomicBool::new(false);
static SERVER_SESSION_DER: OnceLock<Vec<u8>> = OnceLock::new();
static CLIENT_SESSION_DER: OnceLock<Vec<u8>> = OnceLock::new();

let (listener, addr) = create_listener();

let acceptor = create_acceptor(move |builder| {
builder
.set_max_proto_version(Some(SslVersion::TLS1_2))
.unwrap();
builder.set_options(SslOptions::NO_TICKET);
builder
.set_session_cache_mode(SslSessionCacheMode::SERVER | SslSessionCacheMode::NO_INTERNAL);
builder.set_new_session_callback(|_, session| {
SERVER_SESSION_DER.set(session.to_der().unwrap()).unwrap()
});

unsafe {
builder.set_async_get_session_callback(|_, _| {
let Some(der) = SERVER_SESSION_DER.get() else {
return None;
};

Some(Box::pin(async move {
yield_now().await;

FOUND_SESSION.store(true, Ordering::SeqCst);

Some(Box::new(|_: &mut SslRef, _: &[u8]| {
Some(SslSession::from_der(der).unwrap())
}) as BoxGetSessionFinish)
}))
});
}
});

let connector = create_connector(|builder| {
builder.set_session_cache_mode(SslSessionCacheMode::CLIENT);
builder.set_new_session_callback(|_, session| {
assert_eq!(
SslSession::from_der(SERVER_SESSION_DER.get().unwrap())
.unwrap()
.id(),
session.id()
);

CLIENT_SESSION_DER.set(session.to_der().unwrap()).unwrap()
});

builder.set_ca_file("tests/cert.pem")
});

let server = async move {
tokio_boring::accept(&acceptor, listener.accept().await.unwrap().0)
.await
.unwrap();

assert!(SERVER_SESSION_DER.get().is_some());
assert!(!FOUND_SESSION.load(Ordering::SeqCst));

tokio_boring::accept(&acceptor, listener.accept().await.unwrap().0)
.await
.unwrap();

assert!(FOUND_SESSION.load(Ordering::SeqCst));
};

let client = async move {
tokio_boring::connect(
connector.configure().unwrap(),
"localhost",
TcpStream::connect(&addr).await.unwrap(),
)
.await
.unwrap();

let der = CLIENT_SESSION_DER.get().unwrap();

let mut config = connector.configure().unwrap();

unsafe {
config
.set_session(&SslSession::from_der(der).unwrap())
.unwrap();
}

tokio_boring::connect(
config,
"localhost",
TcpStream::connect(&addr).await.unwrap(),
)
.await
.unwrap();
};

future::join(server, client).await;
}
58 changes: 37 additions & 21 deletions tokio-boring/tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,51 +17,67 @@ pub(crate) fn create_server(
impl Future<Output = Result<SslStream<TcpStream>, HandshakeError<TcpStream>>>,
SocketAddr,
) {
let (listener, addr) = create_listener();

let server = async move {
let acceptor = create_acceptor(setup);

let stream = listener.accept().await.unwrap().0;

tokio_boring::accept(&acceptor, stream).await
};

(server, addr)
}

pub(crate) fn create_listener() -> (TcpListener, SocketAddr) {
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();

listener.set_nonblocking(true).unwrap();

let listener = TcpListener::from_std(listener).unwrap();
let addr = listener.local_addr().unwrap();

let server = async move {
let mut acceptor = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap();

acceptor
.set_private_key_file("tests/key.pem", SslFiletype::PEM)
.unwrap();

acceptor
.set_certificate_chain_file("tests/cert.pem")
.unwrap();
(listener, addr)
}

setup(&mut acceptor);
pub(crate) fn create_acceptor(setup: impl FnOnce(&mut SslAcceptorBuilder)) -> SslAcceptor {
let mut acceptor = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap();

let acceptor = acceptor.build();
acceptor
.set_private_key_file("tests/key.pem", SslFiletype::PEM)
.unwrap();

let stream = listener.accept().await.unwrap().0;
acceptor
.set_certificate_chain_file("tests/cert.pem")
.unwrap();

tokio_boring::accept(&acceptor, stream).await
};
setup(&mut acceptor);

(server, addr)
acceptor.build()
}

pub(crate) async fn connect(
addr: SocketAddr,
setup: impl FnOnce(&mut SslConnectorBuilder) -> Result<(), ErrorStack>,
) -> Result<SslStream<TcpStream>, HandshakeError<TcpStream>> {
let mut connector = SslConnector::builder(SslMethod::tls()).unwrap();

setup(&mut connector).unwrap();

let config = connector.build().configure().unwrap();
let config = create_connector(setup).configure().unwrap();

let stream = TcpStream::connect(&addr).await.unwrap();

tokio_boring::connect(config, "localhost", stream).await
}

pub(crate) fn create_connector(
setup: impl FnOnce(&mut SslConnectorBuilder) -> Result<(), ErrorStack>,
) -> SslConnector {
let mut connector = SslConnector::builder(SslMethod::tls()).unwrap();

setup(&mut connector).unwrap();

connector.build()
}

pub(crate) async fn with_trivial_client_server_exchange(
server_setup: impl FnOnce(&mut SslAcceptorBuilder),
) {
Expand Down

0 comments on commit 2ffe1fe

Please sign in to comment.