Skip to content

Commit

Permalink
Server: Allows to listen while not blocking the thread
Browse files Browse the repository at this point in the history
The server was already running in background threads, let's not block the main thread
  • Loading branch information
Tpt committed Nov 18, 2023
1 parent 7dbe179 commit 87b5e81
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 74 deletions.
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ It is still a work in progress. Use at your own risks behind a reverse proxy!

Example:
```rust no_run
use std::net::{Ipv4Addr, Ipv6Addr};
use oxhttp::Server;
use oxhttp::model::{Response, Status};
use std::time::Duration;
Expand All @@ -50,12 +51,14 @@ let mut server = Server::new(|request| {
Response::builder(Status::NOT_FOUND).build()
}
});
// We bind the server to localhost on both IPv4 and v6
server = server.bind((Ipv4Addr::LOCALHOST, 8080)).bind((Ipv6Addr::LOCALHOST, 8080));
// Raise a timeout error if the client does not respond after 10s.
server = server.with_global_timeout(Duration::from_secs(10));
// Limits the max number of concurrent connections to 128.
server = server.with_max_concurrent_connections(128);
// Listen to localhost:8080
server.listen(("localhost", 8080)).unwrap();
// We spawn the server and block on it
server.spawn().unwrap().join().unwrap();
```

## License
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ mod utils;
#[cfg(feature = "client")]
pub use client::Client;
#[cfg(feature = "server")]
pub use server::Server;
pub use server::{ListeningServer, Server};
153 changes: 82 additions & 71 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ use crate::model::{
};
use std::fmt;
use std::io::{copy, sink, BufReader, BufWriter, Error, ErrorKind, Result, Write};
use std::net::{TcpListener, TcpStream, ToSocketAddrs};
use std::net::{SocketAddr, TcpListener, TcpStream};
use std::sync::{Arc, Condvar, Mutex};
use std::thread::{self, Builder};
use std::thread::{Builder, JoinHandle};
use std::time::Duration;

