diff --git a/mbedtls/examples/client_dtls.rs b/mbedtls/examples/client_dtls.rs index d4776e30d..e079b2aa9 100644 --- a/mbedtls/examples/client_dtls.rs +++ b/mbedtls/examples/client_dtls.rs @@ -15,7 +15,7 @@ use std::sync::Arc; use mbedtls::rng::CtrDrbg; use mbedtls::ssl::config::{Endpoint, Preset, Transport}; -use mbedtls::ssl::{Config, Context}; +use mbedtls::ssl::{Config, Context, Io}; use mbedtls::x509::Certificate; use mbedtls::Result as TlsResult; @@ -35,7 +35,7 @@ fn result_main(addr: &str) -> TlsResult<()> { ctx.set_timer_callback(Box::new(mbedtls::ssl::context::Timer::new())); let sock = UdpSocket::bind("localhost:12345").unwrap(); - let sock = mbedtls::ssl::context::ConnectedUdpSocket::connect(sock, addr).unwrap(); + let sock = mbedtls::ssl::io::ConnectedUdpSocket::connect(sock, addr).unwrap(); ctx.establish(sock, None).unwrap(); let mut line = String::new(); diff --git a/mbedtls/src/ssl/config.rs b/mbedtls/src/ssl/config.rs index 8880cb785..966affd95 100644 --- a/mbedtls/src/ssl/config.rs +++ b/mbedtls/src/ssl/config.rs @@ -52,6 +52,7 @@ define!( define!( #[c_ty(c_int)] + #[derive(PartialEq, Eq)] enum Transport { /// TLS Stream = SSL_TRANSPORT_STREAM, diff --git a/mbedtls/src/ssl/context.rs b/mbedtls/src/ssl/context.rs index 8ce2d6574..20e0b4445 100644 --- a/mbedtls/src/ssl/context.rs +++ b/mbedtls/src/ssl/context.rs @@ -9,13 +9,9 @@ use core::result::Result as StdResult; #[cfg(feature = "std")] -use { - std::io::{Read, Write, Result as IoResult, Error as IoError}, - std::sync::Arc, -}; +use std::sync::Arc; -use mbedtls_sys::types::raw_types::{c_int, c_uchar, c_void}; -use mbedtls_sys::types::size_t; +use mbedtls_sys::types::raw_types::{c_int, c_void}; use mbedtls_sys::*; #[cfg(not(feature = "std"))] @@ -25,94 +21,9 @@ use crate::error::{Error, Result, IntoResult}; use crate::pk::Pk; use crate::private::UnsafeFrom; use crate::ssl::config::{Config, Version, AuthMode}; +use crate::ssl::io::IoCallbackUnsafe; use crate::x509::{Certificate, Crl, VerifyError}; -pub trait IoCallback { - unsafe extern "C" fn call_recv(user_data: *mut c_void, data: *mut c_uchar, len: size_t) -> c_int where Self: Sized; - unsafe extern "C" fn call_send(user_data: *mut c_void, data: *const c_uchar, len: size_t) -> c_int where Self: Sized; - fn data_ptr(&mut self) -> *mut c_void; -} - -#[cfg(feature = "std")] -impl IoCallback for IO { - unsafe extern "C" fn call_recv(user_data: *mut c_void, data: *mut c_uchar, len: size_t) -> c_int { - let len = if len > (c_int::max_value() as size_t) { - c_int::max_value() as size_t - } else { - len - }; - match (&mut *(user_data as *mut IO)).read(::core::slice::from_raw_parts_mut(data, len)) { - Ok(i) => i as c_int, - Err(_) => ::mbedtls_sys::ERR_NET_RECV_FAILED, - } - } - - unsafe extern "C" fn call_send(user_data: *mut c_void, data: *const c_uchar, len: size_t) -> c_int { - let len = if len > (c_int::max_value() as size_t) { - c_int::max_value() as size_t - } else { - len - }; - match (&mut *(user_data as *mut IO)).write(::core::slice::from_raw_parts(data, len)) { - Ok(i) => i as c_int, - Err(_) => ::mbedtls_sys::ERR_NET_SEND_FAILED, - } - } - - fn data_ptr(&mut self) -> *mut c_void { - self as *mut IO as *mut _ - } -} - -#[cfg(feature = "std")] -pub struct ConnectedUdpSocket { - socket: std::net::UdpSocket, -} - -#[cfg(feature = "std")] -impl ConnectedUdpSocket { - pub fn connect(socket: std::net::UdpSocket, addr: A) -> StdResult { - match socket.connect(addr) { - Ok(_) => Ok(ConnectedUdpSocket { - socket, - }), - Err(e) => Err((e, socket)), - } - } -} - -#[cfg(feature = "std")] -impl IoCallback for ConnectedUdpSocket { - unsafe extern "C" fn call_recv(user_data: *mut c_void, data: *mut c_uchar, len: size_t) -> c_int { - let len = if len > (c_int::max_value() as size_t) { - c_int::max_value() as size_t - } else { - len - }; - match (&mut *(user_data as *mut ConnectedUdpSocket)).socket.recv(::core::slice::from_raw_parts_mut(data, len)) { - Ok(i) => i as c_int, - Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => 0, - Err(_) => ::mbedtls_sys::ERR_NET_RECV_FAILED, - } - } - - unsafe extern "C" fn call_send(user_data: *mut c_void, data: *const c_uchar, len: size_t) -> c_int { - let len = if len > (c_int::max_value() as size_t) { - c_int::max_value() as size_t - } else { - len - }; - match (&mut *(user_data as *mut ConnectedUdpSocket)).socket.send(::core::slice::from_raw_parts(data, len)) { - Ok(i) => i as c_int, - Err(_) => ::mbedtls_sys::ERR_NET_SEND_FAILED, - } - } - - fn data_ptr(&mut self) -> *mut c_void { - self as *mut ConnectedUdpSocket as *mut c_void - } -} - pub trait TimerCallback: Send + Sync { unsafe extern "C" fn set_timer( p_timer: *mut c_void, @@ -261,8 +172,13 @@ impl Context { } } -impl Context { - pub fn establish(&mut self, io: T, hostname: Option<&str>) -> Result<()> { +impl Context { + /// Establish a TLS session on the given `io`. + /// + /// Upon succesful return, the context can be communicated with using the + /// `std::io::Read` and `std::io::Write` traits if `io` implements those as + /// well, and using the `mbedtls::ssl::io::Io` trait otherwise. + pub fn establish(&mut self, io: T, hostname: Option<&str>) -> Result<()> where T: IoCallbackUnsafe { unsafe { let mut io = Box::new(io); ssl_session_reset(self.into()).into_result()?; @@ -292,7 +208,7 @@ impl Context { /// Try to complete the handshake procedure to set up a (D)TLS connection /// /// In general, this should not be called directly. Instead, [`establish`](Context::establish) - /// should be used which properly sets up the [`IoCallback`] and resets any previous sessions. + /// should be used which properly sets up the [`IoCallbackUnsafe`] and resets any previous sessions. /// /// This should only be used directly if the handshake could not be completed successfully in /// `establish`, i.e.: @@ -483,16 +399,14 @@ impl Context { pub fn set_client_transport_id_once(&mut self, info: &[u8]) { self.client_transport_id = Some(info.into()); } -} -impl Context { - pub fn recv(&mut self, buf: &mut [u8]) -> Result { + pub(super) fn recv(&mut self, buf: &mut [u8]) -> Result { unsafe { ssl_read(self.into(), buf.as_mut_ptr(), buf.len()).into_result().map(|r| r as usize) } } - pub fn send(&mut self, buf: &[u8]) -> Result { + pub(super) fn send(&mut self, buf: &[u8]) -> Result { unsafe { ssl_write(self.into(), buf.as_ptr(), buf.len()).into_result().map(|w| w as usize) } @@ -508,40 +422,6 @@ impl Drop for Context { } } -#[cfg(feature = "std")] -/// Implements [`std::io::Read`] whenever T implements `Read`, too. This ensures that -/// `Read`, which is designated for byte-oriented sources, is only implemented when the -/// underlying [`IoCallback`] is byte-oriented, too. Specifically, this means that it is implemented -/// for `Context`, i.e. TLS connections but not for DTLS connections. -impl Read for Context { - fn read(&mut self, buf: &mut [u8]) -> IoResult { - match self.recv(buf) { - Err(Error::SslPeerCloseNotify) => Ok(0), - Err(e) => Err(crate::private::error_to_io_error(e)), - Ok(i) => Ok(i), - } - } -} - -#[cfg(feature = "std")] -/// Implements [`std::io::Write`] whenever T implements `Write`, too. This ensures that -/// `Write`, which is designated for byte-oriented sinks, is only implemented when the -/// underlying [`IoCallback`] is byte-oriented, too. Specifically, this means that it is implemented -/// for `Context`, i.e. TLS connections but not for DTLS connections. -impl Write for Context { - fn write(&mut self, buf: &[u8]) -> IoResult { - match self.send(buf) { - Err(Error::SslPeerCloseNotify) => Ok(0), - Err(e) => Err(crate::private::error_to_io_error(e)), - Ok(i) => Ok(i), - } - } - - fn flush(&mut self) -> IoResult<()> { - Ok(()) - } -} - // // Class exists only during SNI callback that is configured from Config. // SNI Callback must provide input whose lifetime exceeds the SNI closure to avoid memory corruptions. diff --git a/mbedtls/src/ssl/io.rs b/mbedtls/src/ssl/io.rs new file mode 100644 index 000000000..4ca1e0f0a --- /dev/null +++ b/mbedtls/src/ssl/io.rs @@ -0,0 +1,217 @@ +/* Copyright (c) Fortanix, Inc. + * + * Licensed under the GNU General Public License, version 2 or the Apache License, Version + * 2.0 , at your + * option. This file may not be copied, modified, or distributed except + * according to those terms. */ +//! Various I/O abstractions for use with MbedTLS's TLS sessions. +//! +//! If you are using `std::net::TcpStream` or any `std::io::Read` and +//! `std::io::Write` streams, you probably don't need to look at any of this. +//! Just pass your stream directly to `Context::establish`. If you want to use +//! a `std::net::UdpSocket` with DTLS, take a look at `ConnectedUdpSocket`. If +//! you are implementing your own communication types or traits, consider +//! implementing `Io` for them. If all else fails, implement `IoCallback`. + +#[cfg(feature = "std")] +use std::{ + io::{Read, Write, Result as IoResult, Error as IoError, ErrorKind as IoErrorKind}, + net::UdpSocket, + result::Result as StdResult, +}; + +use mbedtls_sys::types::raw_types::{c_int, c_uchar, c_void}; +use mbedtls_sys::types::size_t; + +#[cfg(feature = "std")] +use crate::error::Error; +use crate::error::Result; +use super::context::Context; + +/// A direct representation of the `mbedtls_ssl_send_t` and `mbedtls_ssl_recv_t` +/// callback function pointers. +/// +/// You probably want to implement `IoCallback` instead. +pub trait IoCallbackUnsafe { + unsafe extern "C" fn call_recv(user_data: *mut c_void, data: *mut c_uchar, len: size_t) -> c_int where Self: Sized; + unsafe extern "C" fn call_send(user_data: *mut c_void, data: *const c_uchar, len: size_t) -> c_int where Self: Sized; + fn data_ptr(&mut self) -> *mut c_void; +} + +/// A safe representation of the `mbedtls_ssl_send_t` and `mbedtls_ssl_recv_t` +/// callback function pointers. +/// +/// `T` specifies whether this abstracts an implementation of `std::io::Read` +/// and `std::io::Write` or the more generic `Io` type. See the `Stream` and +/// `AnyIo` types in this module. +pub trait IoCallback { + fn recv(&mut self, buf: &mut [u8]) -> Result; + fn send(&mut self, buf: &[u8]) -> Result; +} + +impl, T> IoCallbackUnsafe for IO { + unsafe extern "C" fn call_recv(user_data: *mut c_void, data: *mut c_uchar, len: size_t) -> c_int { + let len = if len > (c_int::max_value() as size_t) { + c_int::max_value() as size_t + } else { + len + }; + match (&mut *(user_data as *mut IO)).recv(::core::slice::from_raw_parts_mut(data, len)) { + Ok(i) => i as c_int, + Err(e) => e.to_int(), + } + } + + unsafe extern "C" fn call_send(user_data: *mut c_void, data: *const c_uchar, len: size_t) -> c_int { + let len = if len > (c_int::max_value() as size_t) { + c_int::max_value() as size_t + } else { + len + }; + match (&mut *(user_data as *mut IO)).send(::core::slice::from_raw_parts(data, len)) { + Ok(i) => i as c_int, + Err(e) => e.to_int(), + } + } + + fn data_ptr(&mut self) -> *mut c_void { + self as *mut IO as *mut _ + } +} + +/// Marker type for an IO implementation that doesn't implement `std::io::Read` +/// and `std::io::Write`. +pub enum AnyIo {} +#[cfg(feature = "std")] +/// Marker type for an IO implementation that implements both `std::io::Read` +/// and `std::io::Write`. +pub enum Stream {} + +/// Read and write bytes or packets. +/// +/// Implementors represent a duplex socket or file descriptor that can be read +/// from or written to. +/// +/// You can wrap any type of `Io` with `Context::establish` to protect that +/// communication channel with (D)TLS. That `Context` then also implements `Io` +/// so you can use it interchangeably. +/// +/// If you are using byte streams and are using `std`, you don't need this trait +/// and can rely on `std::io::Read` and `std::io::Write` instead. +pub trait Io { + fn recv(&mut self, buf: &mut [u8]) -> Result; + fn send(&mut self, buf: &[u8]) -> Result; +} + +impl IoCallback for IO { + fn recv(&mut self, buf: &mut [u8]) -> Result { + Io::recv(self, buf) + } + + fn send(&mut self, buf: &[u8]) -> Result { + Io::send(self, buf) + } +} + +#[cfg(feature = "std")] +impl IoCallback for IO { + fn recv(&mut self, buf: &mut [u8]) -> Result { + self.read(buf).map_err(|e| match e { + ref e if e.kind() == std::io::ErrorKind::WouldBlock => Error::SslWantRead, + _ => Error::NetRecvFailed + }) + } + + fn send(&mut self, buf: &[u8]) -> Result { + self.write(buf).map_err(|e| match e { + ref e if e.kind() == std::io::ErrorKind::WouldBlock => Error::SslWantWrite, + _ => Error::NetSendFailed + }) + } +} + +#[cfg(feature = "std")] +/// A `UdpSocket` on which `connect` was succesfully called. +/// +/// Construct this type using `ConnectedUdpSocket::connect`. +pub struct ConnectedUdpSocket { + socket: UdpSocket, +} + +#[cfg(feature = "std")] +impl ConnectedUdpSocket { + pub fn connect(socket: UdpSocket, addr: A) -> StdResult { + match socket.connect(addr) { + Ok(_) => Ok(ConnectedUdpSocket { + socket, + }), + Err(e) => Err((e, socket)), + } + } + + pub fn into_socket(self) -> UdpSocket { + self.socket + } +} + +#[cfg(feature = "std")] +impl Io for ConnectedUdpSocket { + fn recv(&mut self, buf: &mut [u8]) -> Result { + match self.socket.recv(buf) { + Ok(i) => Ok(i), + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => Err(Error::SslWantRead), + Err(_) => Err(Error::NetRecvFailed) + } + } + + fn send(&mut self, buf: &[u8]) -> Result { + self.socket.send(buf).map_err(|_| Error::NetSendFailed) + } +} + +impl> Io for Context { + fn recv(&mut self, buf: &mut [u8]) -> Result { + Context::recv(self, buf) + } + + fn send(&mut self, buf: &[u8]) -> Result { + Context::send(self, buf) + } +} + +#[cfg(feature = "std")] +/// Implements [`std::io::Read`] whenever T implements `Read`, too. This ensures that +/// `Read`, which is designated for byte-oriented sources, is only implemented when the +/// underlying [`IoCallbackUnsafe`] is byte-oriented, too. Specifically, this means that it is implemented +/// for `Context`, i.e. TLS connections but not for DTLS connections. +impl> Read for Context { + fn read(&mut self, buf: &mut [u8]) -> IoResult { + match self.recv(buf) { + Err(Error::SslPeerCloseNotify) => Ok(0), + Err(Error::SslWantRead) | Err(Error::SslWantWrite) => Err(IoErrorKind::WouldBlock.into()), + Err(e) => Err(crate::private::error_to_io_error(e)), + Ok(i) => Ok(i), + } + } +} + +#[cfg(feature = "std")] +/// Implements [`std::io::Write`] whenever T implements `Write`, too. This ensures that +/// `Write`, which is designated for byte-oriented sinks, is only implemented when the +/// underlying [`IoCallbackUnsafe`] is byte-oriented, too. Specifically, this means that it is implemented +/// for `Context`, i.e. TLS connections but not for DTLS connections. +impl> Write for Context { + fn write(&mut self, buf: &[u8]) -> IoResult { + match self.send(buf) { + Err(Error::SslPeerCloseNotify) => Ok(0), + Err(Error::SslWantRead) | Err(Error::SslWantWrite) => Err(IoErrorKind::WouldBlock.into()), + Err(e) => Err(crate::private::error_to_io_error(e)), + Ok(i) => Ok(i), + } + } + + fn flush(&mut self) -> IoResult<()> { + Ok(()) + } +} diff --git a/mbedtls/src/ssl/mod.rs b/mbedtls/src/ssl/mod.rs index 1bfc078cf..6db024381 100644 --- a/mbedtls/src/ssl/mod.rs +++ b/mbedtls/src/ssl/mod.rs @@ -10,6 +10,7 @@ pub mod ciphersuites; pub mod config; pub mod context; pub mod cookie; +pub mod io; pub mod ticket; #[doc(inline)] @@ -21,4 +22,6 @@ pub use self::context::Context; #[doc(inline)] pub use self::cookie::CookieContext; #[doc(inline)] +pub use self::io::Io; +#[doc(inline)] pub use self::ticket::TicketContext; diff --git a/mbedtls/tests/client_server.rs b/mbedtls/tests/client_server.rs index 43f613491..c967f1751 100644 --- a/mbedtls/tests/client_server.rs +++ b/mbedtls/tests/client_server.rs @@ -11,65 +11,68 @@ // needed to have common code for `mod support` in unit and integrations tests extern crate mbedtls; +use std::io::{Read, Write}; use std::net::TcpStream; use mbedtls::pk::Pk; use mbedtls::rng::CtrDrbg; use mbedtls::ssl::config::{Endpoint, Preset, Transport}; -use mbedtls::ssl::context::{ConnectedUdpSocket, IoCallback, Timer}; -use mbedtls::ssl::{Config, Context, CookieContext, Version}; +use mbedtls::ssl::context::Timer; +use mbedtls::ssl::io::{ConnectedUdpSocket, IoCallback}; +use mbedtls::ssl::{Config, Context, CookieContext, Io, Version}; use mbedtls::x509::{Certificate, VerifyError}; use mbedtls::Error; use mbedtls::Result as TlsResult; use std::sync::Arc; -use mbedtls_sys::types::raw_types::*; -use mbedtls_sys::types::size_t; - mod support; use support::entropy::entropy_new; use support::keys; -/// Simple type to unify TCP and UDP connections, to support both TLS and DTLS -enum Connection { - Tcp(TcpStream), - Udp(ConnectedUdpSocket), +trait TransportType: Sized { + fn get_transport_type() -> Transport; + + fn recv(ctx: &mut Context, buf: &mut [u8]) -> TlsResult; + fn send(ctx: &mut Context, buf: &[u8]) -> TlsResult; } -impl IoCallback for Connection { - unsafe extern "C" fn call_recv(user_data: *mut c_void, data: *mut c_uchar, len: size_t) -> c_int { - let conn = &mut *(user_data as *mut Connection); - match conn { - Connection::Tcp(c) => TcpStream::call_recv(c.data_ptr(), data, len), - Connection::Udp(c) => ConnectedUdpSocket::call_recv(c.data_ptr(), data, len), - } +impl TransportType for TcpStream { + fn get_transport_type() -> Transport { + Transport::Stream } - unsafe extern "C" fn call_send(user_data: *mut c_void, data: *const c_uchar, len: size_t) -> c_int { - let conn = &mut *(user_data as *mut Connection); - match conn { - Connection::Tcp(c) => TcpStream::call_send(c.data_ptr(), data, len), - Connection::Udp(c) => ConnectedUdpSocket::call_send(c.data_ptr(), data, len), - } + fn recv(ctx: &mut Context, buf: &mut [u8]) -> TlsResult { + ctx.read(buf).map_err(|_| Error::NetRecvFailed) } - fn data_ptr(&mut self) -> *mut c_void { - self as *mut Connection as *mut c_void + fn send(ctx: &mut Context, buf: &[u8]) -> TlsResult { + ctx.write(buf).map_err(|_| Error::NetSendFailed) } } -fn client( - conn: Connection, +impl TransportType for ConnectedUdpSocket { + fn get_transport_type() -> Transport { + Transport::Datagram + } + + fn recv(ctx: &mut Context, buf: &mut [u8]) -> TlsResult { + Io::recv(ctx, buf) + } + + fn send(ctx: &mut Context, buf: &[u8]) -> TlsResult { + Io::send(ctx, buf) + } +} + +fn client + TransportType, T>( + conn: C, min_version: Version, max_version: Version, exp_version: Option, use_psk: bool) -> TlsResult<()> { let entropy = Arc::new(entropy_new()); let rng = Arc::new(CtrDrbg::new(entropy, None)?); - let mut config = match conn { - Connection::Tcp(_) => Config::new(Endpoint::Client, Transport::Stream, Preset::Default), - Connection::Udp(_) => Config::new(Endpoint::Client, Transport::Datagram, Preset::Default), - }; + let mut config = Config::new(Endpoint::Client, C::get_transport_type(), Preset::Default); config.set_rng(rng); config.set_min_version(min_version)?; config.set_max_version(max_version)?; @@ -98,7 +101,7 @@ fn client( let mut ctx = Context::new(Arc::new(config)); // For DTLS, timers are required to support retransmissions - if let Connection::Udp(_) = conn { + if C::get_transport_type() == Transport::Datagram { ctx.set_timer_callback(Box::new(Timer::new())); } @@ -118,15 +121,15 @@ fn client( let ciphersuite = ctx.ciphersuite().unwrap(); let buf = format!("Client2Server {:4x}", ciphersuite); - assert_eq!(ctx.send(buf.as_bytes()).unwrap(), buf.len()); + assert_eq!(::send(&mut ctx, buf.as_bytes()).unwrap(), buf.len()); let mut buf = [0u8; 13 + 4 + 1]; - assert_eq!(ctx.recv(&mut buf).unwrap(), buf.len()); + assert_eq!(::recv(&mut ctx, &mut buf).unwrap(), buf.len()); assert_eq!(&buf, format!("Server2Client {:4x}", ciphersuite).as_bytes()); Ok(()) } -fn server( - conn: Connection, +fn server + TransportType, T>( + conn: C, min_version: Version, max_version: Version, exp_version: Option, @@ -134,16 +137,12 @@ fn server( ) -> TlsResult<()> { let entropy = entropy_new(); let rng = Arc::new(CtrDrbg::new(Arc::new(entropy), None)?); - let mut config = match conn { - Connection::Tcp(_) => Config::new(Endpoint::Server, Transport::Stream, Preset::Default), - Connection::Udp(_) => { - let mut config = Config::new(Endpoint::Server, Transport::Datagram, Preset::Default); - // For DTLS, we need a cookie context to work against DoS attacks - let cookies = CookieContext::new(rng.clone())?; - config.set_dtls_cookies(Arc::new(cookies)); - config - } - }; + let mut config = Config::new(Endpoint::Server, C::get_transport_type(), Preset::Default); + if C::get_transport_type() == Transport::Datagram { + // For DTLS, we need a cookie context to work against DoS attacks + let cookies = CookieContext::new(rng.clone())?; + config.set_dtls_cookies(Arc::new(cookies)); + } config.set_rng(rng); config.set_min_version(min_version)?; config.set_max_version(max_version)?; @@ -156,7 +155,7 @@ fn server( } let mut ctx = Context::new(Arc::new(config)); - let res = if let Connection::Udp(_) = conn { + let res = if C::get_transport_type() == Transport::Datagram { // For DTLS, timers are required to support retransmissions and the DTLS server needs a client // ID to create individual cookies per client ctx.set_timer_callback(Box::new(Timer::new())); @@ -190,9 +189,9 @@ fn server( let ciphersuite = ctx.ciphersuite().unwrap(); let buf = format!("Server2Client {:4x}", ciphersuite); - assert_eq!(ctx.send(buf.as_bytes()).unwrap(), buf.len()); + assert_eq!(::send(&mut ctx, buf.as_bytes()).unwrap(), buf.len()); let mut buf = [0u8; 13 + 1 + 4]; - assert_eq!(ctx.recv(&mut buf).unwrap(), buf.len()); + assert_eq!(::recv(&mut ctx, &mut buf).unwrap(), buf.len()); assert_eq!(&buf, format!("Client2Server {:4x}", ciphersuite).as_bytes()); Ok(()) @@ -206,7 +205,7 @@ mod test { fn client_server_test() { use mbedtls::ssl::Version; use std::net::UdpSocket; - use mbedtls::ssl::context::ConnectedUdpSocket; + use mbedtls::ssl::io::ConnectedUdpSocket; #[derive(Copy,Clone)] struct TestConfig { @@ -248,8 +247,8 @@ mod test { // TLS tests using certificates let (c, s) = crate::support::net::create_tcp_pair().unwrap(); - let c = thread::spawn(move || super::client(super::Connection::Tcp(c), min_c, max_c, exp_ver, false).unwrap()); - let s = thread::spawn(move || super::server(super::Connection::Tcp(s), min_s, max_s, exp_ver, false).unwrap()); + let c = thread::spawn(move || super::client(c, min_c, max_c, exp_ver, false).unwrap()); + let s = thread::spawn(move || super::server(s, min_s, max_s, exp_ver, false).unwrap()); c.join().unwrap(); s.join().unwrap(); @@ -257,8 +256,8 @@ mod test { // TLS tests using PSK let (c, s) = crate::support::net::create_tcp_pair().unwrap(); - let c = thread::spawn(move || super::client(super::Connection::Tcp(c), min_c, max_c, exp_ver, true).unwrap()); - let s = thread::spawn(move || super::server(super::Connection::Tcp(s), min_s, max_s, exp_ver, true).unwrap()); + let c = thread::spawn(move || super::client(c, min_c, max_c, exp_ver, true).unwrap()); + let s = thread::spawn(move || super::server(s, min_s, max_s, exp_ver, true).unwrap()); c.join().unwrap(); s.join().unwrap(); @@ -272,10 +271,10 @@ mod test { let s = UdpSocket::bind("127.0.0.1:12340").expect("could not bind UdpSocket"); let s = ConnectedUdpSocket::connect(s, "127.0.0.1:12341").expect("could not connect UdpSocket"); - let s = thread::spawn(move || super::server(super::Connection::Udp(s), min_s, max_s, exp_ver, false).unwrap()); + let s = thread::spawn(move || super::server(s, min_s, max_s, exp_ver, false).unwrap()); let c = UdpSocket::bind("127.0.0.1:12341").expect("could not bind UdpSocket"); let c = ConnectedUdpSocket::connect(c, "127.0.0.1:12340").expect("could not connect UdpSocket"); - let c = thread::spawn(move || super::client(super::Connection::Udp(c), min_c, max_c, exp_ver, false).unwrap()); + let c = thread::spawn(move || super::client(c, min_c, max_c, exp_ver, false).unwrap()); s.join().unwrap(); c.join().unwrap(); @@ -289,10 +288,10 @@ mod test { let s = UdpSocket::bind("127.0.0.1:12340").expect("could not bind UdpSocket"); let s = ConnectedUdpSocket::connect(s, "127.0.0.1:12341").expect("could not connect UdpSocket"); - let s = thread::spawn(move || super::server(super::Connection::Udp(s), min_s, max_s, exp_ver, true).unwrap()); + let s = thread::spawn(move || super::server(s, min_s, max_s, exp_ver, true).unwrap()); let c = UdpSocket::bind("127.0.0.1:12341").expect("could not bind UdpSocket"); let c = ConnectedUdpSocket::connect(c, "127.0.0.1:12340").expect("could not connect UdpSocket"); - let c = thread::spawn(move || super::client(super::Connection::Udp(c), min_c, max_c, exp_ver, true).unwrap()); + let c = thread::spawn(move || super::client(c, min_c, max_c, exp_ver, true).unwrap()); s.join().unwrap(); c.join().unwrap();