Skip to content

Commit

Permalink
Merge branch 'master' into gijsk/windows
Browse files Browse the repository at this point in the history
  • Loading branch information
kwakkel1000 authored Mar 22, 2023
2 parents 745c2a4 + 74dbd1b commit 55a6a75
Show file tree
Hide file tree
Showing 6 changed files with 292 additions and 192 deletions.
4 changes: 2 additions & 2 deletions mbedtls/examples/client_dtls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use std::sync::Arc;

use mbedtls::rng::CtrDrbg;
use mbedtls::ssl::config::{Endpoint, Preset, Transport};
use mbedtls::ssl::{Config, Context};
use mbedtls::ssl::{Config, Context, Io};
use mbedtls::x509::Certificate;
use mbedtls::Result as TlsResult;

Expand All @@ -35,7 +35,7 @@ fn result_main(addr: &str) -> TlsResult<()> {
ctx.set_timer_callback(Box::new(mbedtls::ssl::context::Timer::new()));

let sock = UdpSocket::bind("localhost:12345").unwrap();
let sock = mbedtls::ssl::context::ConnectedUdpSocket::connect(sock, addr).unwrap();
let sock = mbedtls::ssl::io::ConnectedUdpSocket::connect(sock, addr).unwrap();
ctx.establish(sock, None).unwrap();

let mut line = String::new();
Expand Down
1 change: 1 addition & 0 deletions mbedtls/src/ssl/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ define!(

define!(
#[c_ty(c_int)]
#[derive(PartialEq, Eq)]
enum Transport {
/// TLS
Stream = SSL_TRANSPORT_STREAM,
Expand Down
146 changes: 13 additions & 133 deletions mbedtls/src/ssl/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,9 @@
use core::result::Result as StdResult;

#[cfg(feature = "std")]
use {
std::io::{Read, Write, Result as IoResult, Error as IoError},
std::sync::Arc,
};
use std::sync::Arc;

use mbedtls_sys::types::raw_types::{c_int, c_uchar, c_void};
use mbedtls_sys::types::size_t;
use mbedtls_sys::types::raw_types::{c_int, c_void};
use mbedtls_sys::*;

#[cfg(not(feature = "std"))]
Expand All @@ -25,94 +21,9 @@ use crate::error::{Error, Result, IntoResult};
use crate::pk::Pk;
use crate::private::UnsafeFrom;
use crate::ssl::config::{Config, Version, AuthMode};
use crate::ssl::io::IoCallbackUnsafe;
use crate::x509::{Certificate, Crl, VerifyError};

pub trait IoCallback {
unsafe extern "C" fn call_recv(user_data: *mut c_void, data: *mut c_uchar, len: size_t) -> c_int where Self: Sized;
unsafe extern "C" fn call_send(user_data: *mut c_void, data: *const c_uchar, len: size_t) -> c_int where Self: Sized;
fn data_ptr(&mut self) -> *mut c_void;
}

#[cfg(feature = "std")]
impl<IO: Read + Write> IoCallback for IO {
unsafe extern "C" fn call_recv(user_data: *mut c_void, data: *mut c_uchar, len: size_t) -> c_int {
let len = if len > (c_int::max_value() as size_t) {
c_int::max_value() as size_t
} else {
len
};
match (&mut *(user_data as *mut IO)).read(::core::slice::from_raw_parts_mut(data, len)) {
Ok(i) => i as c_int,
Err(_) => ::mbedtls_sys::ERR_NET_RECV_FAILED,
}
}

unsafe extern "C" fn call_send(user_data: *mut c_void, data: *const c_uchar, len: size_t) -> c_int {
let len = if len > (c_int::max_value() as size_t) {
c_int::max_value() as size_t
} else {
len
};
match (&mut *(user_data as *mut IO)).write(::core::slice::from_raw_parts(data, len)) {
Ok(i) => i as c_int,
Err(_) => ::mbedtls_sys::ERR_NET_SEND_FAILED,
}
}

fn data_ptr(&mut self) -> *mut c_void {
self as *mut IO as *mut _
}
}

#[cfg(feature = "std")]
pub struct ConnectedUdpSocket {
socket: std::net::UdpSocket,
}

#[cfg(feature = "std")]
impl ConnectedUdpSocket {
pub fn connect<A: std::net::ToSocketAddrs>(socket: std::net::UdpSocket, addr: A) -> StdResult<Self, (IoError, std::net::UdpSocket)> {
match socket.connect(addr) {
Ok(_) => Ok(ConnectedUdpSocket {
socket,
}),
Err(e) => Err((e, socket)),
}
}
}

#[cfg(feature = "std")]
impl IoCallback for ConnectedUdpSocket {
unsafe extern "C" fn call_recv(user_data: *mut c_void, data: *mut c_uchar, len: size_t) -> c_int {
let len = if len > (c_int::max_value() as size_t) {
c_int::max_value() as size_t
} else {
len
};
match (&mut *(user_data as *mut ConnectedUdpSocket)).socket.recv(::core::slice::from_raw_parts_mut(data, len)) {
Ok(i) => i as c_int,
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => 0,
Err(_) => ::mbedtls_sys::ERR_NET_RECV_FAILED,
}
}

unsafe extern "C" fn call_send(user_data: *mut c_void, data: *const c_uchar, len: size_t) -> c_int {
let len = if len > (c_int::max_value() as size_t) {
c_int::max_value() as size_t
} else {
len
};
match (&mut *(user_data as *mut ConnectedUdpSocket)).socket.send(::core::slice::from_raw_parts(data, len)) {
Ok(i) => i as c_int,
Err(_) => ::mbedtls_sys::ERR_NET_SEND_FAILED,
}
}

fn data_ptr(&mut self) -> *mut c_void {
self as *mut ConnectedUdpSocket as *mut c_void
}
}

pub trait TimerCallback: Send + Sync {
unsafe extern "C" fn set_timer(
p_timer: *mut c_void,
Expand Down Expand Up @@ -261,8 +172,13 @@ impl<T> Context<T> {
}
}

impl<T: IoCallback> Context<T> {
pub fn establish(&mut self, io: T, hostname: Option<&str>) -> Result<()> {
impl<T> Context<T> {
/// Establish a TLS session on the given `io`.
///
/// Upon succesful return, the context can be communicated with using the
/// `std::io::Read` and `std::io::Write` traits if `io` implements those as
/// well, and using the `mbedtls::ssl::io::Io` trait otherwise.
pub fn establish<IoType>(&mut self, io: T, hostname: Option<&str>) -> Result<()> where T: IoCallbackUnsafe<IoType> {
unsafe {
let mut io = Box::new(io);
ssl_session_reset(self.into()).into_result()?;
Expand Down Expand Up @@ -292,7 +208,7 @@ impl<T> Context<T> {
/// Try to complete the handshake procedure to set up a (D)TLS connection
///
/// In general, this should not be called directly. Instead, [`establish`](Context::establish)
/// should be used which properly sets up the [`IoCallback`] and resets any previous sessions.
/// should be used which properly sets up the [`IoCallbackUnsafe`] and resets any previous sessions.
///
/// This should only be used directly if the handshake could not be completed successfully in
/// `establish`, i.e.:
Expand Down Expand Up @@ -483,16 +399,14 @@ impl<T> Context<T> {
pub fn set_client_transport_id_once(&mut self, info: &[u8]) {
self.client_transport_id = Some(info.into());
}
}

impl<T: IoCallback> Context<T> {
pub fn recv(&mut self, buf: &mut [u8]) -> Result<usize> {
pub(super) fn recv(&mut self, buf: &mut [u8]) -> Result<usize> {
unsafe {
ssl_read(self.into(), buf.as_mut_ptr(), buf.len()).into_result().map(|r| r as usize)
}
}

pub fn send(&mut self, buf: &[u8]) -> Result<usize> {
pub(super) fn send(&mut self, buf: &[u8]) -> Result<usize> {
unsafe {
ssl_write(self.into(), buf.as_ptr(), buf.len()).into_result().map(|w| w as usize)
}
Expand All @@ -508,40 +422,6 @@ impl<T> Drop for Context<T> {
}
}

#[cfg(feature = "std")]
/// Implements [`std::io::Read`] whenever T implements `Read`, too. This ensures that
/// `Read`, which is designated for byte-oriented sources, is only implemented when the
/// underlying [`IoCallback`] is byte-oriented, too. Specifically, this means that it is implemented
/// for `Context<TcpStream>`, i.e. TLS connections but not for DTLS connections.
impl<T: IoCallback + Read> Read for Context<T> {
fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
match self.recv(buf) {
Err(Error::SslPeerCloseNotify) => Ok(0),
Err(e) => Err(crate::private::error_to_io_error(e)),
Ok(i) => Ok(i),
}
}
}

#[cfg(feature = "std")]
/// Implements [`std::io::Write`] whenever T implements `Write`, too. This ensures that
/// `Write`, which is designated for byte-oriented sinks, is only implemented when the
/// underlying [`IoCallback`] is byte-oriented, too. Specifically, this means that it is implemented
/// for `Context<TcpStream>`, i.e. TLS connections but not for DTLS connections.
impl<T: IoCallback + Write> Write for Context<T> {
fn write(&mut self, buf: &[u8]) -> IoResult<usize> {
match self.send(buf) {
Err(Error::SslPeerCloseNotify) => Ok(0),
Err(e) => Err(crate::private::error_to_io_error(e)),
Ok(i) => Ok(i),
}
}

fn flush(&mut self) -> IoResult<()> {
Ok(())
}
}

//
// Class exists only during SNI callback that is configured from Config.
// SNI Callback must provide input whose lifetime exceeds the SNI closure to avoid memory corruptions.
Expand Down
Loading

0 comments on commit 55a6a75

Please sign in to comment.