diff --git a/actix-web/CHANGES.md b/actix-web/CHANGES.md index 9ded67e35ee..00ab6c9c83f 100644 --- a/actix-web/CHANGES.md +++ b/actix-web/CHANGES.md @@ -3,7 +3,9 @@ ## Unreleased - 2022-xx-xx ### Added - Add `ContentDisposition::attachment` constructor. [#2867] +- Add `ErrorHandlers::default_handler()` (as well as `default_handler_{server, client}()`) to make registering handlers for groups of response statuses easier. [#2784] +[#2784]: https://github.com/actix/actix-web/pull/2784 [#2867]: https://github.com/actix/actix-web/pull/2867 diff --git a/actix-web/src/middleware/err_handlers.rs b/actix-web/src/middleware/err_handlers.rs index f74220cd2e7..3a4e44a2c0f 100644 --- a/actix-web/src/middleware/err_handlers.rs +++ b/actix-web/src/middleware/err_handlers.rs @@ -30,11 +30,25 @@ pub enum ErrorHandlerResponse { type ErrorHandler = dyn Fn(ServiceResponse) -> Result>; +type DefaultHandler = Option>>; + /// Middleware for registering custom status code based error handlers. /// -/// Register handlers with the `ErrorHandlers::handler()` method to register a custom error handler +/// Register handlers with the [`ErrorHandlers::handler()`] method to register a custom error handler /// for a given status code. Handlers can modify existing responses or create completely new ones. /// +/// To register a default handler, use the [`ErrorHandlers::default_handler()`] method. This +/// handler will be used only if a response has an error status code (400-599) that isn't covered by +/// a more specific handler (set with the [`handler()`][ErrorHandlers::handler] method). See examples +/// below. +/// +/// To register a default for only client errors (400-499) or only server errors (500-599), use the +/// [`ErrorHandlers::default_handler_client()`] and [`ErrorHandlers::default_handler_server()`] +/// methods, respectively. +/// +/// Any response with a status code that isn't covered by a specific handler or a default handler +/// will pass by unchanged by this middleware. +/// /// # Examples /// ``` /// use actix_web::http::{header, StatusCode}; @@ -53,7 +67,70 @@ type ErrorHandler = dyn Fn(ServiceResponse) -> Result(mut res: dev::ServiceResponse) -> Result> { +/// res.response_mut().headers_mut().insert( +/// header::CONTENT_TYPE, +/// header::HeaderValue::from_static("Error"), +/// ); +/// Ok(ErrorHandlerResponse::Response(res.map_into_left_body())) +/// } +/// +/// fn handle_bad_request(mut res: dev::ServiceResponse) -> Result> { +/// res.response_mut().headers_mut().insert( +/// header::CONTENT_TYPE, +/// header::HeaderValue::from_static("Bad Request Error"), +/// ); +/// Ok(ErrorHandlerResponse::Response(res.map_into_left_body())) +/// } +/// +/// // Bad Request errors will hit `handle_bad_request()`, while all other errors will hit +/// // `add_error_header()`. The order in which the methods are called is not meaningful. +/// let app = App::new() +/// .wrap( +/// ErrorHandlers::new() +/// .default_handler(add_error_header) +/// .handler(StatusCode::BAD_REQUEST, handle_bad_request) +/// ) +/// .service(web::resource("/").route(web::get().to(HttpResponse::InternalServerError))); +/// ``` +/// Alternatively, you can set default handlers for only client or only server errors: +/// +/// ```rust +/// # use actix_web::http::{header, StatusCode}; +/// # use actix_web::middleware::{ErrorHandlerResponse, ErrorHandlers}; +/// # use actix_web::{dev, web, App, HttpResponse, Result}; +/// # fn add_error_header(mut res: dev::ServiceResponse) -> Result> { +/// # res.response_mut().headers_mut().insert( +/// # header::CONTENT_TYPE, +/// # header::HeaderValue::from_static("Error"), +/// # ); +/// # Ok(ErrorHandlerResponse::Response(res.map_into_left_body())) +/// # } +/// # fn handle_bad_request(mut res: dev::ServiceResponse) -> Result> { +/// # res.response_mut().headers_mut().insert( +/// # header::CONTENT_TYPE, +/// # header::HeaderValue::from_static("Bad Request Error"), +/// # ); +/// # Ok(ErrorHandlerResponse::Response(res.map_into_left_body())) +/// # } +/// // Bad request errors will hit `handle_bad_request()`, other client errors will hit +/// // `add_error_header()`, and server errors will pass through unchanged +/// let app = App::new() +/// .wrap( +/// ErrorHandlers::new() +/// .default_handler_client(add_error_header) // or .default_handler_server +/// .handler(StatusCode::BAD_REQUEST, handle_bad_request) +/// ) +/// .service(web::resource("/").route(web::get().to(HttpResponse::InternalServerError))); +/// ``` pub struct ErrorHandlers { + default_client: DefaultHandler, + default_server: DefaultHandler, handlers: Handlers, } @@ -62,6 +139,8 @@ type Handlers = Rc>>>; impl Default for ErrorHandlers { fn default() -> Self { ErrorHandlers { + default_client: Default::default(), + default_server: Default::default(), handlers: Default::default(), } } @@ -83,6 +162,66 @@ impl ErrorHandlers { .insert(status, Box::new(handler)); self } + + /// Register a default error handler. + /// + /// Any request with a status code that hasn't been given a specific other handler (by calling + /// [`.handler()`][ErrorHandlers::handler]) will fall back on this. + /// + /// Note that this will overwrite any default handlers previously set by calling + /// [`.default_handler_client()`][ErrorHandlers::default_handler_client] or + /// [`.default_handler_server()`][ErrorHandlers::default_handler_server], but not any set by + /// calling [`.handler()`][ErrorHandlers::handler]. + pub fn default_handler(self, handler: F) -> Self + where + F: Fn(ServiceResponse) -> Result> + 'static, + { + let handler = Rc::new(handler); + Self { + default_server: Some(handler.clone()), + default_client: Some(handler), + ..self + } + } + + /// Register a handler on which to fall back for client error status codes (400-499). + pub fn default_handler_client(self, handler: F) -> Self + where + F: Fn(ServiceResponse) -> Result> + 'static, + { + Self { + default_client: Some(Rc::new(handler)), + ..self + } + } + + /// Register a handler on which to fall back for server error status codes (500-599). + pub fn default_handler_server(self, handler: F) -> Self + where + F: Fn(ServiceResponse) -> Result> + 'static, + { + Self { + default_server: Some(Rc::new(handler)), + ..self + } + } + + /// Selects the most appropriate handler for the given status code. + /// + /// If the `handlers` map has an entry for that status code, that handler is returned. + /// Otherwise, fall back on the appropriate default handler. + fn get_handler<'a>( + status: &StatusCode, + default_client: Option<&'a ErrorHandler>, + default_server: Option<&'a ErrorHandler>, + handlers: &'a Handlers, + ) -> Option<&'a ErrorHandler> { + handlers + .get(status) + .map(|h| h.as_ref()) + .or_else(|| status.is_client_error().then(|| default_client).flatten()) + .or_else(|| status.is_server_error().then(|| default_server).flatten()) + } } impl Transform for ErrorHandlers @@ -99,13 +238,24 @@ where fn new_transform(&self, service: S) -> Self::Future { let handlers = self.handlers.clone(); - Box::pin(async move { Ok(ErrorHandlersMiddleware { service, handlers }) }) + let default_client = self.default_client.clone(); + let default_server = self.default_server.clone(); + Box::pin(async move { + Ok(ErrorHandlersMiddleware { + service, + default_client, + default_server, + handlers, + }) + }) } } #[doc(hidden)] pub struct ErrorHandlersMiddleware { service: S, + default_client: DefaultHandler, + default_server: DefaultHandler, handlers: Handlers, } @@ -123,8 +273,15 @@ where fn call(&self, req: ServiceRequest) -> Self::Future { let handlers = self.handlers.clone(); + let default_client = self.default_client.clone(); + let default_server = self.default_server.clone(); let fut = self.service.call(req); - ErrorHandlersFuture::ServiceFuture { fut, handlers } + ErrorHandlersFuture::ServiceFuture { + fut, + default_client, + default_server, + handlers, + } } } @@ -137,6 +294,8 @@ pin_project! { ServiceFuture { #[pin] fut: Fut, + default_client: DefaultHandler, + default_server: DefaultHandler, handlers: Handlers, }, ErrorHandlerFuture { @@ -153,10 +312,22 @@ where fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.as_mut().project() { - ErrorHandlersProj::ServiceFuture { fut, handlers } => { + ErrorHandlersProj::ServiceFuture { + fut, + default_client, + default_server, + handlers, + } => { let res = ready!(fut.poll(cx))?; - - match handlers.get(&res.status()) { + let status = res.status(); + + let handler = ErrorHandlers::get_handler( + &status, + default_client.as_mut().map(|f| Rc::as_ref(f)), + default_server.as_mut().map(|f| Rc::as_ref(f)), + handlers, + ); + match handler { Some(handler) => match handler(res)? { ErrorHandlerResponse::Response(res) => Poll::Ready(Ok(res)), ErrorHandlerResponse::Future(fut) => { @@ -166,7 +337,6 @@ where self.poll(cx) } }, - None => Poll::Ready(Ok(res.map_into_left_body())), } } @@ -298,4 +468,117 @@ mod tests { "error in error handler" ); } + + #[actix_rt::test] + async fn default_error_handler() { + #[allow(clippy::unnecessary_wraps)] + fn error_handler(mut res: ServiceResponse) -> Result> { + res.response_mut() + .headers_mut() + .insert(CONTENT_TYPE, HeaderValue::from_static("0001")); + Ok(ErrorHandlerResponse::Response(res.map_into_left_body())) + } + + let make_mw = |status| async move { + ErrorHandlers::new() + .default_handler(error_handler) + .new_transform(test::status_service(status).into_service()) + .await + .unwrap() + }; + let mw_server = make_mw(StatusCode::INTERNAL_SERVER_ERROR).await; + let mw_client = make_mw(StatusCode::BAD_REQUEST).await; + + let resp = + test::call_service(&mw_client, TestRequest::default().to_srv_request()).await; + assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); + + let resp = + test::call_service(&mw_server, TestRequest::default().to_srv_request()).await; + assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); + } + + #[actix_rt::test] + async fn default_handlers_separate_client_server() { + #[allow(clippy::unnecessary_wraps)] + fn error_handler_client( + mut res: ServiceResponse, + ) -> Result> { + res.response_mut() + .headers_mut() + .insert(CONTENT_TYPE, HeaderValue::from_static("0001")); + Ok(ErrorHandlerResponse::Response(res.map_into_left_body())) + } + + #[allow(clippy::unnecessary_wraps)] + fn error_handler_server( + mut res: ServiceResponse, + ) -> Result> { + res.response_mut() + .headers_mut() + .insert(CONTENT_TYPE, HeaderValue::from_static("0002")); + Ok(ErrorHandlerResponse::Response(res.map_into_left_body())) + } + + let make_mw = |status| async move { + ErrorHandlers::new() + .default_handler_server(error_handler_server) + .default_handler_client(error_handler_client) + .new_transform(test::status_service(status).into_service()) + .await + .unwrap() + }; + let mw_server = make_mw(StatusCode::INTERNAL_SERVER_ERROR).await; + let mw_client = make_mw(StatusCode::BAD_REQUEST).await; + + let resp = + test::call_service(&mw_client, TestRequest::default().to_srv_request()).await; + assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); + + let resp = + test::call_service(&mw_server, TestRequest::default().to_srv_request()).await; + assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0002"); + } + + #[actix_rt::test] + async fn default_handlers_specialization() { + #[allow(clippy::unnecessary_wraps)] + fn error_handler_client( + mut res: ServiceResponse, + ) -> Result> { + res.response_mut() + .headers_mut() + .insert(CONTENT_TYPE, HeaderValue::from_static("0001")); + Ok(ErrorHandlerResponse::Response(res.map_into_left_body())) + } + + #[allow(clippy::unnecessary_wraps)] + fn error_handler_specific( + mut res: ServiceResponse, + ) -> Result> { + res.response_mut() + .headers_mut() + .insert(CONTENT_TYPE, HeaderValue::from_static("0003")); + Ok(ErrorHandlerResponse::Response(res.map_into_left_body())) + } + + let make_mw = |status| async move { + ErrorHandlers::new() + .default_handler_client(error_handler_client) + .handler(StatusCode::UNPROCESSABLE_ENTITY, error_handler_specific) + .new_transform(test::status_service(status).into_service()) + .await + .unwrap() + }; + let mw_client = make_mw(StatusCode::BAD_REQUEST).await; + let mw_specific = make_mw(StatusCode::UNPROCESSABLE_ENTITY).await; + + let resp = + test::call_service(&mw_client, TestRequest::default().to_srv_request()).await; + assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001"); + + let resp = + test::call_service(&mw_specific, TestRequest::default().to_srv_request()).await; + assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0003"); + } }