/// An HTTP server.
Expand All @@ -16,6 +16,7 @@ use std::time::Duration;
/// To avoid crashes it is possible to set an upper bound to the number of concurrent connections using the [`Server::with_max_concurrent_connections`] function.
///
/// ```no_run
/// use std::net::{Ipv4Addr, Ipv6Addr};
/// use oxhttp::Server;
/// use oxhttp::model::{Response, Status};
/// use std::time::Duration;
Expand All @@ -28,17 +29,20 @@ use std::time::Duration;
/// Response::builder(Status::NOT_FOUND).build()
/// }
/// });
/// // We bind the server to localhost on both IPv4 and v6
/// server = server.bind((Ipv4Addr::LOCALHOST, 8080)).bind((Ipv6Addr::LOCALHOST, 8080));
/// // Raise a timeout error if the client does not respond after 10s.
/// server = server.with_global_timeout(Duration::from_secs(10));
/// // Limits the number of concurrent connections to 128.
/// server = server.with_max_concurrent_connections(128);
/// // Listen to localhost:8080
/// server.listen(("localhost", 8080))?;
/// // We spawn the server and block on it
/// server.spawn()?.join()?;
/// # Result::<_,Box<dyn std::error::Error>>::Ok(())
/// ```
#[allow(missing_copy_implementations)]
pub struct Server {
on_request: Box<dyn Fn(&mut Request) -> Response + Send + Sync + 'static>,
on_request: Arc<dyn Fn(&mut Request) -> Response + Send + Sync + 'static>,
socket_addrs: Vec<SocketAddr>,
timeout: Option<Duration>,
server: Option<HeaderValue>,
max_num_thread: Option<usize>,
Expand All @@ -49,13 +53,23 @@ impl Server {
#[inline]
pub fn new(on_request: impl Fn(&mut Request) -> Response + Send + Sync + 'static) -> Self {
Self {
on_request: Box::new(on_request),
on_request: Arc::new(on_request),
socket_addrs: Vec::new(),
timeout: None,
server: None,
max_num_thread: None,
}
}

/// Ask the server to listen to a given socket when spawned.
pub fn bind(mut self, addr: impl Into<SocketAddr>) -> Self {
let addr = addr.into();
if !self.socket_addrs.contains(&addr) {
self.socket_addrs.push(addr);
}
self
}

/// Sets the default value for the [`Server`](https://httpwg.org/http-core/draft-ietf-httpbis-semantics-latest.html#field.server) header.
#[inline]
pub fn with_server_name(
Expand All @@ -80,18 +94,22 @@ impl Server {
self
}

/// Runs the server by listening to `address`.
pub fn listen(&self, address: impl ToSocketAddrs) -> Result<()> {
thread::scope(|scope| {
let timeout = self.timeout;
let thread_limit = self.max_num_thread.map(Semaphore::new);
let listener_threads = open_tcp(address)?
/// Spawns the server by listening to the given addresses.
///
/// Note that this is not blocking.
/// To wait for the server to terminate indefinitely, call [`join`](ListeningServer::join) on the result.
pub fn spawn(self) -> Result<ListeningServer> {
let timeout = self.timeout;
let thread_limit = self.max_num_thread.map(Semaphore::new);
let listener_threads = self.socket_addrs
.into_iter()
.map(|listener| {
let listener_addr = listener.local_addr()?;
.map(|listener_addr| {
let listener = TcpListener::bind(listener_addr)?;
let thread_name = format!("{}: listener thread of OxHTTP", listener_addr);
let thread_limit = thread_limit.clone();
Builder::new().name(thread_name).spawn_scoped(scope, move || {
let on_request = Arc::clone(&self.on_request);
let server = self.server.clone();
Builder::new().name(thread_name).spawn(move || {
for stream in listener.incoming() {
match stream {
Ok(stream) => {
Expand All @@ -104,10 +122,12 @@ impl Server {
};
let thread_name = format!("{}: responding thread of OxHTTP", peer_addr);
let thread_guard = thread_limit.as_ref().map(|s| s.lock());
if let Err(error) = Builder::new().name(thread_name).spawn_scoped(scope,
let on_request = Arc::clone(&on_request);
let server = server.clone();
if let Err(error) = Builder::new().name(thread_name).spawn(
move || {
if let Err(error) =
accept_request(stream, &*self.on_request, timeout, &self.server)
accept_request(stream, &*on_request, timeout, &server)
{
eprintln!(
"OxHTTP TCP error when writing response to {peer_addr}: {error}"
Expand All @@ -127,41 +147,33 @@ impl Server {
})
})
.collect::<Result<Vec<_>>>()?;
for thread in listener_threads {
thread.join().map_err(|e| {
Error::new(
ErrorKind::Other,
if let Ok(e) = e.downcast::<&dyn fmt::Display>() {
format!("The server thread panicked with error: {e}")
} else {
"The server thread panicked with an unknown error".into()
},
)
})?;
}
Ok(())
Ok(ListeningServer {
threads: listener_threads,
})
}
}

fn open_tcp(address: impl ToSocketAddrs) -> Result<Vec<TcpListener>> {
let mut listeners = Vec::new();
let mut last_error = None;
for address in address.to_socket_addrs()? {
match TcpListener::bind(address) {
Ok(listener) => listeners.push(listener),
Err(e) => last_error = Some(e),
/// Handle to a running server created by [`Server::spawn`].
pub struct ListeningServer {
threads: Vec<JoinHandle<()>>,
}

impl ListeningServer {
/// Join the server threads and wait for them indefinitely except in case of crash.
pub fn join(self) -> Result<()> {
for thread in self.threads {
thread.join().map_err(|e| {
Error::new(
ErrorKind::Other,
if let Ok(e) = e.downcast::<&dyn fmt::Display>() {
format!("The server thread panicked with error: {e}")
} else {
"The server thread panicked with an unknown error".into()
},
)
})?;
}
}
if listeners.is_empty() {
Err(last_error.unwrap_or_else(|| {
Error::new(
ErrorKind::InvalidInput,
"could not resolve to any addresses",
)
}))
} else {
Ok(listeners)
Ok(())
}
}

Expand Down Expand Up @@ -328,7 +340,8 @@ mod tests {
use super::*;
use crate::model::Status;
use std::io::Read;
use std::thread::{sleep, spawn};
use std::net::{Ipv4Addr, Ipv6Addr};
use std::thread::sleep;

#[test]
fn test_regular_http_operations() -> Result<()> {
Expand Down Expand Up @@ -365,20 +378,19 @@ mod tests {
requests: impl IntoIterator<Item = &'static str>,
responses: impl IntoIterator<Item = &'static str>,
) -> Result<()> {
spawn(move || {
Server::new(|request| {
if request.url().path() == "/" {
Response::builder(Status::OK).with_body("home")
} else {
Response::builder(Status::NOT_FOUND).build()
}
})
.with_server_name("OxHTTP/1.0")
.unwrap()
.with_global_timeout(Duration::from_secs(1))
.listen(("localhost", server_port))
.unwrap();
});
Server::new(|request| {
if request.url().path() == "/" {
Response::builder(Status::OK).with_body("home")
} else {
Response::builder(Status::NOT_FOUND).build()
}
})
.bind((Ipv4Addr::LOCALHOST, server_port))
.bind((Ipv6Addr::LOCALHOST, server_port))
.with_server_name("OxHTTP/1.0")
.unwrap()
.with_global_timeout(Duration::from_secs(1))
.spawn()?;
sleep(Duration::from_millis(100)); // Makes sure the server is up
let mut stream = TcpStream::connect((request_host, server_port))?;
for (request, response) in requests.into_iter().zip(responses) {
Expand All @@ -395,15 +407,14 @@ mod tests {
let server_port = 9996;
let request = b"GET / HTTP/1.1\nhost: localhost:9999\n\n";
let response = b"HTTP/1.1 200 OK\r\nserver: OxHTTP/1.0\r\ncontent-length: 4\r\n\r\nhome";
spawn(move || {
Server::new(|_| Response::builder(Status::OK).with_body("home"))
.with_server_name("OxHTTP/1.0")
.unwrap()
.with_global_timeout(Duration::from_secs(1))
.with_max_concurrent_connections(2)
.listen(("localhost", server_port))
.unwrap();
});
Server::new(|_| Response::builder(Status::OK).with_body("home"))
.bind((Ipv4Addr::LOCALHOST, server_port))
.bind((Ipv6Addr::LOCALHOST, server_port))
.with_server_name("OxHTTP/1.0")
.unwrap()
.with_global_timeout(Duration::from_secs(1))
.with_max_concurrent_connections(2)
.spawn()?;
sleep(Duration::from_millis(100)); // Makes sure the server is up
let streams = (0..128)
.map(|_| {
Expand Down

0 comments on commit 87b5e81

Please sign in to comment.