diff --git a/tokio-boring/src/async_callbacks.rs b/tokio-boring/src/async_callbacks.rs index b12ad2b2..8fa9494b 100644 --- a/tokio-boring/src/async_callbacks.rs +++ b/tokio-boring/src/async_callbacks.rs @@ -1,6 +1,7 @@ use boring::ex_data::Index; use boring::ssl::{self, ClientHello, PrivateKeyMethod, Ssl, SslContextBuilder}; use once_cell::sync::Lazy; +use std::convert::identity; use std::future::Future; use std::pin::Pin; use std::task::{ready, Context, Poll, Waker}; @@ -19,6 +20,12 @@ pub type BoxPrivateKeyMethodFuture = pub type BoxPrivateKeyMethodFinish = Box Result>; +/// The type of futures to pass to [`SslContextBuilderExt::set_async_get_session_callback`]. +pub type BoxGetSessionFuture = ExDataFuture>; + +/// The type of callbacks returned by [`BoxSelectCertFuture`] methods. +pub type BoxGetSessionFinish = Box Option>; + /// Convenience alias for futures stored in [`Ssl`] ex data by [`SslContextBuilderExt`] methods. /// /// Public for documentation purposes. @@ -31,6 +38,8 @@ pub(crate) static SELECT_CERT_FUTURE_INDEX: Lazy>, > = Lazy::new(|| Ssl::new_ex_index().unwrap()); +pub(crate) static SELECT_GET_SESSION_FUTURE_INDEX: Lazy>> = + Lazy::new(|| Ssl::new_ex_index().unwrap()); /// Extensions to [`SslContextBuilder`]. /// @@ -57,6 +66,23 @@ pub trait SslContextBuilderExt: private::Sealed { /// /// See [`AsyncPrivateKeyMethod`] for more details. fn set_async_private_key_method(&mut self, method: impl AsyncPrivateKeyMethod); + + /// Sets a callback that is called when a client proposed to resume a session + /// but it was not found in the internal cache. + /// + /// The callback is passed a reference to the session ID provided by the client. + /// It should return the session corresponding to that ID if available. This is + /// only used for servers, not clients. + /// + /// See [`SslContextBuilder::set_get_session_callback`] for the sync setter + /// of this callback. + /// + /// # Safety + /// + /// The returned [`SslSession`] must not be associated with a different [`SslContext`]. + unsafe fn set_async_get_session_callback(&mut self, callback: F) + where + F: Fn(&mut ssl::SslRef, &[u8]) -> Option + Send + Sync + 'static; } impl SslContextBuilderExt for SslContextBuilder { @@ -73,6 +99,7 @@ impl SslContextBuilderExt for SslContextBuilder { *SELECT_CERT_FUTURE_INDEX, ClientHello::ssl_mut, &callback, + identity, ); let fut_result = match fut_poll_result { @@ -89,6 +116,29 @@ impl SslContextBuilderExt for SslContextBuilder { fn set_async_private_key_method(&mut self, method: impl AsyncPrivateKeyMethod) { self.set_private_key_method(AsyncPrivateKeyMethodBridge(Box::new(method))); } + + unsafe fn set_async_get_session_callback(&mut self, callback: F) + where + F: Fn(&mut ssl::SslRef, &[u8]) -> Option + Send + Sync + 'static, + { + let async_callback = move |ssl: &mut ssl::SslRef, id: &[u8]| { + let fut_poll_result = with_ex_data_future( + &mut *ssl, + *SELECT_GET_SESSION_FUTURE_INDEX, + |ssl| ssl, + |ssl| callback(ssl, id).ok_or(()), + |option| option.ok_or(()), + ); + + match fut_poll_result { + Poll::Ready(Err(())) => Ok(None), + Poll::Ready(Ok(finish)) => Ok(finish(ssl, id)), + Poll::Pending => Err(ssl::GetSessionPendingError), + } + }; + + self.set_get_session_callback(async_callback) + } } /// A fatal error to be returned from async select certificate callbacks. @@ -201,6 +251,7 @@ fn with_private_key_method( *SELECT_PRIVATE_KEY_METHOD_FUTURE_INDEX, |ssl| ssl, |ssl| create_fut(ssl, output), + identity, ); let fut_result = match fut_poll_result { @@ -217,11 +268,12 @@ fn with_private_key_method( /// /// This function won't even bother storing the future in `index` if the future /// created by `create_fut` returns `Poll::Ready(_)` on the first poll call. -fn with_ex_data_future( +fn with_ex_data_future( ssl_handle: &mut H, - index: Index>>>, + index: Index>>, get_ssl_mut: impl Fn(&mut H) -> &mut ssl::SslRef, - create_fut: impl FnOnce(&mut H) -> Result>, E>, + create_fut: impl FnOnce(&mut H) -> Result, E>, + into_result: impl Fn(R) -> Result, ) -> Poll> { let ssl = get_ssl_mut(ssl_handle); let waker = ssl @@ -233,7 +285,7 @@ fn with_ex_data_future( let mut ctx = Context::from_waker(&waker); if let Some(data @ Some(_)) = ssl.ex_data_mut(index) { - let fut_result = ready!(data.as_mut().unwrap().as_mut().poll(&mut ctx)); + let fut_result = into_result(ready!(data.as_mut().unwrap().as_mut().poll(&mut ctx))); *data = None; @@ -242,7 +294,7 @@ fn with_ex_data_future( let mut fut = create_fut(ssl_handle)?; match fut.as_mut().poll(&mut ctx) { - Poll::Ready(fut_result) => Poll::Ready(fut_result), + Poll::Ready(fut_result) => Poll::Ready(into_result(fut_result)), Poll::Pending => { get_ssl_mut(ssl_handle).set_ex_data(index, Some(fut)); diff --git a/tokio-boring/src/lib.rs b/tokio-boring/src/lib.rs index a8ab50ad..a3c4eae2 100644 --- a/tokio-boring/src/lib.rs +++ b/tokio-boring/src/lib.rs @@ -32,9 +32,9 @@ mod bridge; use self::async_callbacks::TASK_WAKER_INDEX; pub use self::async_callbacks::{ - AsyncPrivateKeyMethod, AsyncPrivateKeyMethodError, AsyncSelectCertError, - BoxPrivateKeyMethodFinish, BoxPrivateKeyMethodFuture, BoxSelectCertFinish, BoxSelectCertFuture, - ExDataFuture, SslContextBuilderExt, + AsyncPrivateKeyMethod, AsyncPrivateKeyMethodError, AsyncSelectCertError, BoxGetSessionFinish, + BoxGetSessionFuture, BoxPrivateKeyMethodFinish, BoxPrivateKeyMethodFuture, BoxSelectCertFinish, + BoxSelectCertFuture, ExDataFuture, SslContextBuilderExt, }; use self::bridge::AsyncStreamBridge; diff --git a/tokio-boring/tests/async_get_session.rs b/tokio-boring/tests/async_get_session.rs new file mode 100644 index 00000000..4035acf0 --- /dev/null +++ b/tokio-boring/tests/async_get_session.rs @@ -0,0 +1,111 @@ +use boring::ssl::{SslOptions, SslRef, SslSession, SslSessionCacheMode, SslVersion}; +use futures::future; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::OnceLock; +use tokio::net::TcpStream; +use tokio::task::yield_now; +use tokio_boring::{BoxGetSessionFinish, SslContextBuilderExt}; + +mod common; + +use self::common::{create_acceptor, create_connector, create_listener}; + +#[tokio::test] +async fn test() { + static FOUND_SESSION: AtomicBool = AtomicBool::new(false); + static SERVER_SESSION_DER: OnceLock> = OnceLock::new(); + static CLIENT_SESSION_DER: OnceLock> = OnceLock::new(); + + let (listener, addr) = create_listener(); + + let acceptor = create_acceptor(move |builder| { + builder + .set_max_proto_version(Some(SslVersion::TLS1_2)) + .unwrap(); + builder.set_options(SslOptions::NO_TICKET); + builder + .set_session_cache_mode(SslSessionCacheMode::SERVER | SslSessionCacheMode::NO_INTERNAL); + builder.set_new_session_callback(|_, session| { + SERVER_SESSION_DER.set(session.to_der().unwrap()).unwrap() + }); + + unsafe { + builder.set_async_get_session_callback(|_, _| { + let Some(der) = SERVER_SESSION_DER.get() else { + return None; + }; + + Some(Box::pin(async move { + yield_now().await; + + FOUND_SESSION.store(true, Ordering::SeqCst); + + Some(Box::new(|_: &mut SslRef, _: &[u8]| { + Some(SslSession::from_der(der).unwrap()) + }) as BoxGetSessionFinish) + })) + }); + } + }); + + let connector = create_connector(|builder| { + builder.set_session_cache_mode(SslSessionCacheMode::CLIENT); + builder.set_new_session_callback(|_, session| { + assert_eq!( + SslSession::from_der(SERVER_SESSION_DER.get().unwrap()) + .unwrap() + .id(), + session.id() + ); + + CLIENT_SESSION_DER.set(session.to_der().unwrap()).unwrap() + }); + + builder.set_ca_file("tests/cert.pem") + }); + + let server = async move { + tokio_boring::accept(&acceptor, listener.accept().await.unwrap().0) + .await + .unwrap(); + + assert!(SERVER_SESSION_DER.get().is_some()); + assert!(!FOUND_SESSION.load(Ordering::SeqCst)); + + tokio_boring::accept(&acceptor, listener.accept().await.unwrap().0) + .await + .unwrap(); + + assert!(FOUND_SESSION.load(Ordering::SeqCst)); + }; + + let client = async move { + tokio_boring::connect( + connector.configure().unwrap(), + "localhost", + TcpStream::connect(&addr).await.unwrap(), + ) + .await + .unwrap(); + + let der = CLIENT_SESSION_DER.get().unwrap(); + + let mut config = connector.configure().unwrap(); + + unsafe { + config + .set_session(&SslSession::from_der(der).unwrap()) + .unwrap(); + } + + tokio_boring::connect( + config, + "localhost", + TcpStream::connect(&addr).await.unwrap(), + ) + .await + .unwrap(); + }; + + future::join(server, client).await; +} diff --git a/tokio-boring/tests/common/mod.rs b/tokio-boring/tests/common/mod.rs index 6ed394ef..b28917b4 100644 --- a/tokio-boring/tests/common/mod.rs +++ b/tokio-boring/tests/common/mod.rs @@ -17,6 +17,20 @@ pub(crate) fn create_server( impl Future, HandshakeError>>, SocketAddr, ) { + let (listener, addr) = create_listener(); + + let server = async move { + let acceptor = create_acceptor(setup); + + let stream = listener.accept().await.unwrap().0; + + tokio_boring::accept(&acceptor, stream).await + }; + + (server, addr) +} + +pub(crate) fn create_listener() -> (TcpListener, SocketAddr) { let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); listener.set_nonblocking(true).unwrap(); @@ -24,44 +38,46 @@ pub(crate) fn create_server( let listener = TcpListener::from_std(listener).unwrap(); let addr = listener.local_addr().unwrap(); - let server = async move { - let mut acceptor = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); - - acceptor - .set_private_key_file("tests/key.pem", SslFiletype::PEM) - .unwrap(); - - acceptor - .set_certificate_chain_file("tests/cert.pem") - .unwrap(); + (listener, addr) +} - setup(&mut acceptor); +pub(crate) fn create_acceptor(setup: impl FnOnce(&mut SslAcceptorBuilder)) -> SslAcceptor { + let mut acceptor = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); - let acceptor = acceptor.build(); + acceptor + .set_private_key_file("tests/key.pem", SslFiletype::PEM) + .unwrap(); - let stream = listener.accept().await.unwrap().0; + acceptor + .set_certificate_chain_file("tests/cert.pem") + .unwrap(); - tokio_boring::accept(&acceptor, stream).await - }; + setup(&mut acceptor); - (server, addr) + acceptor.build() } pub(crate) async fn connect( addr: SocketAddr, setup: impl FnOnce(&mut SslConnectorBuilder) -> Result<(), ErrorStack>, ) -> Result, HandshakeError> { - let mut connector = SslConnector::builder(SslMethod::tls()).unwrap(); - - setup(&mut connector).unwrap(); - - let config = connector.build().configure().unwrap(); + let config = create_connector(setup).configure().unwrap(); let stream = TcpStream::connect(&addr).await.unwrap(); tokio_boring::connect(config, "localhost", stream).await } +pub(crate) fn create_connector( + setup: impl FnOnce(&mut SslConnectorBuilder) -> Result<(), ErrorStack>, +) -> SslConnector { + let mut connector = SslConnector::builder(SslMethod::tls()).unwrap(); + + setup(&mut connector).unwrap(); + + connector.build() +} + pub(crate) async fn with_trivial_client_server_exchange( server_setup: impl FnOnce(&mut SslAcceptorBuilder), ) {