Skip to content

Commit

Permalink
feat(api)!: make IoErrors respond with BadRequest
Browse files Browse the repository at this point in the history
  • Loading branch information
jbr committed May 30, 2024
1 parent ea8b215 commit c471bca
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 1 deletion.
2 changes: 2 additions & 0 deletions api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,5 @@ trillium-router = { path = "../router" }
trillium-smol = { path = "../smol" }
trillium-testing = { path = "../testing" }
trillium-api = { path = ".", features = ["url"] }
test-harness = "0.2.0"
async-channel = "2.3.1"
1 change: 1 addition & 0 deletions api/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ impl From<&Error> for Status {
Status::UnsupportedMediaType
}
Error::FailureToNegotiateContent => Status::NotAcceptable,
Error::IoError { .. } => Status::BadRequest,
_ => Status::InternalServerError,
}
}
Expand Down
122 changes: 122 additions & 0 deletions api/tests/disconnect-body.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
use async_channel::Sender;
use test_harness::test;
use trillium::{async_trait, Conn, Handler, Status};
use trillium_testing::{
futures_lite::AsyncWriteExt, harness, AsyncWrite, ClientConfig, Connector, ObjectSafeConnector,
ServerHandle,
};

struct LastStatus(Sender<Option<Status>>);

#[async_trait]
impl Handler for LastStatus {
async fn run(&self, conn: Conn) -> Conn {
conn
}
async fn before_send(&self, conn: Conn) -> Conn {
self.0.send(conn.status()).await.unwrap();
conn
}
}

#[test(harness)]
async fn disconnect_on_string_body() {
async fn api_handler(conn: &mut Conn, body: String) {
conn.set_body(body);
}
let (sender, receiver) = async_channel::bounded(1);
let handler = (LastStatus(sender), trillium_api::api(api_handler));
let (handle, mut client) = establish_server(handler).await;

client
.write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\nContent-Length: 10\r\n\r\nnot ten")
.await
.unwrap();

drop(client);
assert_eq!(Some(Status::BadRequest), receiver.recv().await.unwrap());
handle.stop().await;
}

/// this test exists to confirm that the 400 response tested above is in fact due to the disconnect
#[test(harness)]
async fn normal_string_body() {
async fn api_handler(conn: &mut Conn, body: String) {
conn.set_body(body);
}
let (sender, receiver) = async_channel::bounded(1);
let handler = (LastStatus(sender), trillium_api::api(api_handler));
let (handle, mut client) = establish_server(handler).await;
client
.write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\nContent-Length: 10\r\n\r\nexactlyten")
.await
.unwrap();

drop(client);
assert_eq!(Some(Status::Ok), receiver.recv().await.unwrap());
handle.stop().await;
}

#[test(harness)]
async fn disconnect_on_vec_body() {
async fn api_handler(conn: &mut Conn, body: Vec<u8>) {
conn.set_body(body);
}
let (sender, receiver) = async_channel::bounded(1);
let handler = (LastStatus(sender), trillium_api::api(api_handler));
let (handle, mut client) = establish_server(handler).await;

client
.write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\nContent-Length: 10\r\n\r\nnot ten")
.await
.unwrap();

drop(client);
assert_eq!(Some(Status::BadRequest), receiver.recv().await.unwrap());
handle.stop().await;
}

/// this test exists to confirm that the 400 response tested above is in fact due to the disconnect
#[test(harness)]
async fn normal_vec_body() {
async fn api_handler(conn: &mut Conn, body: Vec<u8>) {
conn.set_body(body);
}
let (sender, receiver) = async_channel::bounded(1);
let handler = (LastStatus(sender), trillium_api::api(api_handler));
let (handle, mut client) = establish_server(handler).await;
client
.write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\nContent-Length: 10\r\n\r\nexactlyten")
.await
.unwrap();

drop(client);
assert_eq!(Some(Status::Ok), receiver.recv().await.unwrap());
handle.stop().await;
}

async fn establish_server(handler: impl Handler) -> (ServerHandle, impl AsyncWrite) {
let _ = env_logger::builder().is_test(true).try_init();

let handle = trillium_testing::config().with_port(0).spawn(handler);
let info = handle.info().await;
let port = info.tcp_socket_addr().map_or_else(
|| {
info.listener_description()
.split(":")
.nth(1)
.unwrap()
.parse()
.unwrap()
},
|x| x.port(),
);

let client = Connector::connect(
&ClientConfig::default().boxed(),
&format!("http://localhost:{port}").parse().unwrap(),
)
.await
.unwrap();
(handle, client)
}
2 changes: 1 addition & 1 deletion testing/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ mod server_connector;
pub use server_connector::{connector, ServerConnector};

use trillium_server_common::Config;
pub use trillium_server_common::{Connector, ObjectSafeConnector, Server};
pub use trillium_server_common::{Connector, ObjectSafeConnector, Server, ServerHandle};

#[derive(Debug)]
/// A droppable future
Expand Down

0 comments on commit c471bca

Please sign in to comment.