Skip to content

Commit

Permalink
Reject overflowing connection with status code 429 (#456)
Browse files Browse the repository at this point in the history
* Reject overflowing connection with status code 429

* fmt

* rename Handshake -> HandshakeMode for clarity; verbose test

* Gracefully shutdown after rejecting to hopefully fix the errors on windows

* HandshakeMode -> HandshakeResponse; tweak pending subscriptions on shutdown test
  • Loading branch information
maciejhirsz authored Sep 14, 2021
1 parent a5162ec commit bf2fff0
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 54 deletions.
25 changes: 18 additions & 7 deletions test-utils/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,22 +88,33 @@ impl std::fmt::Debug for WebSocketTestClient {
}
}

#[derive(Debug)]
pub enum WebSocketTestError {
Redirect,
RejectedWithStatusCode(u16),
Soketto(SokettoError),
}

impl From<io::Error> for WebSocketTestError {
fn from(err: io::Error) -> Self {
WebSocketTestError::Soketto(SokettoError::Io(err))
}
}

impl WebSocketTestClient {
pub async fn new(url: SocketAddr) -> Result<Self, SokettoError> {
pub async fn new(url: SocketAddr) -> Result<Self, WebSocketTestError> {
let socket = TcpStream::connect(url).await?;
let mut client = handshake::Client::new(BufReader::new(BufWriter::new(socket.compat())), "test-client", "/");
match client.handshake().await {
Ok(handshake::ServerResponse::Accepted { .. }) => {
let (tx, rx) = client.into_builder().finish();
Ok(Self { tx, rx })
}
Ok(handshake::ServerResponse::Redirect { .. }) => {
Err(SokettoError::Io(io::Error::new(io::ErrorKind::Other, "Redirection not supported in tests")))
Ok(handshake::ServerResponse::Redirect { .. }) => Err(WebSocketTestError::Redirect),
Ok(handshake::ServerResponse::Rejected { status_code }) => {
Err(WebSocketTestError::RejectedWithStatusCode(status_code))
}
Ok(handshake::ServerResponse::Rejected { .. }) => {
Err(SokettoError::Io(io::Error::new(io::ErrorKind::Other, "Rejected")))
}
Err(err) => Err(err),
Err(err) => Err(WebSocketTestError::Soketto(err)),
}
}

Expand Down
14 changes: 11 additions & 3 deletions tests/tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,8 +281,16 @@ async fn ws_close_pending_subscription_when_server_terminated() {

// no new request should be accepted.
assert!(matches!(sub2, Err(_)));

// consume final message
assert!(matches!(sub.next().await, Ok(Some(_))));
// the already established subscription should also be closed.
assert!(matches!(sub.next().await, Ok(None)));
for _ in 0..2 {
match sub.next().await {
// All good, exit test
Ok(None) => return,
// Try again
_ => continue,
}
}

panic!("subscription keeps sending messages after server shutdown");
}
92 changes: 55 additions & 37 deletions ws-server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ use crate::types::{
};
use futures_channel::mpsc;
use futures_util::io::{BufReader, BufWriter};
// use futures_util::future::FutureExt;
use futures_util::stream::StreamExt;
use soketto::handshake::{server::Response, Server as SokettoServer};
use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
Expand Down Expand Up @@ -86,13 +85,17 @@ impl Server {

if connections.count() >= self.cfg.max_connections as usize {
log::warn!("Too many connections. Try again in a while.");
connections.add(Box::pin(handshake(socket, HandshakeResponse::Reject { status_code: 429 })));
continue;
}

let methods = &methods;
let cfg = &self.cfg;

connections.add(Box::pin(handshake(socket, id, methods, cfg, &stop_monitor)));
connections.add(Box::pin(handshake(
socket,
HandshakeResponse::Accept { conn_id: id, methods, cfg, stop_monitor: &stop_monitor },
)));

id = id.wrapping_add(1);
}
Expand Down Expand Up @@ -139,49 +142,64 @@ impl<'a> Future for Incoming<'a> {
}
}

async fn handshake(
socket: tokio::net::TcpStream,
conn_id: ConnectionId,
methods: &Methods,
cfg: &Settings,
stop_monitor: &StopMonitor,
) -> Result<(), Error> {
enum HandshakeResponse<'a> {
Reject { status_code: u16 },
Accept { conn_id: ConnectionId, methods: &'a Methods, cfg: &'a Settings, stop_monitor: &'a StopMonitor },
}

async fn handshake(socket: tokio::net::TcpStream, mode: HandshakeResponse<'_>) -> Result<(), Error> {
// For each incoming background_task we perform a handshake.
let mut server = SokettoServer::new(BufReader::new(BufWriter::new(socket.compat())));

let key = {
let req = server.receive_request().await?;
let host_check = cfg.allowed_hosts.verify("Host", Some(req.headers().host));
let origin_check = cfg.allowed_origins.verify("Origin", req.headers().origin);
match mode {
HandshakeResponse::Reject { status_code } => {
// Forced rejection, don't need to read anything from the socket
let reject = Response::Reject { status_code };
server.send_response(&reject).await?;

host_check.and(origin_check).map(|()| req.key())
};
let (mut sender, _) = server.into_builder().finish();

match key {
Ok(key) => {
let accept = Response::Accept { key, protocol: None };
server.send_response(&accept).await?;
}
Err(error) => {
let reject = Response::Reject { status_code: 403 };
server.send_response(&reject).await?;
// Gracefully shut down the connection
sender.close().await?;

return Err(error);
Ok(())
}
}
HandshakeResponse::Accept { conn_id, methods, cfg, stop_monitor } => {
let key = {
let req = server.receive_request().await?;
let host_check = cfg.allowed_hosts.verify("Host", Some(req.headers().host));
let origin_check = cfg.allowed_origins.verify("Origin", req.headers().origin);

let join_result = tokio::spawn(background_task(
server,
conn_id,
methods.clone(),
cfg.max_request_body_size,
stop_monitor.clone(),
))
.await;

match join_result {
Err(_) => Err(Error::Custom("Background task was aborted".into())),
Ok(result) => result,
host_check.and(origin_check).map(|()| req.key())
};

match key {
Ok(key) => {
let accept = Response::Accept { key, protocol: None };
server.send_response(&accept).await?;
}
Err(error) => {
let reject = Response::Reject { status_code: 403 };
server.send_response(&reject).await?;

return Err(error);
}
}

let join_result = tokio::spawn(background_task(
server,
conn_id,
methods.clone(),
cfg.max_request_body_size,
stop_monitor.clone(),
))
.await;

match join_result {
Err(_) => Err(Error::Custom("Background task was aborted".into())),
Ok(result) => result,
}
}
}
}

Expand Down
11 changes: 4 additions & 7 deletions ws-server/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use crate::{future::StopHandle, RpcModule, WsServerBuilder};
use anyhow::anyhow;
use futures_util::FutureExt;
use jsonrpsee_test_utils::helpers::*;
use jsonrpsee_test_utils::types::{Id, TestContext, WebSocketTestClient};
use jsonrpsee_test_utils::types::{Id, TestContext, WebSocketTestClient, WebSocketTestError};
use jsonrpsee_test_utils::TimeoutFutureExt;
use serde_json::Value as JsonValue;
use std::fmt;
Expand Down Expand Up @@ -203,12 +203,9 @@ async fn can_set_max_connections() {
assert!(conn2.is_ok());
// Third connection is rejected
assert!(conn3.is_err());

let err = match conn3 {
Err(soketto::handshake::Error::Io(err)) => err,
_ => panic!("Invalid error kind; expected std::io::Error"),
};
assert_eq!(err.kind(), std::io::ErrorKind::ConnectionReset);
if !matches!(conn3, Err(WebSocketTestError::RejectedWithStatusCode(429))) {
panic!("Expected RejectedWithStatusCode(429), got: {:#?}", conn3);
}

// Decrement connection count
drop(conn2);
Expand Down

0 comments on commit bf2fff0

Please sign in to comment.