Skip to content

Commit

Permalink
Merge branch 'master' for take_io()
Browse files Browse the repository at this point in the history
  • Loading branch information
djc committed Jun 6, 2023
2 parents 6e9b052 + fcbae20 commit 4476f8b
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 0 deletions.
45 changes: 45 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,51 @@ where
io: Some(io),
}
}

/// Takes back the client connection. Will return `None` if called more than once or if the
/// connection has been accepted.
///
/// # Example
///
/// ```no_run
/// # fn choose_server_config(
/// # _: rustls::server::ClientHello,
/// # ) -> std::sync::Arc<rustls::ServerConfig> {
/// # unimplemented!();
/// # }
/// # #[allow(unused_variables)]
/// # async fn listen() {
/// use tokio::io::AsyncWriteExt;
/// let listener = tokio::net::TcpListener::bind("127.0.0.1:4443").await.unwrap();
/// let (stream, _) = listener.accept().await.unwrap();
///
/// let acceptor = tokio_rustls::LazyConfigAcceptor::new(rustls::server::Acceptor::default(), stream);
/// futures_util::pin_mut!(acceptor);
///
/// match acceptor.as_mut().await {
/// Ok(start) => {
/// let clientHello = start.client_hello();
/// let config = choose_server_config(clientHello);
/// let stream = start.into_stream(config).await.unwrap();
/// // Proceed with handling the ServerConnection...
/// }
/// Err(err) => {
/// if let Some(mut stream) = acceptor.take_io() {
/// stream
/// .write_all(
/// format!("HTTP/1.1 400 Invalid Input\r\n\r\n\r\n{:?}\n", err)
/// .as_bytes()
/// )
/// .await
/// .unwrap();
/// }
/// }
/// }
/// # }
/// ```
pub fn take_io(&mut self) -> Option<IO> {
self.io.take()
}
}

impl<IO> Future for LazyConfigAcceptor<IO>
Expand Down
36 changes: 36 additions & 0 deletions tests/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use std::time::Duration;
use std::{io, thread};
use tokio::io::{copy, split, AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::oneshot;
use tokio::{runtime, time};
use tokio_rustls::{LazyConfigAcceptor, TlsAcceptor, TlsConnector};

Expand Down Expand Up @@ -215,5 +216,40 @@ async fn lazy_config_acceptor_eof() {
}
}

#[tokio::test]
async fn lazy_config_acceptor_take_io() -> Result<(), rustls::Error> {
let (mut cstream, sstream) = tokio::io::duplex(1200);

let (tx, rx) = oneshot::channel();

tokio::spawn(async move {
cstream.write_all(b"hello, world!").await.unwrap();

let mut buf = Vec::new();
cstream.read_to_end(&mut buf).await.unwrap();
tx.send(buf).unwrap();
});

let acceptor = LazyConfigAcceptor::new(rustls::server::Acceptor::default(), sstream);
futures_util::pin_mut!(acceptor);
if (acceptor.as_mut().await).is_ok() {
panic!("Expected Err(err)");
}

let server_msg = b"message from server";

let some_io = acceptor.take_io();
assert!(some_io.is_some(), "Expected Some(io)");
some_io.unwrap().write_all(server_msg).await.unwrap();

assert_eq!(rx.await.unwrap(), server_msg);

assert!(
acceptor.take_io().is_none(),
"Should not be able to take twice"
);
Ok(())
}

// Include `utils` module
include!("utils.rs");

0 comments on commit 4476f8b

Please sign in to comment.