diff --git a/boring/src/ssl/mod.rs b/boring/src/ssl/mod.rs index 1420ee32..90ad1ed8 100644 --- a/boring/src/ssl/mod.rs +++ b/boring/src/ssl/mod.rs @@ -864,7 +864,9 @@ impl SslContextBuilder { // out. When that happens, we wouldn't be able to look up the callback's state in the // context's ex data. Instead, pass the pointer directly as the servername arg. It's // still stored in ex data to manage the lifetime. - let arg = self.set_ex_data_inner(SslContext::cached_ex_index::(), callback); + let arg = self + .ctx + .set_ex_data(SslContext::cached_ex_index::(), callback); ffi::SSL_CTX_set_tlsext_servername_arg(self.as_ptr(), arg); ffi::SSL_CTX_set_tlsext_servername_callback(self.as_ptr(), Some(raw_sni::)); @@ -1653,19 +1655,30 @@ impl SslContextBuilder { /// /// This corresponds to [`SSL_CTX_set_ex_data`]. /// + /// Note that if this method is called multiple times with the same index, any previous + /// value stored in the `SslContextBuilder` will be leaked. + /// /// [`SSL_CTX_set_ex_data`]: https://www.openssl.org/docs/man1.0.2/ssl/SSL_CTX_set_ex_data.html pub fn set_ex_data(&mut self, index: Index, data: T) { - self.set_ex_data_inner(index, data); - } - - fn set_ex_data_inner(&mut self, index: Index, data: T) -> *mut c_void { unsafe { - let data = Box::into_raw(Box::new(data)) as *mut c_void; - ffi::SSL_CTX_set_ex_data(self.as_ptr(), index.as_raw(), data); - data + self.ctx.set_ex_data(index, data); } } + /// Sets or overwrites the extra data at the specified index. + /// + /// This can be used to provide data to callbacks registered with the context. Use the + /// `Ssl::new_ex_index` method to create an `Index`. + /// + /// This corresponds to [`SSL_set_ex_data`]. + /// + /// Any previous value will be returned and replaced by the new one. + /// + /// [`SSL_set_ex_data`]: https://www.openssl.org/docs/manmaster/man3/SSL_set_ex_data.html + pub fn replace_ex_data(&mut self, index: Index, data: T) -> Option { + unsafe { self.ctx.replace_ex_data(index, data) } + } + /// Sets the context's session cache size limit, returning the previous limit. /// /// A value of 0 means that the cache size is unbounded. @@ -1916,6 +1929,39 @@ impl SslContextRef { } } + // Unsafe because SSL contexts are not guaranteed to be unique, we call + // this only from SslContextBuilder. + unsafe fn ex_data_mut(&mut self, index: Index) -> Option<&mut T> { + let data = ffi::SSL_CTX_get_ex_data(self.as_ptr(), index.as_raw()); + if data.is_null() { + None + } else { + Some(&mut *(data as *mut T)) + } + } + + // Unsafe because SSL contexts are not guaranteed to be unique, we call + // this only from SslContextBuilder. + unsafe fn set_ex_data(&mut self, index: Index, data: T) -> *mut c_void { + unsafe { + let data = Box::into_raw(Box::new(data)) as *mut c_void; + ffi::SSL_CTX_set_ex_data(self.as_ptr(), index.as_raw(), data); + data + } + } + + // Unsafe because SSL contexts are not guaranteed to be unique, we call + // this only from SslContextBuilder. + unsafe fn replace_ex_data(&mut self, index: Index, data: T) -> Option { + if let Some(old) = self.ex_data_mut(index) { + return Some(mem::replace(old, data)); + } + + self.set_ex_data(index, data); + + None + } + /// Adds a session to the context's cache. /// /// Returns `true` if the session was successfully added to the cache, and `false` if it was already present. @@ -3191,8 +3237,17 @@ impl SslRef { /// /// This corresponds to [`SSL_set_ex_data`]. /// + /// Note that if this method is called multiple times with the same index, any previous + /// value stored in the `SslContextBuilder` will be leaked. + /// /// [`SSL_set_ex_data`]: https://www.openssl.org/docs/manmaster/man3/SSL_set_ex_data.html pub fn set_ex_data(&mut self, index: Index, data: T) { + if let Some(old) = self.ex_data_mut(index) { + *old = data; + + return; + } + unsafe { let data = Box::new(data); ffi::SSL_set_ex_data( @@ -3203,6 +3258,26 @@ impl SslRef { } } + /// Sets or overwrites the extra data at the specified index. + /// + /// This can be used to provide data to callbacks registered with the context. Use the + /// `Ssl::new_ex_index` method to create an `Index`. + /// + /// This corresponds to [`SSL_set_ex_data`]. + /// + /// Any previous value will be dropped and replaced by the new one. + /// + /// [`SSL_set_ex_data`]: https://www.openssl.org/docs/manmaster/man3/SSL_set_ex_data.html + pub fn replace_ex_data(&mut self, index: Index, data: T) -> Option { + if let Some(old) = self.ex_data_mut(index) { + return Some(mem::replace(old, data)); + } + + self.set_ex_data(index, data); + + None + } + /// Returns a reference to the extra data at the specified index. /// /// This corresponds to [`SSL_get_ex_data`]. diff --git a/boring/src/ssl/test/mod.rs b/boring/src/ssl/test/mod.rs index d3319f5f..c6199352 100644 --- a/boring/src/ssl/test/mod.rs +++ b/boring/src/ssl/test/mod.rs @@ -1044,3 +1044,24 @@ fn server_set_default_curves_list() { // Panics if Kyber768 missing in boringSSL. ssl.server_set_default_curves_list(); } + +#[test] +fn drop_ex_data_in_context() { + let index = SslContext::new_ex_index::<&'static str>().unwrap(); + let mut ctx = SslContext::builder(SslMethod::dtls()).unwrap(); + + assert_eq!(ctx.replace_ex_data(index, "comté"), None); + assert_eq!(ctx.replace_ex_data(index, "camembert"), Some("comté")); + assert_eq!(ctx.replace_ex_data(index, "raclette"), Some("camembert")); +} + +#[test] +fn drop_ex_data_in_ssl() { + let index = Ssl::new_ex_index::<&'static str>().unwrap(); + let ctx = SslContext::builder(SslMethod::dtls()).unwrap().build(); + let mut ssl = Ssl::new(&ctx).unwrap(); + + assert_eq!(ssl.replace_ex_data(index, "comté"), None); + assert_eq!(ssl.replace_ex_data(index, "camembert"), Some("comté")); + assert_eq!(ssl.replace_ex_data(index, "raclette"), Some("camembert")); +} diff --git a/tokio-boring/src/lib.rs b/tokio-boring/src/lib.rs index adf4a63a..ec837dc9 100644 --- a/tokio-boring/src/lib.rs +++ b/tokio-boring/src/lib.rs @@ -300,18 +300,20 @@ where mid_handshake.get_mut().set_waker(Some(ctx)); mid_handshake .ssl_mut() - .set_ex_data(*TASK_WAKER_INDEX, Some(ctx.waker().clone())); + .replace_ex_data(*TASK_WAKER_INDEX, Some(ctx.waker().clone())); match mid_handshake.handshake() { Ok(mut stream) => { stream.get_mut().set_waker(None); - stream.ssl_mut().set_ex_data(*TASK_WAKER_INDEX, None); + stream.ssl_mut().replace_ex_data(*TASK_WAKER_INDEX, None); Poll::Ready(Ok(SslStream(stream))) } Err(ssl::HandshakeError::WouldBlock(mut mid_handshake)) => { mid_handshake.get_mut().set_waker(None); - mid_handshake.ssl_mut().set_ex_data(*TASK_WAKER_INDEX, None); + mid_handshake + .ssl_mut() + .replace_ex_data(*TASK_WAKER_INDEX, None); self.0 = Some(mid_handshake);