diff --git a/mbedtls/src/ssl/context.rs b/mbedtls/src/ssl/context.rs index 04a48a973..b8163db8f 100644 --- a/mbedtls/src/ssl/context.rs +++ b/mbedtls/src/ssl/context.rs @@ -37,7 +37,7 @@ pub trait IoCallback { fn data_ptr(&mut self) -> *mut c_void; } -impl IoCallback for IO { +impl IoCallback for IO { unsafe extern "C" fn call_recv(user_data: *mut c_void, data: *mut c_uchar, len: size_t) -> c_int { let len = if len > (c_int::max_value() as size_t) { c_int::max_value() as size_t @@ -140,7 +140,7 @@ impl Context { } } -impl Context { +impl Context { pub fn establish(&mut self, io: T, hostname: Option<&str>) -> Result<()> { unsafe { let mut io = Box::new(io); @@ -412,13 +412,67 @@ impl HandshakeContext { #[cfg(test)] mod tests { - use crate::ssl::context::HandshakeContext; + use crate::ssl::context::{HandshakeContext, Context}; use crate::tests::TestTrait; - + use std::io::{Read,Write, Result as IoResult}; + use core::ptr::NonNull; + #[test] fn handshakecontext_sync() { assert!(!TestTrait::::new().impls_trait(), "HandshakeContext must be !Sync"); } + + struct NonSendStream { + _buffer: NonNull, + } + + impl Read for NonSendStream { + fn read(&mut self, _: &mut [u8]) -> IoResult { + unimplemented!() + } + } + + impl Write for NonSendStream { + fn write(&mut self, _: &[u8]) -> IoResult { + unimplemented!() + } + + fn flush(&mut self) -> IoResult<()> { + unimplemented!() + } + } + + + struct SendStream { + _buffer: Vec, + } + + impl Read for SendStream { + fn read(&mut self, _: &mut [u8]) -> IoResult { + unimplemented!() + } + } + + impl Write for SendStream { + fn write(&mut self, _: &[u8]) -> IoResult { + unimplemented!() + } + + fn flush(&mut self) -> IoResult<()> { + unimplemented!() + } + } + + + #[test] + fn context_send() { + assert!(!TestTrait::::new().impls_trait(), "NonSendStream can't be send"); + assert!(!TestTrait::>::new().impls_trait(), "Context can't be send"); + + assert!(TestTrait::::new().impls_trait(), "SendStream is send"); + assert!(TestTrait::>::new().impls_trait(), "Context is send"); + } + } // ssl_get_alpn_protocol diff --git a/mbedtls/tests/hyper.rs b/mbedtls/tests/hyper.rs index 475599066..58fb9b065 100644 --- a/mbedtls/tests/hyper.rs +++ b/mbedtls/tests/hyper.rs @@ -25,18 +25,14 @@ impl TlsStream { } } -unsafe impl Send for TlsStream {} -unsafe impl Sync for TlsStream {} - - -impl io::Read for TlsStream +impl io::Read for TlsStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { self.context.lock().unwrap().read(buf) } } -impl io::Write for TlsStream +impl io::Write for TlsStream { fn write(&mut self, buf: &[u8]) -> io::Result { self.context.lock().unwrap().write(buf)