From b11d407dc963f79a73cee2439c7062e30ca83913 Mon Sep 17 00:00:00 2001 From: Matthew Pomes Date: Sun, 17 Nov 2024 18:35:59 -0600 Subject: [PATCH 01/20] Apply basic design to the core (and http) library --- core/http/Cargo.toml | 3 +- core/http/src/status.rs | 3 + core/lib/Cargo.toml | 7 +- core/lib/src/catcher/catcher.rs | 19 +- core/lib/src/catcher/handler.rs | 13 +- core/lib/src/catcher/mod.rs | 2 + core/lib/src/catcher/types.rs | 334 ++++++++++++++++++++ core/lib/src/data/capped.rs | 2 +- core/lib/src/erased.rs | 35 +- core/lib/src/fairing/fairings.rs | 14 + core/lib/src/fairing/info_kind.rs | 10 +- core/lib/src/fairing/mod.rs | 19 ++ core/lib/src/fs/named_file.rs | 2 +- core/lib/src/fs/server.rs | 5 +- core/lib/src/lifecycle.rs | 135 +++++--- core/lib/src/local/asynchronous/request.rs | 8 +- core/lib/src/local/asynchronous/response.rs | 11 +- core/lib/src/outcome.rs | 7 +- core/lib/src/response/content.rs | 4 +- core/lib/src/response/debug.rs | 8 +- core/lib/src/response/flash.rs | 2 +- core/lib/src/response/mod.rs | 4 +- core/lib/src/response/redirect.rs | 4 +- core/lib/src/response/responder.rs | 43 +-- core/lib/src/response/status.rs | 17 +- core/lib/src/response/stream/bytes.rs | 2 +- core/lib/src/response/stream/reader.rs | 2 +- core/lib/src/response/stream/sse.rs | 2 +- core/lib/src/response/stream/text.rs | 2 +- core/lib/src/route/handler.rs | 12 +- core/lib/src/router/matcher.rs | 5 +- core/lib/src/router/router.rs | 24 +- core/lib/src/server.rs | 6 +- core/lib/src/trace/traceable.rs | 4 +- 34 files changed, 630 insertions(+), 140 deletions(-) create mode 100644 core/lib/src/catcher/types.rs diff --git a/core/http/Cargo.toml b/core/http/Cargo.toml index ff62a0ae87..b2beca4adc 100644 --- a/core/http/Cargo.toml +++ b/core/http/Cargo.toml @@ -21,7 +21,7 @@ workspace = true [features] default = [] serde = ["dep:serde", "uncased/with-serde-alloc"] -uuid = ["dep:uuid"] +uuid = ["dep:uuid", "transient/uuid"] [dependencies] tinyvec = { version = "1.6", features = ["std", "rustc_1_57"] } @@ -36,6 +36,7 @@ memchr = "2" stable-pattern = "0.1" cookie = { version = "0.18", features = ["percent-encode"] } state = "0.6" +transient = "0.4.1" [dependencies.serde] version = "1.0" diff --git a/core/http/src/status.rs b/core/http/src/status.rs index 1aa882f438..41d2d90510 100644 --- a/core/http/src/status.rs +++ b/core/http/src/status.rs @@ -1,4 +1,5 @@ use std::fmt; +use transient::Static; /// Enumeration of HTTP status classes. #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] @@ -118,6 +119,8 @@ pub struct Status { pub code: u16, } +impl Static for Status {} + impl Default for Status { fn default() -> Self { Status::Ok diff --git a/core/lib/Cargo.toml b/core/lib/Cargo.toml index 0ce5f3a854..848d53f6e0 100644 --- a/core/lib/Cargo.toml +++ b/core/lib/Cargo.toml @@ -27,9 +27,9 @@ default = ["http2", "tokio-macros", "trace"] http2 = ["hyper/http2", "hyper-util/http2"] http3-preview = ["s2n-quic", "s2n-quic-h3", "tls"] secrets = ["cookie/private", "cookie/key-expansion"] -json = ["serde_json"] -msgpack = ["rmp-serde"] -uuid = ["uuid_", "rocket_http/uuid"] +json = ["serde_json", "transient/serde_json"] +msgpack = ["rmp-serde", "transient/rmp-serde"] +uuid = ["uuid_", "rocket_http/uuid", "transient/uuid"] tls = ["rustls", "tokio-rustls", "rustls-pemfile"] mtls = ["tls", "x509-parser"] tokio-macros = ["tokio/macros"] @@ -74,6 +74,7 @@ tokio-stream = { version = "0.1.6", features = ["signal", "time"] } cookie = { version = "0.18", features = ["percent-encode"] } futures = { version = "0.3.30", default-features = false, features = ["std"] } state = "0.6" +transient = { version = "0.4.1", features = ["either"] } # tracing tracing = { version = "0.1.40", default-features = false, features = ["std", "attributes"] } diff --git a/core/lib/src/catcher/catcher.rs b/core/lib/src/catcher/catcher.rs index 2aa1402ada..b030285ecb 100644 --- a/core/lib/src/catcher/catcher.rs +++ b/core/lib/src/catcher/catcher.rs @@ -1,6 +1,8 @@ use std::fmt; use std::io::Cursor; +use transient::TypeId; + use crate::http::uri::Path; use crate::http::ext::IntoOwned; use crate::response::Response; @@ -8,6 +10,8 @@ use crate::request::Request; use crate::http::{Status, ContentType, uri}; use crate::catcher::{Handler, BoxFuture}; +use super::TypedError; + /// An error catching route. /// /// Catchers are routes that run when errors are produced by the application. @@ -128,6 +132,9 @@ pub struct Catcher { /// This is -(number of nonempty segments in base). pub(crate) rank: isize, + /// TypeId to match against + pub(crate) type_id: Option, + /// The catcher's file, line, and column location. pub(crate) location: Option<(&'static str, u32, u32)>, } @@ -188,6 +195,7 @@ impl Catcher { base: uri::Origin::root().clone(), handler: Box::new(handler), rank: rank(uri::Origin::root().path()), + type_id: None, code, location: None, } @@ -313,8 +321,8 @@ impl Catcher { impl Default for Catcher { fn default() -> Self { - fn handler<'r>(s: Status, req: &'r Request<'_>) -> BoxFuture<'r> { - Box::pin(async move { Ok(default_handler(s, req)) }) + fn handler<'r>(status: Status, e: &'r dyn TypedError<'r>, req: &'r Request<'_>) -> BoxFuture<'r> { + Box::pin(async move { Ok(default_handler(status, e, req)) }) } let mut catcher = Catcher::new(None, handler); @@ -331,7 +339,9 @@ pub struct StaticInfo { /// The catcher's status code. pub code: Option, /// The catcher's handler, i.e, the annotated function. - pub handler: for<'r> fn(Status, &'r Request<'_>) -> BoxFuture<'r>, + pub handler: for<'r> fn(Status, &'r dyn TypedError<'r>, &'r Request<'_>) -> BoxFuture<'r>, + /// TypeId to match against + pub type_id: Option, /// The file, line, and column where the catcher was defined. pub location: (&'static str, u32, u32), } @@ -343,6 +353,7 @@ impl From for Catcher { let mut catcher = Catcher::new(info.code, info.handler); catcher.name = Some(info.name.into()); catcher.location = Some(info.location); + catcher.type_id = info.type_id; catcher } } @@ -354,6 +365,7 @@ impl fmt::Debug for Catcher { .field("base", &self.base) .field("code", &self.code) .field("rank", &self.rank) + .field("type_id", &self.type_id.as_ref().map(|_| "TY")) .finish() } } @@ -418,6 +430,7 @@ macro_rules! default_handler_fn { pub(crate) fn default_handler<'r>( status: Status, + _error: &'r dyn TypedError<'r>, req: &'r Request<'_> ) -> Response<'r> { let preferred = req.accept().map(|a| a.preferred()); diff --git a/core/lib/src/catcher/handler.rs b/core/lib/src/catcher/handler.rs index f33ceba0e3..ea64fff500 100644 --- a/core/lib/src/catcher/handler.rs +++ b/core/lib/src/catcher/handler.rs @@ -1,5 +1,5 @@ -use crate::{Request, Response}; -use crate::http::Status; +use crate::{Request, Response, http::Status}; +use super::TypedError; /// Type alias for the return type of a [`Catcher`](crate::Catcher)'s /// [`Handler::handle()`]. @@ -97,16 +97,17 @@ pub trait Handler: Cloneable + Send + Sync + 'static { /// Nevertheless, failure is allowed, both for convenience and necessity. If /// an error handler fails, Rocket's default `500` catcher is invoked. If it /// succeeds, the returned `Response` is used to respond to the client. - async fn handle<'r>(&self, status: Status, req: &'r Request<'_>) -> Result<'r>; + async fn handle<'r>(&self, status: Status, error: &'r dyn TypedError<'r>, req: &'r Request<'_>) -> Result<'r>; } // We write this manually to avoid double-boxing. impl Handler for F - where for<'x> F: Fn(Status, &'x Request<'_>) -> BoxFuture<'x>, + where for<'x> F: Fn(Status, &'x dyn TypedError<'x>, &'x Request<'_>) -> BoxFuture<'x>, { fn handle<'r, 'life0, 'life1, 'async_trait>( &'life0 self, status: Status, + error: &'r dyn TypedError<'r>, req: &'r Request<'life1>, ) -> BoxFuture<'r> where 'r: 'async_trait, @@ -114,13 +115,13 @@ impl Handler for F 'life1: 'async_trait, Self: 'async_trait, { - self(status, req) + self(status, error, req) } } // Used in tests! Do not use, please. #[doc(hidden)] -pub fn dummy_handler<'r>(_: Status, _: &'r Request<'_>) -> BoxFuture<'r> { +pub fn dummy_handler<'r>(_: Status, _: &'r dyn TypedError<'r>, _: &'r Request<'_>) -> BoxFuture<'r> { Box::pin(async move { Ok(Response::new()) }) } diff --git a/core/lib/src/catcher/mod.rs b/core/lib/src/catcher/mod.rs index 4f5fefa19d..d9bbb48d48 100644 --- a/core/lib/src/catcher/mod.rs +++ b/core/lib/src/catcher/mod.rs @@ -2,6 +2,8 @@ mod catcher; mod handler; +mod types; pub use catcher::*; pub use handler::*; +pub use types::*; diff --git a/core/lib/src/catcher/types.rs b/core/lib/src/catcher/types.rs new file mode 100644 index 0000000000..2eae56a97e --- /dev/null +++ b/core/lib/src/catcher/types.rs @@ -0,0 +1,334 @@ +use either::Either; +use transient::{Any, CanRecoverFrom, Downcast, Transience}; +use crate::{http::Status, response::status::Custom, Request, Response}; +#[doc(inline)] +pub use transient::{Static, Transient, TypeId, Inv, CanTranscendTo}; + +/// Polyfill for trait upcasting to [`Any`] +pub trait AsAny: Any + Sealed { + /// The actual upcast + fn as_any(&self) -> &dyn Any; + /// convience typeid of the inner typeid + fn trait_obj_typeid(&self) -> TypeId; +} + +use sealed::Sealed; +mod sealed { + use transient::{Any, Transience, Transient, TypeId}; + + use super::AsAny; + + pub trait Sealed {} + impl<'r, Tr: Transience, T: Any> Sealed for T { } + impl<'r, Tr: Transience, T: Any + Transient> AsAny for T { + fn as_any(&self) -> &dyn Any { + self + } + fn trait_obj_typeid(&self) -> transient::TypeId { + TypeId::of::() + } + } +} + +/// This is the core of typed catchers. If an error type (returned by +/// FromParam, FromRequest, FromForm, FromData, or Responder) implements +/// this trait, it can be caught by a typed catcher. (TODO) This trait +/// can be derived. +pub trait TypedError<'r>: AsAny> + Send + Sync + 'r { + /// Generates a default response for this type (or forwards to a default catcher) + #[allow(unused_variables)] + fn respond_to(&self, request: &'r Request<'_>) -> Result, Status> { + Err(self.status()) + } + + /// A descriptive name of this error type. Defaults to the type name. + fn name(&self) -> &'static str { std::any::type_name::() } + + /// The error that caused this error. Defaults to None. + /// + /// # Warning + /// A typed catcher will not attempt to follow the source of an error + /// more than (TODO: exact number) 5 times. + fn source(&'r self) -> Option<&'r (dyn TypedError<'r> + 'r)> { None } + + /// Status code + fn status(&self) -> Status { Status::InternalServerError } +} + +// TODO: this is less useful, since impls should generally use `Status` instead. +impl<'r> TypedError<'r> for () { } + +impl<'r> TypedError<'r> for Status { + fn respond_to(&self, _r: &'r Request<'_>) -> Result, Status> { + Err(*self) + } + + fn name(&self) -> &'static str { + // TODO: Status generally shouldn't be caught + "" + } + + fn source(&'r self) -> Option<&'r (dyn TypedError<'r> + 'r)> { + Some(&()) + } + + fn status(&self) -> Status { + *self + } +} + +// TODO: Typed: update transient to make the possible. +// impl<'r, R: TypedError<'r> + Transient> TypedError<'r> for (Status, R) +// where R::Transience: CanTranscendTo> +// { +// fn respond_to(&self, request: &'r Request<'_>) -> Result, Status> { +// self.1.respond_to(request) +// } + +// fn name(&self) -> &'static str { +// self.1.name() +// } + +// fn source(&'r self) -> Option<&'r (dyn TypedError<'r> + 'r)> { +// Some(&self.1) +// } + +// fn status(&self) -> Status { +// self.0 +// } +// } + +impl<'r, R: TypedError<'r> + Transient> TypedError<'r> for Custom + where R::Transience: CanTranscendTo> +{ + fn respond_to(&self, request: &'r Request<'_>) -> Result, Status> { + self.1.respond_to(request) + } + + fn name(&self) -> &'static str { + self.1.name() + } + + fn source(&'r self) -> Option<&'r (dyn TypedError<'r> + 'r)> { + Some(&self.1) + } + + fn status(&self) -> Status { + self.0 + } +} + +impl<'r> TypedError<'r> for std::convert::Infallible { } + +impl<'r> TypedError<'r> for std::io::Error { + fn status(&self) -> Status { + match self.kind() { + std::io::ErrorKind::NotFound => Status::NotFound, + std::io::ErrorKind::PermissionDenied => Status::Unauthorized, + std::io::ErrorKind::AlreadyExists => Status::Conflict, + std::io::ErrorKind::InvalidInput => Status::BadRequest, + _ => Status::InternalServerError, + } + } +} + +impl<'r> TypedError<'r> for std::num::ParseIntError { + fn status(&self) -> Status { Status::BadRequest } +} + +impl<'r> TypedError<'r> for std::num::ParseFloatError { + fn status(&self) -> Status { Status::BadRequest } +} + +impl<'r> TypedError<'r> for std::string::FromUtf8Error { + fn status(&self) -> Status { Status::BadRequest } +} + +#[cfg(feature = "json")] +impl<'r> TypedError<'r> for serde_json::Error { + fn status(&self) -> Status { Status::BadRequest } +} + +#[cfg(feature = "msgpack")] +impl<'r> TypedError<'r> for rmp_serde::encode::Error { } + +#[cfg(feature = "msgpack")] +impl<'r> TypedError<'r> for rmp_serde::decode::Error { + fn status(&self) -> Status { Status::BadRequest } +} + +// // TODO: This is a hack to make any static type implement Transient +// impl<'r, T: std::fmt::Debug + Send + Sync + 'static> TypedError<'r> for response::Debug { +// fn respond_to(&self, request: &'r Request<'_>) -> Result, Status> { +// format!("{:?}", self.0).respond_to(request).responder_error() +// } +// } + +impl<'r, L, R> TypedError<'r> for Either + where L: TypedError<'r> + Transient, + L::Transience: CanTranscendTo>, + R: TypedError<'r> + Transient, + R::Transience: CanTranscendTo>, +{ + fn respond_to(&self, request: &'r Request<'_>) -> Result, Status> { + match self { + Self::Left(v) => v.respond_to(request), + Self::Right(v) => v.respond_to(request), + } + } + + fn name(&self) -> &'static str { + match self { + Self::Left(v) => v.name(), + Self::Right(v) => v.name(), + } + } + + fn source(&'r self) -> Option<&'r (dyn TypedError<'r> + 'r)> { + match self { + Self::Left(v) => Some(v), + Self::Right(v) => Some(v), + } + } + + fn status(&self) -> Status { + match self { + Self::Left(v) => v.status(), + Self::Right(v) => v.status(), + } + } +} + +// // TODO: This cannot be used as a bound on an untyped catcher to get any error type. +// // This is mostly an implementation detail (and issue with double boxing) for +// // the responder derive +// // We should just get rid of this. `&dyn TypedError<'_>` impls `FromError` +// #[derive(Transient)] +// pub struct AnyError<'r>(pub Box + 'r>); + +// impl<'r> TypedError<'r> for AnyError<'r> { +// fn source(&'r self) -> Option<&'r (dyn TypedError<'r> + 'r)> { +// Some(self.0.as_ref()) +// } + +// fn respond_to(&self, request: &'r Request<'_>) -> Result, Status> { +// self.0.respond_to(request) +// } + +// fn name(&self) -> &'static str { self.0.name() } + +// fn status(&self) -> Status { self.0.status() } +// } + +/// Validates that a type implements `TypedError`. Used by the `#[catch]` attribute to ensure +/// the `TypeError` is first in the diagnostics. +#[doc(hidden)] +pub fn type_id_of<'r, T: TypedError<'r> + Transient + 'r>() -> (TypeId, &'static str) { + (TypeId::of::(), std::any::type_name::()) +} + +/// Downcast an error type to the underlying concrete type. Used by the `#[catch]` attribute. +#[doc(hidden)] +pub fn downcast<'r, T>(v: Option<&'r dyn TypedError<'r>>) -> Option<&'r T> + where T: TypedError<'r> + Transient + 'r, + T::Transience: CanRecoverFrom>, +{ + // if v.is_none() { + // crate::trace::error!("No value to downcast from"); + // } + let v = v?; + // crate::trace::error!("Downcasting error from {}", v.name()); + v.as_any().downcast_ref() +} + +/// Upcasts a value to `Box>`, falling back to a default if it doesn't implement +/// `Error` +#[doc(hidden)] +#[macro_export] +macro_rules! resolve_typed_catcher { + ($T:expr) => ({ + #[allow(unused_imports)] + use $crate::catcher::resolution::{Resolve, DefaultTypeErase, ResolvedTypedError}; + + let inner = Resolve::new($T).cast(); + ResolvedTypedError { + name: inner.as_ref().ok().map(|e| e.name()), + val: inner, + } + }); +} + +pub use resolve_typed_catcher; + +pub mod resolution { + use std::marker::PhantomData; + + use transient::{CanTranscendTo, Transient}; + + use super::*; + + /// The *magic*. + /// + /// `Resolve::item` for `T: Transient` is `::item`. + /// `Resolve::item` for `T: !Transient` is `DefaultTypeErase::item`. + /// + /// This _must_ be used as `Resolve:::item` for resolution to work. This + /// is a fun, static dispatch hack for "specialization" that works because + /// Rust prefers inherent methods over blanket trait impl methods. + pub struct Resolve<'r, T: 'r>(pub T, PhantomData<&'r ()>); + + impl<'r, T: 'r> Resolve<'r, T> { + pub fn new(val: T) -> Self { + Self(val, PhantomData) + } + } + + /// Fallback trait "implementing" `Transient` for all types. This is what + /// Rust will resolve `Resolve::item` to when `T: !Transient`. + pub trait DefaultTypeErase<'r>: Sized { + const SPECIALIZED: bool = false; + + fn cast(self) -> Result>, Self> { Err(self) } + } + + impl<'r, T: 'r> DefaultTypeErase<'r> for Resolve<'r, T> {} + + /// "Specialized" "implementation" of `Transient` for `T: Transient`. This is + /// what Rust will resolve `Resolve::item` to when `T: Transient`. + impl<'r, T: TypedError<'r> + Transient> Resolve<'r, T> + where T::Transience: CanTranscendTo> + { + pub const SPECIALIZED: bool = true; + + pub fn cast(self) -> Result>, Self> { Ok(Box::new(self.0)) } + } + + // TODO: These extensions maybe useful, but so far not really + // // Box can be upcast without double boxing? + // impl<'r> Resolve<'r, Box>> { + // pub const SPECIALIZED: bool = true; + + // pub fn cast(self) -> Result>, Self> { Ok(self.0) } + // } + + // Ideally, we should be able to handle this case, but we can't, since we don't own `Either` + // impl<'r, A, B> Resolve<'r, Either> + // where A: TypedError<'r> + Transient, + // A::Transience: CanTranscendTo>, + // B: TypedError<'r> + Transient, + // B::Transience: CanTranscendTo>, + // { + // pub const SPECIALIZED: bool = true; + + // pub fn cast(self) -> Result>, Self> { Ok(Box::new(self.0)) } + // } + + /// Wrapper type to hold the return type of `resolve_typed_catcher`. + #[doc(hidden)] + pub struct ResolvedTypedError<'r, T> { + /// The return value from `TypedError::name()`, if Some + pub name: Option<&'static str>, + /// The upcast error, if it supports it + pub val: Result + 'r>, Resolve<'r, T>>, + } +} diff --git a/core/lib/src/data/capped.rs b/core/lib/src/data/capped.rs index 804a42d486..094a10d5b4 100644 --- a/core/lib/src/data/capped.rs +++ b/core/lib/src/data/capped.rs @@ -205,7 +205,7 @@ use crate::response::{self, Responder}; use crate::request::Request; impl<'r, 'o: 'r, T: Responder<'r, 'o>> Responder<'r, 'o> for Capped { - fn respond_to(self, request: &'r Request<'_>) -> response::Result<'o> { + fn respond_to(self, request: &'r Request<'_>) -> response::Result<'r, 'o> { self.value.respond_to(request) } } diff --git a/core/lib/src/erased.rs b/core/lib/src/erased.rs index 964f954dda..2cd86bbf8a 100644 --- a/core/lib/src/erased.rs +++ b/core/lib/src/erased.rs @@ -1,4 +1,4 @@ -use std::io; +use std::{fmt, io}; use std::mem::transmute; use std::pin::Pin; use std::sync::Arc; @@ -8,6 +8,7 @@ use futures::future::BoxFuture; use http::request::Parts; use tokio::io::{AsyncRead, ReadBuf}; +use crate::catcher::TypedError; use crate::data::{Data, IoHandler, RawStream}; use crate::{Request, Response, Rocket, Orbit}; @@ -22,6 +23,31 @@ macro_rules! static_assert_covariance { ) } +pub struct ErrorBox { + value: Option>>, +} + +impl ErrorBox { + pub fn new() -> Self { + Self { value: None } + } + pub fn write<'r>(&mut self, error: Box + 'r>) -> &'r dyn TypedError<'r> { + assert!(self.value.is_none()); + self.value = Some(unsafe { transmute(error) }); + let val: &dyn TypedError<'static> = self.value.as_ref().unwrap().as_ref(); + unsafe { transmute(val) } + } +} + +impl fmt::Debug for ErrorBox { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.value { + Some(v) => write!(f, "Some(<{}>)", v.name()), + None => write!(f, "None"), + } + } +} + #[derive(Debug)] pub struct ErasedRequest { // XXX: SAFETY: This (dependent) field must come first due to drop order! @@ -38,6 +64,8 @@ impl Drop for ErasedRequest { pub struct ErasedResponse { // XXX: SAFETY: This (dependent) field must come first due to drop order! response: Response<'static>, + // XXX: SAFETY: This (dependent) field must come second due to drop order! + error: ErrorBox, _request: Arc, } @@ -94,6 +122,7 @@ impl ErasedRequest { T, &'r Rocket, &'r Request<'r>, + &'r mut ErrorBox, Data<'r> ) -> BoxFuture<'r, Response<'r>>, ) -> ErasedResponse @@ -102,6 +131,7 @@ impl ErasedRequest { { let mut data: Data<'_> = Data::from(raw_stream); let mut parent = Arc::new(self); + let mut error = ErrorBox { value: None }; let token: T = { let parent: &mut ErasedRequest = Arc::get_mut(&mut parent).unwrap(); let rocket: &Rocket = &parent._rocket; @@ -116,11 +146,12 @@ impl ErasedRequest { let parent: &'static ErasedRequest = unsafe { transmute(parent) }; let rocket: &Rocket = &parent._rocket; let request: &Request<'_> = &parent.request; - dispatch(token, rocket, request, data).await + dispatch(token, rocket, request, unsafe { transmute(&mut error) }, data).await }; ErasedResponse { _request: parent, + error, response, } } diff --git a/core/lib/src/fairing/fairings.rs b/core/lib/src/fairing/fairings.rs index 57bd18121e..0b617955fb 100644 --- a/core/lib/src/fairing/fairings.rs +++ b/core/lib/src/fairing/fairings.rs @@ -1,3 +1,4 @@ +use crate::catcher::TypedError; use crate::{Rocket, Request, Response, Data, Build, Orbit}; use crate::fairing::{Fairing, Info, Kind}; @@ -13,6 +14,7 @@ pub struct Fairings { ignite: Vec, liftoff: Vec, request: Vec, + request_filter: Vec, response: Vec, shutdown: Vec, } @@ -42,6 +44,7 @@ impl Fairings { self.ignite.iter() .chain(self.liftoff.iter()) .chain(self.request.iter()) + .chain(self.request_filter.iter()) .chain(self.response.iter()) .chain(self.shutdown.iter()) } @@ -112,6 +115,7 @@ impl Fairings { if this_info.kind.is(Kind::Ignite) { self.ignite.push(index); } if this_info.kind.is(Kind::Liftoff) { self.liftoff.push(index); } if this_info.kind.is(Kind::Request) { self.request.push(index); } + if this_info.kind.is(Kind::RequestFilter) { self.request_filter.push(index); } if this_info.kind.is(Kind::Response) { self.response.push(index); } if this_info.kind.is(Kind::Shutdown) { self.shutdown.push(index); } } @@ -161,6 +165,16 @@ impl Fairings { } } + #[inline(always)] + pub async fn handle_request_filter<'r>(&self, req: &'r Request<'_>, data: &mut Data<'_>) + -> Result<(), Box + 'r>> + { + for fairing in iter!(self.request_filter) { + fairing.on_request_filter(req, data).await?; + } + Ok(()) + } + #[inline(always)] pub async fn handle_response<'r>(&self, req: &'r Request<'_>, res: &mut Response<'r>) { for fairing in iter!(self.response) { diff --git a/core/lib/src/fairing/info_kind.rs b/core/lib/src/fairing/info_kind.rs index 74ab3a4827..3fd28e6972 100644 --- a/core/lib/src/fairing/info_kind.rs +++ b/core/lib/src/fairing/info_kind.rs @@ -64,15 +64,18 @@ impl Kind { /// `Kind` flag representing a request for a 'request' callback. pub const Request: Kind = Kind(1 << 2); + /// `Kind` flag representing a request for a 'request' callback. + pub const RequestFilter: Kind = Kind(1 << 3); + /// `Kind` flag representing a request for a 'response' callback. - pub const Response: Kind = Kind(1 << 3); + pub const Response: Kind = Kind(1 << 4); /// `Kind` flag representing a request for a 'shutdown' callback. - pub const Shutdown: Kind = Kind(1 << 4); + pub const Shutdown: Kind = Kind(1 << 5); /// `Kind` flag representing a /// [singleton](crate::fairing::Fairing#singletons) fairing. - pub const Singleton: Kind = Kind(1 << 5); + pub const Singleton: Kind = Kind(1 << 6); /// Returns `true` if `self` is a superset of `other`. In other words, /// returns `true` if all of the kinds in `other` are also in `self`. @@ -144,6 +147,7 @@ impl std::fmt::Display for Kind { write("ignite", Kind::Ignite)?; write("liftoff", Kind::Liftoff)?; write("request", Kind::Request)?; + write("request_filter", Kind::RequestFilter)?; write("response", Kind::Response)?; write("shutdown", Kind::Shutdown)?; write("singleton", Kind::Singleton) diff --git a/core/lib/src/fairing/mod.rs b/core/lib/src/fairing/mod.rs index 28a4e58d54..b52c57f383 100644 --- a/core/lib/src/fairing/mod.rs +++ b/core/lib/src/fairing/mod.rs @@ -51,6 +51,7 @@ use std::any::Any; +use crate::catcher::TypedError; use crate::{Rocket, Request, Response, Data, Build, Orbit}; mod fairings; @@ -503,6 +504,24 @@ pub trait Fairing: Send + Sync + AsAny + 'static { /// The default implementation of this method does nothing. async fn on_request(&self, _req: &mut Request<'_>, _data: &mut Data<'_>) {} + /// The request filter callback. + /// + /// See [Fairing Callbacks](#request) for complete semantics. + /// + /// This method is called when a new request is received if `Kind::RequestFilter` + /// is in the `kind` field of the `Info` structure for this fairing. The + /// `&Request` parameter is the incoming request, and the `&Data` + /// parameter is the incoming data in the request. + /// + /// ## Default Implementation + /// + /// The default implementation of this method does nothing. + async fn on_request_filter<'r>(&self, _req: &'r Request<'_>, _data: &mut Data<'_>) + -> Result<(), Box + 'r>> + { + Ok (()) + } + /// The response callback. /// /// See [Fairing Callbacks](#response) for complete semantics. diff --git a/core/lib/src/fs/named_file.rs b/core/lib/src/fs/named_file.rs index d4eed82a92..5aa5198274 100644 --- a/core/lib/src/fs/named_file.rs +++ b/core/lib/src/fs/named_file.rs @@ -152,7 +152,7 @@ impl NamedFile { /// you would like to stream a file with a different Content-Type than that /// implied by its extension, use a [`File`] directly. impl<'r> Responder<'r, 'static> for NamedFile { - fn respond_to(self, req: &'r Request<'_>) -> response::Result<'static> { + fn respond_to(self, req: &'r Request<'_>) -> response::Result<'r, 'static> { let mut response = self.1.respond_to(req)?; if let Some(ext) = self.0.extension() { if let Some(ct) = ContentType::from_extension(&ext.to_string_lossy()) { diff --git a/core/lib/src/fs/server.rs b/core/lib/src/fs/server.rs index 2f6efe7a26..2d5e060df5 100644 --- a/core/lib/src/fs/server.rs +++ b/core/lib/src/fs/server.rs @@ -3,6 +3,7 @@ use std::path::{Path, PathBuf}; use std::sync::Arc; use std::borrow::Cow; +use crate::catcher::TypedError; use crate::{response, Data, Request, Response}; use crate::outcome::IntoOutcome; use crate::http::{uri::Segments, HeaderMap, Method, ContentType, Status}; @@ -351,7 +352,7 @@ impl Handler for FileServer { None => return Outcome::forward(data, Status::NotFound), }; - outcome.or_forward((data, status)) + outcome.or_forward((data, Box::new(status) as Box + 'r>)) } } @@ -390,7 +391,7 @@ struct NamedFile<'r> { // Do we want to allow the user to rewrite the Content-Type? impl<'r> Responder<'r, 'r> for NamedFile<'r> { - fn respond_to(self, _: &'r Request<'_>) -> response::Result<'r> { + fn respond_to(self, _: &'r Request<'_>) -> response::Result<'r, 'r> { let mut response = Response::new(); response.set_header_map(self.headers); if !response.headers().contains("Content-Type") { diff --git a/core/lib/src/lifecycle.rs b/core/lib/src/lifecycle.rs index cf76e3be26..62ec4b1c97 100644 --- a/core/lib/src/lifecycle.rs +++ b/core/lib/src/lifecycle.rs @@ -1,22 +1,27 @@ -use futures::future::{FutureExt, Future}; +use futures::future::{Future, FutureExt}; -use crate::trace::Trace; -use crate::util::Formatter; +use crate::catcher::TypedError; use crate::data::IoHandler; -use crate::http::{Method, Status, Header}; -use crate::outcome::Outcome; +use crate::erased::ErrorBox; use crate::form::Form; -use crate::{route, catcher, Rocket, Orbit, Request, Response, Data}; +use crate::http::{Header, Method, Status}; +use crate::outcome::Outcome; +use crate::trace::Trace; +use crate::util::Formatter; +use crate::{catcher, route, Data, Orbit, Request, Response, Rocket}; // A token returned to force the execution of one method before another. pub(crate) struct RequestToken; async fn catch_handle(name: Option<&str>, run: F) -> Option - where F: FnOnce() -> Fut, Fut: Future, +where + F: FnOnce() -> Fut, + Fut: Future, { macro_rules! panic_info { ($name:expr, $e:expr) => {{ - error!(handler = name.as_ref().map(display), + error!( + handler = name.as_ref().map(display), "handler panicked\n\ This is an application bug.\n\ A panic in Rust must be treated as an exceptional event.\n\ @@ -25,10 +30,11 @@ async fn catch_handle(name: Option<&str>, run: F) -> Option Panics will degrade application performance.\n\ Instead of panicking, return `Option` and/or `Result`.\n\ Values of either type can be returned directly from handlers.\n\ - A panic is treated as an internal server error."); + A panic is treated as an internal server error." + ); $e - }} + }}; } let run = std::panic::AssertUnwindSafe(run); @@ -54,13 +60,14 @@ impl Rocket { pub(crate) async fn preprocess( &self, req: &mut Request<'_>, - data: &mut Data<'_> + data: &mut Data<'_>, ) -> RequestToken { // Check if this is a form and if the form contains the special _method // field which we use to reinterpret the request's method. if req.method() == Method::Post && req.content_type().map_or(false, |v| v.is_form()) { let peek_buffer = data.peek(32).await; - let method = std::str::from_utf8(peek_buffer).ok() + let method = std::str::from_utf8(peek_buffer) + .ok() .and_then(|raw_form| Form::values(raw_form).next()) .filter(|field| field.name == "_method") .and_then(|field| field.value.parse().ok()); @@ -92,28 +99,47 @@ impl Rocket { &'s self, _token: RequestToken, request: &'r Request<'s>, - data: Data<'r>, + error_box: &mut ErrorBox, + mut data: Data<'r>, // io_stream: impl Future> + Send, ) -> Response<'r> { // Remember if the request is `HEAD` for later body stripping. let was_head_request = request.method() == Method::Head; - // Route the request and run the user's handlers. - let mut response = match self.route(request, data).await { - Outcome::Success(response) => response, - Outcome::Forward((data, _)) if request.method() == Method::Head => { - tracing::Span::current().record("autohandled", true); + // Run request filter + let mut response = if let Err(error) = self.fairings.handle_request_filter(request, &mut data).await { + let error = error_box.write(error); + self.dispatch_error(error, request).await + } else { + // Route the request and run the user's handlers. + match self.route(request, data).await { + Outcome::Success(response) => response, + Outcome::Forward((data, _)) if request.method() == Method::Head => { + tracing::Span::current().record("autohandled", true); - // Dispatch the request again with Method `GET`. - request._set_method(Method::Get); - match self.route(request, data).await { - Outcome::Success(response) => response, - Outcome::Error(status) => self.dispatch_error(status, request).await, - Outcome::Forward((_, status)) => self.dispatch_error(status, request).await, + // Dispatch the request again with Method `GET`. + request._set_method(Method::Get); + match self.route(request, data).await { + Outcome::Success(response) => response, + Outcome::Error(error) => { + let error = error_box.write(error); + self.dispatch_error(error, request).await + } + Outcome::Forward((_, error)) => { + let error = error_box.write(error); + self.dispatch_error(error, request).await + } + } + } + Outcome::Forward((_, error)) => { + let error = error_box.write(error); + self.dispatch_error(error, request).await + } + Outcome::Error(error) => { + let error = error_box.write(error); + self.dispatch_error(error, request).await } } - Outcome::Forward((_, status)) => self.dispatch_error(status, request).await, - Outcome::Error(status) => self.dispatch_error(status, request).await, }; // Set the cookies. Note that error responses will only include cookies @@ -201,21 +227,22 @@ impl Rocket { ) -> route::Outcome<'r> { // Go through all matching routes until we fail or succeed or run out of // routes to try, in which case we forward with the last status. - let mut status = Status::NotFound; + let mut status: Box + 'r> = Box::new(Status::NotFound); for route in self.router.route(request) { // Retrieve and set the requests parameters. route.trace_info(); request.set_route(route); let name = route.name.as_deref(); - let outcome = catch_handle(name, || route.handler.handle(request, data)).await - .unwrap_or(Outcome::Error(Status::InternalServerError)); + let outcome = catch_handle(name, || route.handler.handle(request, data)) + .await + .unwrap_or(Outcome::Error(Box::new(Status::InternalServerError))); // Check if the request processing completed (Some) or if the // request needs to be forwarded. If it does, continue the loop outcome.trace_info(); match outcome { - o@Outcome::Success(_) | o@Outcome::Error(_) => return o, + o @ Outcome::Success(_) | o @ Outcome::Error(_) => return o, Outcome::Forward(forwarded) => (data, status) = forwarded, } } @@ -230,28 +257,31 @@ impl Rocket { // // On catcher error, the 500 error catcher is attempted. If _that_ errors, // the (infallible) default 500 error cather is used. - #[tracing::instrument("catching", skip_all, fields(status = status.code, uri = %req.uri()))] + #[tracing::instrument("catching", skip_all, fields(status = error.status().code, uri = %req.uri()))] pub(crate) async fn dispatch_error<'r, 's: 'r>( &'s self, - mut status: Status, - req: &'r Request<'s> + mut error: &'r dyn TypedError<'r>, + req: &'r Request<'s>, ) -> Response<'r> { // We may wish to relax this in the future. req.cookies().reset_delta(); loop { // Dispatch to the `status` catcher. - match self.invoke_catcher(status, req).await { + match self.invoke_catcher(error, req).await { Ok(r) => return r, // If the catcher failed, try `500` catcher, unless this is it. - Err(e) if status.code != 500 => { - warn!(status = e.map(|r| r.code), "catcher failed: trying 500 catcher"); - status = Status::InternalServerError; + Err(e) if error.status().code != 500 => { + warn!( + status = e.map(|r| r.code), + "catcher failed: trying 500 catcher" + ); + error = &Status::InternalServerError; } // The 500 catcher failed. There's no recourse. Use default. Err(e) => { error!(status = e.map(|r| r.code), "500 catcher failed"); - return catcher::default_handler(Status::InternalServerError, req); + return catcher::default_handler(Status::InternalServerError, &(), req); } } } @@ -269,18 +299,35 @@ impl Rocket { /// handler panicked while executing. async fn invoke_catcher<'s, 'r: 's>( &'s self, - status: Status, - req: &'r Request<'s> + error: &'r dyn TypedError<'r>, + req: &'r Request<'s>, ) -> Result, Option> { - if let Some(catcher) = self.router.catch(status, req) { + const MAX_CALLS_TO_SOURCE: usize = 5; + let status = error.status(); + let iter = std::iter::successors(Some(error), |e| e.source()) + .take(MAX_CALLS_TO_SOURCE) + .flat_map(|e| [ + // Catchers with matching status and typeid + self.router.catch(status, Some(e), req), + // Catchers with `default` status and typeid + self.router.catch_any(status, Some(e), req) + ].into_iter().filter_map(|c| c)) + .chain([ + // Catcher with matching status and no typeid + self.router.catch(status, None, req), + // Catcher with `default` status and no typeid + self.router.catch_any(status, None, req) + ].into_iter().filter_map(|c| c)); + // Select lowest rank of (up to) 12 matching catchers. + if let Some(catcher) = iter.min_by_key(|c| c.rank) { catcher.trace_info(); - catch_handle(catcher.name.as_deref(), || catcher.handler.handle(status, req)).await + catch_handle(catcher.name.as_deref(), || catcher.handler.handle(status, error, req)).await .map(|result| result.map_err(Some)) .unwrap_or_else(|| Err(None)) } else { - info!(name: "catcher", name = "rocket::default", "uri.base" = "/", code = status.code, + info!(name: "catcher", name = "rocket::default", "uri.base" = "/", code = error.status().code, "no registered catcher: using Rocket default"); - Ok(catcher::default_handler(status, req)) + Ok(catcher::default_handler(status, error, req)) } } } diff --git a/core/lib/src/local/asynchronous/request.rs b/core/lib/src/local/asynchronous/request.rs index 4c85c02024..de0d4cd957 100644 --- a/core/lib/src/local/asynchronous/request.rs +++ b/core/lib/src/local/asynchronous/request.rs @@ -85,8 +85,8 @@ impl<'c> LocalRequest<'c> { // _shouldn't_ error. Check that now and error only if not. if self.inner().uri() == invalid { error!("invalid request URI: {:?}", invalid.path()); - return LocalResponse::new(self.request, move |req| { - rocket.dispatch_error(Status::BadRequest, req) + return LocalResponse::new(self.request, move |req, error_box| { + rocket.dispatch_error(error_box.write(Box::new(Status::BadRequest)), req) }).await } } @@ -94,8 +94,8 @@ impl<'c> LocalRequest<'c> { // Actually dispatch the request. let mut data = Data::local(self.data); let token = rocket.preprocess(&mut self.request, &mut data).await; - let response = LocalResponse::new(self.request, move |req| { - rocket.dispatch(token, req, data) + let response = LocalResponse::new(self.request, move |req, error_box| { + rocket.dispatch(token, req, error_box, data) }).await; // If the client is tracking cookies, updates the internal cookie jar diff --git a/core/lib/src/local/asynchronous/response.rs b/core/lib/src/local/asynchronous/response.rs index 06ae18e3b7..93313332e5 100644 --- a/core/lib/src/local/asynchronous/response.rs +++ b/core/lib/src/local/asynchronous/response.rs @@ -1,9 +1,11 @@ use std::io; use std::future::Future; +use std::mem::transmute; use std::{pin::Pin, task::{Context, Poll}}; use tokio::io::{AsyncRead, ReadBuf}; +use crate::erased::ErrorBox; use crate::http::CookieJar; use crate::{Request, Response}; @@ -55,6 +57,8 @@ use crate::{Request, Response}; pub struct LocalResponse<'c> { // XXX: SAFETY: This (dependent) field must come first due to drop order! response: Response<'c>, + // XXX: SAFETY: This (dependent) field must come second due to drop order! + error: ErrorBox, cookies: CookieJar<'c>, _request: Box>, } @@ -65,7 +69,7 @@ impl Drop for LocalResponse<'_> { impl<'c> LocalResponse<'c> { pub(crate) fn new(req: Request<'c>, f: F) -> impl Future> - where F: FnOnce(&'c Request<'c>) -> O + Send, + where F: FnOnce(&'c Request<'c>, &'c mut ErrorBox) -> O + Send, O: Future> + Send { // `LocalResponse` is a self-referential structure. In particular, @@ -91,19 +95,20 @@ impl<'c> LocalResponse<'c> { // away as `'_`, ensuring it is not used for any output value. let boxed_req = Box::new(req); let request: &'c Request<'c> = unsafe { &*(&*boxed_req as *const _) }; + let mut error_box = ErrorBox::new(); async move { // NOTE: The cookie jar `secure` state will not reflect the last // known value in `request.cookies()`. This is okay: new cookies // should never be added to the resulting jar which is the only time // the value is used to set cookie defaults. - let response: Response<'c> = f(request).await; + let response: Response<'c> = f(request, unsafe { transmute(&mut error_box) }).await; let mut cookies = CookieJar::new(None, request.rocket()); for cookie in response.cookies() { cookies.add_original(cookie.into_owned()); } - LocalResponse { _request: boxed_req, cookies, response, } + LocalResponse { _request: boxed_req, error: error_box, cookies, response, } } } } diff --git a/core/lib/src/outcome.rs b/core/lib/src/outcome.rs index 35521aa36a..6a07080272 100644 --- a/core/lib/src/outcome.rs +++ b/core/lib/src/outcome.rs @@ -86,6 +86,7 @@ //! a type of `Option`. If an `Outcome` is a `Forward`, the `Option` will be //! `None`. +use crate::catcher::TypedError; use crate::{route, request, response}; use crate::data::{self, Data, FromData}; use crate::http::Status; @@ -788,9 +789,9 @@ impl IntoOutcome> for Result { } } -impl<'r, 'o: 'r> IntoOutcome> for response::Result<'o> { +impl<'r, 'o: 'r> IntoOutcome> for response::Result<'r, 'o> { type Error = (); - type Forward = (Data<'r>, Status); + type Forward = (Data<'r>, Box + 'r>); #[inline] fn or_error(self, _: ()) -> route::Outcome<'r> { @@ -801,7 +802,7 @@ impl<'r, 'o: 'r> IntoOutcome> for response::Result<'o> { } #[inline] - fn or_forward(self, (data, forward): (Data<'r>, Status)) -> route::Outcome<'r> { + fn or_forward(self, (data, forward): (Data<'r>, Box + 'r>)) -> route::Outcome<'r> { match self { Ok(val) => Success(val), Err(_) => Forward((data, forward)) diff --git a/core/lib/src/response/content.rs b/core/lib/src/response/content.rs index 68d7a33d0f..9c819a0723 100644 --- a/core/lib/src/response/content.rs +++ b/core/lib/src/response/content.rs @@ -58,7 +58,7 @@ macro_rules! ctrs { /// Sets the Content-Type of the response then delegates the /// remainder of the response to the wrapped responder. impl<'r, 'o: 'r, R: Responder<'r, 'o>> Responder<'r, 'o> for $name { - fn respond_to(self, req: &'r Request<'_>) -> response::Result<'o> { + fn respond_to(self, req: &'r Request<'_>) -> response::Result<'r, 'o> { (ContentType::$ct, self.0).respond_to(req) } } @@ -78,7 +78,7 @@ ctrs! { } impl<'r, 'o: 'r, R: Responder<'r, 'o>> Responder<'r, 'o> for (ContentType, R) { - fn respond_to(self, req: &'r Request<'_>) -> response::Result<'o> { + fn respond_to(self, req: &'r Request<'_>) -> response::Result<'r, 'o> { Response::build() .merge(self.1.respond_to(req)?) .header(self.0) diff --git a/core/lib/src/response/debug.rs b/core/lib/src/response/debug.rs index a7d3e612a0..e3295bb5a6 100644 --- a/core/lib/src/response/debug.rs +++ b/core/lib/src/response/debug.rs @@ -75,17 +75,17 @@ impl From for Debug { } impl<'r, E: std::fmt::Debug> Responder<'r, 'static> for Debug { - fn respond_to(self, _: &'r Request<'_>) -> response::Result<'static> { + fn respond_to(self, _: &'r Request<'_>) -> response::Result<'r, 'static> { let type_name = std::any::type_name::(); info!(type_name, value = ?self.0, "debug response (500)"); - Err(Status::InternalServerError) + Err(Box::new(Status::InternalServerError)) } } /// Prints a warning with the error and forwards to the `500` error catcher. impl<'r> Responder<'r, 'static> for std::io::Error { - fn respond_to(self, _: &'r Request<'_>) -> response::Result<'static> { + fn respond_to(self, _: &'r Request<'_>) -> response::Result<'r, 'static> { warn!("i/o error response: {self}"); - Err(Status::InternalServerError) + Err(Box::new(Status::InternalServerError)) } } diff --git a/core/lib/src/response/flash.rs b/core/lib/src/response/flash.rs index 279ec854b6..1e495c97eb 100644 --- a/core/lib/src/response/flash.rs +++ b/core/lib/src/response/flash.rs @@ -189,7 +189,7 @@ impl Flash { /// response handling to the wrapped responder. As a result, the `Outcome` of /// the response is the `Outcome` of the wrapped `Responder`. impl<'r, 'o: 'r, R: Responder<'r, 'o>> Responder<'r, 'o> for Flash { - fn respond_to(self, req: &'r Request<'_>) -> response::Result<'o> { + fn respond_to(self, req: &'r Request<'_>) -> response::Result<'r, 'o> { req.cookies().add(self.cookie()); self.inner.respond_to(req) } diff --git a/core/lib/src/response/mod.rs b/core/lib/src/response/mod.rs index 71f0ff6980..dfd4257c07 100644 --- a/core/lib/src/response/mod.rs +++ b/core/lib/src/response/mod.rs @@ -28,6 +28,8 @@ pub mod stream; #[doc(hidden)] pub use rocket_codegen::Responder; +use crate::catcher::TypedError; + pub use self::response::{Response, Builder}; pub use self::body::Body; pub use self::responder::Responder; @@ -36,4 +38,4 @@ pub use self::flash::Flash; pub use self::debug::Debug; /// Type alias for the `Result` of a [`Responder::respond_to()`] call. -pub type Result<'r> = std::result::Result, crate::http::Status>; +pub type Result<'r, 'o> = std::result::Result, Box>>; diff --git a/core/lib/src/response/redirect.rs b/core/lib/src/response/redirect.rs index 7b5cf5d2ec..7f7a4d37bb 100644 --- a/core/lib/src/response/redirect.rs +++ b/core/lib/src/response/redirect.rs @@ -157,7 +157,7 @@ impl Redirect { /// value used to create the `Responder` is an invalid URI, an error of /// `Status::InternalServerError` is returned. impl<'r> Responder<'r, 'static> for Redirect { - fn respond_to(self, _: &'r Request<'_>) -> response::Result<'static> { + fn respond_to(self, _: &'r Request<'_>) -> response::Result<'r, 'static> { if let Some(uri) = self.1 { Response::build() .status(self.0) @@ -165,7 +165,7 @@ impl<'r> Responder<'r, 'static> for Redirect { .ok() } else { error!("Invalid URI used for redirect."); - Err(Status::InternalServerError) + Err(Box::new(Status::InternalServerError)) } } } diff --git a/core/lib/src/response/responder.rs b/core/lib/src/response/responder.rs index f31262c7fc..54d878c814 100644 --- a/core/lib/src/response/responder.rs +++ b/core/lib/src/response/responder.rs @@ -302,13 +302,13 @@ pub trait Responder<'r, 'o: 'r> { /// returned, the error catcher for the given status is retrieved and called /// to generate a final error response, which is then written out to the /// client. - fn respond_to(self, request: &'r Request<'_>) -> response::Result<'o>; + fn respond_to(self, request: &'r Request<'_>) -> response::Result<'r, 'o>; } /// Returns a response with Content-Type `text/plain` and a fixed-size body /// containing the string `self`. Always returns `Ok`. impl<'r, 'o: 'r> Responder<'r, 'o> for &'o str { - fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o> { + fn respond_to(self, _: &'r Request<'_>) -> response::Result<'r, 'o> { Response::build() .header(ContentType::Plain) .sized_body(self.len(), Cursor::new(self)) @@ -319,7 +319,7 @@ impl<'r, 'o: 'r> Responder<'r, 'o> for &'o str { /// Returns a response with Content-Type `text/plain` and a fixed-size body /// containing the string `self`. Always returns `Ok`. impl<'r> Responder<'r, 'static> for String { - fn respond_to(self, _: &'r Request<'_>) -> response::Result<'static> { + fn respond_to(self, _: &'r Request<'_>) -> response::Result<'r, 'static> { Response::build() .header(ContentType::Plain) .sized_body(self.len(), Cursor::new(self)) @@ -339,7 +339,7 @@ impl AsRef<[u8]> for DerefRef where T::Target: AsRef<[u8] /// Returns a response with Content-Type `text/plain` and a fixed-size body /// containing the string `self`. Always returns `Ok`. impl<'r> Responder<'r, 'static> for Arc { - fn respond_to(self, _: &'r Request<'_>) -> response::Result<'static> { + fn respond_to(self, _: &'r Request<'_>) -> response::Result<'r, 'static> { Response::build() .header(ContentType::Plain) .sized_body(self.len(), Cursor::new(DerefRef(self))) @@ -350,7 +350,7 @@ impl<'r> Responder<'r, 'static> for Arc { /// Returns a response with Content-Type `text/plain` and a fixed-size body /// containing the string `self`. Always returns `Ok`. impl<'r> Responder<'r, 'static> for Box { - fn respond_to(self, _: &'r Request<'_>) -> response::Result<'static> { + fn respond_to(self, _: &'r Request<'_>) -> response::Result<'r, 'static> { Response::build() .header(ContentType::Plain) .sized_body(self.len(), Cursor::new(DerefRef(self))) @@ -361,7 +361,7 @@ impl<'r> Responder<'r, 'static> for Box { /// Returns a response with Content-Type `application/octet-stream` and a /// fixed-size body containing the data in `self`. Always returns `Ok`. impl<'r, 'o: 'r> Responder<'r, 'o> for &'o [u8] { - fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o> { + fn respond_to(self, _: &'r Request<'_>) -> response::Result<'r, 'o> { Response::build() .header(ContentType::Binary) .sized_body(self.len(), Cursor::new(self)) @@ -372,7 +372,7 @@ impl<'r, 'o: 'r> Responder<'r, 'o> for &'o [u8] { /// Returns a response with Content-Type `application/octet-stream` and a /// fixed-size body containing the data in `self`. Always returns `Ok`. impl<'r> Responder<'r, 'static> for Vec { - fn respond_to(self, _: &'r Request<'_>) -> response::Result<'static> { + fn respond_to(self, _: &'r Request<'_>) -> response::Result<'r, 'static> { Response::build() .header(ContentType::Binary) .sized_body(self.len(), Cursor::new(self)) @@ -383,7 +383,7 @@ impl<'r> Responder<'r, 'static> for Vec { /// Returns a response with Content-Type `application/octet-stream` and a /// fixed-size body containing the data in `self`. Always returns `Ok`. impl<'r> Responder<'r, 'static> for Arc<[u8]> { - fn respond_to(self, _: &'r Request<'_>) -> response::Result<'static> { + fn respond_to(self, _: &'r Request<'_>) -> response::Result<'r, 'static> { Response::build() .header(ContentType::Binary) .sized_body(self.len(), Cursor::new(self)) @@ -394,7 +394,7 @@ impl<'r> Responder<'r, 'static> for Arc<[u8]> { /// Returns a response with Content-Type `application/octet-stream` and a /// fixed-size body containing the data in `self`. Always returns `Ok`. impl<'r> Responder<'r, 'static> for Box<[u8]> { - fn respond_to(self, _: &'r Request<'_>) -> response::Result<'static> { + fn respond_to(self, _: &'r Request<'_>) -> response::Result<'r, 'static> { Response::build() .header(ContentType::Binary) .sized_body(self.len(), Cursor::new(self)) @@ -438,7 +438,7 @@ impl<'r> Responder<'r, 'static> for Box<[u8]> { /// } /// ``` impl<'r, 'o: 'r, T: Responder<'r, 'o> + Sized> Responder<'r, 'o> for Box { - fn respond_to(self, req: &'r Request<'_>) -> response::Result<'o> { + fn respond_to(self, req: &'r Request<'_>) -> response::Result<'r, 'o> { let inner = *self; inner.respond_to(req) } @@ -446,21 +446,21 @@ impl<'r, 'o: 'r, T: Responder<'r, 'o> + Sized> Responder<'r, 'o> for Box { /// Returns a response with a sized body for the file. Always returns `Ok`. impl<'r> Responder<'r, 'static> for File { - fn respond_to(self, req: &'r Request<'_>) -> response::Result<'static> { + fn respond_to(self, req: &'r Request<'_>) -> response::Result<'r, 'static> { tokio::fs::File::from(self).respond_to(req) } } /// Returns a response with a sized body for the file. Always returns `Ok`. impl<'r> Responder<'r, 'static> for tokio::fs::File { - fn respond_to(self, _: &'r Request<'_>) -> response::Result<'static> { + fn respond_to(self, _: &'r Request<'_>) -> response::Result<'r, 'static> { Response::build().sized_body(None, self).ok() } } /// Returns an empty, default `Response`. Always returns `Ok`. impl<'r> Responder<'r, 'static> for () { - fn respond_to(self, _: &'r Request<'_>) -> response::Result<'static> { + fn respond_to(self, _: &'r Request<'_>) -> response::Result<'r, 'static> { Ok(Response::new()) } } @@ -469,7 +469,7 @@ impl<'r> Responder<'r, 'static> for () { impl<'r, 'o: 'r, R: ?Sized + ToOwned> Responder<'r, 'o> for std::borrow::Cow<'o, R> where &'o R: Responder<'r, 'o> + 'o, ::Owned: Responder<'r, 'o> + 'r { - fn respond_to(self, req: &'r Request<'_>) -> response::Result<'o> { + fn respond_to(self, req: &'r Request<'_>) -> response::Result<'r, 'o> { match self { std::borrow::Cow::Borrowed(b) => b.respond_to(req), std::borrow::Cow::Owned(o) => o.respond_to(req), @@ -480,13 +480,13 @@ impl<'r, 'o: 'r, R: ?Sized + ToOwned> Responder<'r, 'o> for std::borrow::Cow<'o, /// If `self` is `Some`, responds with the wrapped `Responder`. Otherwise prints /// a warning message and returns an `Err` of `Status::NotFound`. impl<'r, 'o: 'r, R: Responder<'r, 'o>> Responder<'r, 'o> for Option { - fn respond_to(self, req: &'r Request<'_>) -> response::Result<'o> { + fn respond_to(self, req: &'r Request<'_>) -> response::Result<'r, 'o> { match self { Some(r) => r.respond_to(req), None => { let type_name = std::any::type_name::(); debug!(type_name, "`Option` responder returned `None`"); - Err(Status::NotFound) + Err(Box::new(Status::NotFound)) }, } } @@ -497,7 +497,7 @@ impl<'r, 'o: 'r, R: Responder<'r, 'o>> Responder<'r, 'o> for Option { impl<'r, 'o: 'r, 't: 'o, 'e: 'o, T, E> Responder<'r, 'o> for Result where T: Responder<'r, 't>, E: Responder<'r, 'e> { - fn respond_to(self, req: &'r Request<'_>) -> response::Result<'o> { + fn respond_to(self, req: &'r Request<'_>) -> response::Result<'r, 'o> { match self { Ok(responder) => responder.respond_to(req), Err(responder) => responder.respond_to(req), @@ -510,7 +510,7 @@ impl<'r, 'o: 'r, 't: 'o, 'e: 'o, T, E> Responder<'r, 'o> for Result impl<'r, 'o: 'r, 't: 'o, 'e: 'o, T, E> Responder<'r, 'o> for either::Either where T: Responder<'r, 't>, E: Responder<'r, 'e> { - fn respond_to(self, req: &'r Request<'_>) -> response::Result<'o> { + fn respond_to(self, req: &'r Request<'_>) -> response::Result<'r, 'o> { match self { either::Either::Left(r) => r.respond_to(req), either::Either::Right(r) => r.respond_to(req), @@ -533,9 +533,9 @@ impl<'r, 'o: 'r, 't: 'o, 'e: 'o, T, E> Responder<'r, 'o> for either::Either Responder<'r, 'static> for Status { - fn respond_to(self, _: &'r Request<'_>) -> response::Result<'static> { + fn respond_to(self, _: &'r Request<'_>) -> response::Result<'r, 'static> { match self.class() { - StatusClass::ClientError | StatusClass::ServerError => Err(self), + StatusClass::ClientError | StatusClass::ServerError => Err(Box::new(self)), StatusClass::Success if self.code < 206 => { Response::build().status(self).ok() } @@ -547,7 +547,8 @@ impl<'r> Responder<'r, 'static> for Status { "invalid status used as responder\n\ status must be one of 100, 200..=205, 400..=599"); - Err(Status::InternalServerError) + // TODO: Typed: Invalid status + Err(Box::new(Status::InternalServerError)) } } } diff --git a/core/lib/src/response/status.rs b/core/lib/src/response/status.rs index 935fe88fdf..0095c48144 100644 --- a/core/lib/src/response/status.rs +++ b/core/lib/src/response/status.rs @@ -29,6 +29,8 @@ use std::hash::{Hash, Hasher}; use std::collections::hash_map::DefaultHasher; use std::borrow::Cow; +use transient::Transient; + use crate::request::Request; use crate::response::{self, Responder, Response}; use crate::http::Status; @@ -163,7 +165,7 @@ impl Created { /// a hashable `Responder` is provided via [`Created::tagged_body()`]. The `ETag` /// header is set to a hash value of the responder. impl<'r, 'o: 'r, R: Responder<'r, 'o>> Responder<'r, 'o> for Created { - fn respond_to(self, req: &'r Request<'_>) -> response::Result<'o> { + fn respond_to(self, req: &'r Request<'_>) -> response::Result<'r, 'o> { let mut response = Response::build(); if let Some(responder) = self.1 { response.merge(responder.respond_to(req)?); @@ -201,7 +203,7 @@ pub struct NoContent; /// Sets the status code of the response to 204 No Content. impl<'r> Responder<'r, 'static> for NoContent { - fn respond_to(self, _: &'r Request<'_>) -> response::Result<'static> { + fn respond_to(self, _: &'r Request<'_>) -> response::Result<'r, 'static> { Response::build().status(Status::NoContent).ok() } } @@ -231,11 +233,16 @@ impl<'r> Responder<'r, 'static> for NoContent { #[derive(Debug, Clone, PartialEq)] pub struct Custom(pub Status, pub R); +unsafe impl Transient for Custom { + type Static = Custom; + type Transience = R::Transience; +} + /// Sets the status code of the response and then delegates the remainder of the /// response to the wrapped responder. impl<'r, 'o: 'r, R: Responder<'r, 'o>> Responder<'r, 'o> for Custom { #[inline] - fn respond_to(self, req: &'r Request<'_>) -> response::Result<'o> { + fn respond_to(self, req: &'r Request<'_>) -> response::Result<'r, 'o> { Response::build_from(self.1.respond_to(req)?) .status(self.0) .ok() @@ -244,7 +251,7 @@ impl<'r, 'o: 'r, R: Responder<'r, 'o>> Responder<'r, 'o> for Custom { impl<'r, 'o: 'r, R: Responder<'r, 'o>> Responder<'r, 'o> for (Status, R) { #[inline(always)] - fn respond_to(self, request: &'r Request<'_>) -> response::Result<'o> { + fn respond_to(self, request: &'r Request<'_>) -> response::Result<'r, 'o> { Custom(self.0, self.1).respond_to(request) } } @@ -289,7 +296,7 @@ macro_rules! status_response { impl<'r, 'o: 'r, R: Responder<'r, 'o>> Responder<'r, 'o> for $T { #[inline(always)] - fn respond_to(self, req: &'r Request<'_>) -> response::Result<'o> { + fn respond_to(self, req: &'r Request<'_>) -> response::Result<'r, 'o> { Custom(Status::$T, self.0).respond_to(req) } } diff --git a/core/lib/src/response/stream/bytes.rs b/core/lib/src/response/stream/bytes.rs index 52782aa241..3d164289fe 100644 --- a/core/lib/src/response/stream/bytes.rs +++ b/core/lib/src/response/stream/bytes.rs @@ -64,7 +64,7 @@ impl From for ByteStream { impl<'r, S: Stream> Responder<'r, 'r> for ByteStream where S: Send + 'r, S::Item: AsRef<[u8]> + Send + Unpin + 'r { - fn respond_to(self, _: &'r Request<'_>) -> response::Result<'r> { + fn respond_to(self, _: &'r Request<'_>) -> response::Result<'r, 'r> { Response::build() .header(ContentType::Binary) .streamed_body(ReaderStream::from(self.0.map(std::io::Cursor::new))) diff --git a/core/lib/src/response/stream/reader.rs b/core/lib/src/response/stream/reader.rs index d3a3da71bf..d414996d65 100644 --- a/core/lib/src/response/stream/reader.rs +++ b/core/lib/src/response/stream/reader.rs @@ -142,7 +142,7 @@ impl From for ReaderStream { impl<'r, S: Stream> Responder<'r, 'r> for ReaderStream where S: Send + 'r, S::Item: AsyncRead + Send, { - fn respond_to(self, _: &'r Request<'_>) -> response::Result<'r> { + fn respond_to(self, _: &'r Request<'_>) -> response::Result<'r, 'r> { Response::build() .streamed_body(self) .ok() diff --git a/core/lib/src/response/stream/sse.rs b/core/lib/src/response/stream/sse.rs index de24ad2816..145b57ab94 100644 --- a/core/lib/src/response/stream/sse.rs +++ b/core/lib/src/response/stream/sse.rs @@ -569,7 +569,7 @@ impl> From for EventStream { } impl<'r, S: Stream + Send + 'r> Responder<'r, 'r> for EventStream { - fn respond_to(self, _: &'r Request<'_>) -> response::Result<'r> { + fn respond_to(self, _: &'r Request<'_>) -> response::Result<'r, 'r> { Response::build() .header(ContentType::EventStream) .raw_header("Cache-Control", "no-cache") diff --git a/core/lib/src/response/stream/text.rs b/core/lib/src/response/stream/text.rs index 3064e0f0e2..6b37f3e0cd 100644 --- a/core/lib/src/response/stream/text.rs +++ b/core/lib/src/response/stream/text.rs @@ -65,7 +65,7 @@ impl From for TextStream { impl<'r, S: Stream> Responder<'r, 'r> for TextStream where S: Send + 'r, S::Item: AsRef + Send + Unpin + 'r { - fn respond_to(self, _: &'r Request<'_>) -> response::Result<'r> { + fn respond_to(self, _: &'r Request<'_>) -> response::Result<'r, 'r> { struct ByteStr(T); impl> AsRef<[u8]> for ByteStr { diff --git a/core/lib/src/route/handler.rs b/core/lib/src/route/handler.rs index b42d81e0fc..0fb9c82e9c 100644 --- a/core/lib/src/route/handler.rs +++ b/core/lib/src/route/handler.rs @@ -1,10 +1,10 @@ +use crate::catcher::TypedError; use crate::{Request, Data}; use crate::response::{Response, Responder}; -use crate::http::Status; /// Type alias for the return type of a [`Route`](crate::Route)'s /// [`Handler::handle()`]. -pub type Outcome<'r> = crate::outcome::Outcome, Status, (Data<'r>, Status)>; +pub type Outcome<'r> = crate::outcome::Outcome, Box>, (Data<'r>, Box>)>; /// Type alias for the return type of a _raw_ [`Route`](crate::Route)'s /// [`Handler`]. @@ -233,8 +233,8 @@ impl<'r, 'o: 'r> Outcome<'o> { /// } /// ``` #[inline(always)] - pub fn error(code: Status) -> Outcome<'r> { - Outcome::Error(code) + pub fn error>(error: E) -> Outcome<'r> { + Outcome::Error(Box::new(error)) } /// Return an `Outcome` of `Forward` with the data `data` and status @@ -253,8 +253,8 @@ impl<'r, 'o: 'r> Outcome<'o> { /// } /// ``` #[inline(always)] - pub fn forward(data: Data<'r>, status: Status) -> Outcome<'r> { - Outcome::Forward((data, status)) + pub fn forward>(data: Data<'r>, error: E) -> Outcome<'r> { + Outcome::Forward((data, Box::new(error))) } } diff --git a/core/lib/src/router/matcher.rs b/core/lib/src/router/matcher.rs index 5cb5b91831..a76d7e64b8 100644 --- a/core/lib/src/router/matcher.rs +++ b/core/lib/src/router/matcher.rs @@ -1,3 +1,5 @@ +use transient::TypeId; + use crate::{Route, Request, Catcher}; use crate::router::Collide; use crate::http::Status; @@ -133,8 +135,9 @@ impl Catcher { /// let b_count = b.base().segments().filter(|s| !s.is_empty()).count(); /// assert!(b_count > a_count); /// ``` - pub fn matches(&self, status: Status, request: &Request<'_>) -> bool { + pub fn matches(&self, status: Status, ty: Option, request: &Request<'_>) -> bool { self.code.map_or(true, |code| code == status.code) + && self.type_id == ty && self.base().segments().prefix_of(request.uri().path().segments()) } } diff --git a/core/lib/src/router/router.rs b/core/lib/src/router/router.rs index 017486a6e6..2770783a4d 100644 --- a/core/lib/src/router/router.rs +++ b/core/lib/src/router/router.rs @@ -1,6 +1,7 @@ use std::ops::{Deref, DerefMut}; use std::collections::HashMap; +use crate::catcher::TypedError; use crate::request::Request; use crate::http::{Method, Status}; use crate::{Route, Catcher}; @@ -106,22 +107,21 @@ impl Router { // For many catchers, using aho-corasick or similar should be much faster. #[track_caller] - pub fn catch<'r>(&self, status: Status, req: &'r Request<'r>) -> Option<&Catcher> { + pub fn catch<'r>(&self, status: Status, error: Option<&'r dyn TypedError<'r>>, req: &'r Request<'r>) -> Option<&Catcher> { + let ty = error.map(|e| e.trait_obj_typeid()); // Note that catchers are presorted by descending base length. - let explicit = self.catcher_map.get(&Some(status.code)) + self.catcher_map.get(&Some(status.code)) .map(|catchers| catchers.iter().map(|&i| &self.catchers[i])) - .and_then(|mut catchers| catchers.find(|c| c.matches(status, req))); + .and_then(|mut catchers| catchers.find(|c| c.matches(status, ty, req))) + } - let default = self.catcher_map.get(&None) + #[track_caller] + pub fn catch_any<'r>(&self, status: Status, error: Option<&'r dyn TypedError<'r>>, req: &'r Request<'r>) -> Option<&Catcher> { + let ty = error.map(|e| e.trait_obj_typeid()); + // Note that catchers are presorted by descending base length. + self.catcher_map.get(&None) .map(|catchers| catchers.iter().map(|&i| &self.catchers[i])) - .and_then(|mut catchers| catchers.find(|c| c.matches(status, req))); - - match (explicit, default) { - (None, None) => None, - (None, c@Some(_)) | (c@Some(_), None) => c, - (Some(a), Some(b)) if a.rank <= b.rank => Some(a), - (Some(_), Some(b)) => Some(b), - } + .and_then(|mut catchers| catchers.find(|c| c.matches(status, ty, req))) } } diff --git a/core/lib/src/server.rs b/core/lib/src/server.rs index badfb44c95..6f8b5e2f0f 100644 --- a/core/lib/src/server.rs +++ b/core/lib/src/server.rs @@ -43,12 +43,12 @@ impl Rocket { let mut response = request.into_response( stream, |rocket, request, data| Box::pin(rocket.preprocess(request, data)), - |token, rocket, request, data| Box::pin(async move { + |token, rocket, request, error_box, data| Box::pin(async move { if !request.errors.is_empty() { - return rocket.dispatch_error(Status::BadRequest, request).await; + return rocket.dispatch_error(error_box.write(Box::new(Status::BadRequest)), request).await; } - rocket.dispatch(token, request, data).await + rocket.dispatch(token, request, error_box, data).await }) ).await; diff --git a/core/lib/src/trace/traceable.rs b/core/lib/src/trace/traceable.rs index 9ef1d4282f..0255501160 100644 --- a/core/lib/src/trace/traceable.rs +++ b/core/lib/src/trace/traceable.rs @@ -247,8 +247,8 @@ impl Trace for route::Outcome<'_> { }, status = match self { Self::Success(r) => r.status().code, - Self::Error(s) => s.code, - Self::Forward((_, s)) => s.code, + Self::Error(s) => s.status().code, + Self::Forward((_, s)) => s.status().code, }, ) } From ba1a7ff13545a43b8cd4aeb972668421f43be836 Mon Sep 17 00:00:00 2001 From: Matthew Pomes Date: Sun, 17 Nov 2024 19:06:55 -0600 Subject: [PATCH 02/20] Add FromError and update Fairings --- core/lib/src/catcher/from_error.rs | 88 ++++++++++++++++++++++++++++ core/lib/src/catcher/mod.rs | 2 + core/lib/src/fairing/ad_hoc.rs | 38 ++++++++++++ core/lib/src/fairing/fairings.rs | 5 +- core/lib/src/fairing/info_kind.rs | 1 + core/lib/src/fairing/mod.rs | 35 +++++++++-- core/lib/src/lifecycle.rs | 4 +- core/lib/src/outcome.rs | 6 +- core/lib/src/request/from_request.rs | 9 +-- core/lib/src/response/flash.rs | 22 +++++-- core/lib/src/state.rs | 18 +++++- 11 files changed, 205 insertions(+), 23 deletions(-) create mode 100644 core/lib/src/catcher/from_error.rs diff --git a/core/lib/src/catcher/from_error.rs b/core/lib/src/catcher/from_error.rs new file mode 100644 index 0000000000..d0daaca320 --- /dev/null +++ b/core/lib/src/catcher/from_error.rs @@ -0,0 +1,88 @@ +use async_trait::async_trait; + +use crate::http::Status; +use crate::outcome::Outcome; +use crate::request::FromRequest; +use crate::Request; + +use crate::catcher::TypedError; + +/// Trait used to extract types for an error catcher. You should +/// pretty much never implement this yourself. There are several +/// existing implementations, that should cover every need. +/// +/// - [`Status`]: Extracts the HTTP status that this error is catching. +/// - [`&Request<'_>`]: Extracts a reference to the entire request that +/// triggered this error to begin with. +/// - [`T: FromRequest<'_>`]: Extracts type that implements `FromRequest` +/// - [`&dyn TypedError<'_>`]: Extracts the typed error, as a dynamic +/// trait object. +/// - [`Option<&dyn TypedError<'_>>`]: Same as previous, but succeeds even +/// if there is no typed error to extract. +/// +/// [`Status`]: crate::http::Status +/// [`&Request<'_>`]: crate::request::Request +/// [`&dyn TypedError<'_>`]: crate::catcher::TypedError +/// [`Option<&dyn TypedError<'_>>`]: crate::catcher::TypedError +#[async_trait] +pub trait FromError<'r>: Sized { + async fn from_error( + status: Status, + request: &'r Request<'r>, + error: &'r dyn TypedError<'r> + ) -> Result; +} + +#[async_trait] +impl<'r> FromError<'r> for Status { + async fn from_error( + status: Status, + _r: &'r Request<'r>, + _e: &'r dyn TypedError<'r> + ) -> Result { + Ok(status) + } +} + +#[async_trait] +impl<'r> FromError<'r> for &'r Request<'r> { + async fn from_error( + _s: Status, + req: &'r Request<'r>, + _e: &'r dyn TypedError<'r> + ) -> Result { + Ok(req) + } +} + +#[async_trait] +impl<'r, T: FromRequest<'r>> FromError<'r> for T { + async fn from_error( + _s: Status, + req: &'r Request<'r>, + _e: &'r dyn TypedError<'r> + ) -> Result { + match T::from_request(req).await { + Outcome::Success(val) => Ok(val), + Outcome::Error(e) => { + info!("Catcher guard error type: `{:?}`", e.name()); + Err(e.status()) + }, + Outcome::Forward(s) => { + info!(status = %s, "Catcher guard forwarding"); + Err(s) + }, + } + } +} + +#[async_trait] +impl<'r> FromError<'r> for &'r dyn TypedError<'r> { + async fn from_error( + _s: Status, + _r: &'r Request<'r>, + error: &'r dyn TypedError<'r> + ) -> Result { + Ok(error) + } +} diff --git a/core/lib/src/catcher/mod.rs b/core/lib/src/catcher/mod.rs index d9bbb48d48..f3127049b8 100644 --- a/core/lib/src/catcher/mod.rs +++ b/core/lib/src/catcher/mod.rs @@ -3,7 +3,9 @@ mod catcher; mod handler; mod types; +mod from_error; pub use catcher::*; pub use handler::*; pub use types::*; +pub use from_error::*; diff --git a/core/lib/src/fairing/ad_hoc.rs b/core/lib/src/fairing/ad_hoc.rs index b6dfe16b78..4a768d2871 100644 --- a/core/lib/src/fairing/ad_hoc.rs +++ b/core/lib/src/fairing/ad_hoc.rs @@ -1,6 +1,7 @@ use parking_lot::Mutex; use futures::future::{Future, BoxFuture, FutureExt}; +use crate::catcher::TypedError; use crate::{Rocket, Request, Response, Data, Build, Orbit}; use crate::fairing::{Fairing, Kind, Info, Result}; use crate::route::RouteUri; @@ -63,6 +64,10 @@ enum AdHocKind { Request(Box Fn(&'a mut Request<'_>, &'a mut Data<'_>) -> BoxFuture<'a, ()> + Send + Sync + 'static>), + /// An ad-hoc **request** fairing. Called when a request is received. + RequestFilter(Box Fn(&'a Request<'_>) + -> BoxFuture<'a, Result<(), Box + 'a>>> + Send + Sync + 'static>), + /// An ad-hoc **response** fairing. Called when a response is ready to be /// sent to a client. Response(Box Fn(&'r Request<'_>, &'b mut Response<'r>) @@ -159,6 +164,30 @@ impl AdHoc { AdHoc { name, kind: AdHocKind::Request(Box::new(f)) } } + /// Constructs an `AdHoc` request filter fairing named `name`. The function `f` + /// will be called and the returned `Future` will be `await`ed by Rocket + /// when a new request is received. + /// + /// # Example + /// + /// ```rust + /// use rocket::fairing::AdHoc; + /// + /// // The no-op request fairing. + /// let fairing = AdHoc::on_request_filter("Dummy", |req| { + /// Box::pin(async move { + /// // do something with the request and data... + /// # let (_, _) = (req, data); + /// Ok(()) + /// }) + /// }); + /// ``` + pub fn on_request_filter(name: &'static str, f: F) -> AdHoc + where F: for<'a> Fn(&'a Request<'_>) -> BoxFuture<'a, Result<(), Box + 'a>>> + { + AdHoc { name, kind: AdHocKind::RequestFilter(Box::new(f)) } + } + // FIXME(rustc): We'd like to allow passing `async fn` to these methods... // https://github.com/rust-lang/rust/issues/64552#issuecomment-666084589 @@ -407,6 +436,7 @@ impl Fairing for AdHoc { AdHocKind::Ignite(_) => Kind::Ignite, AdHocKind::Liftoff(_) => Kind::Liftoff, AdHocKind::Request(_) => Kind::Request, + AdHocKind::RequestFilter(_) => Kind::RequestFilter, AdHocKind::Response(_) => Kind::Response, AdHocKind::Shutdown(_) => Kind::Shutdown, }; @@ -433,6 +463,14 @@ impl Fairing for AdHoc { } } + async fn on_request_filter<'r>(&self, req: &'r Request<'_>) -> Result<(), Box + 'r>> { + if let AdHocKind::RequestFilter(ref f) = self.kind { + f(req).await + } else { + Ok(()) + } + } + async fn on_response<'r>(&self, req: &'r Request<'_>, res: &mut Response<'r>) { if let AdHocKind::Response(ref f) = self.kind { f(req, res).await diff --git a/core/lib/src/fairing/fairings.rs b/core/lib/src/fairing/fairings.rs index 0b617955fb..4974ddcae4 100644 --- a/core/lib/src/fairing/fairings.rs +++ b/core/lib/src/fairing/fairings.rs @@ -166,11 +166,11 @@ impl Fairings { } #[inline(always)] - pub async fn handle_request_filter<'r>(&self, req: &'r Request<'_>, data: &mut Data<'_>) + pub async fn handle_request_filter<'r>(&self, req: &'r Request<'_>) -> Result<(), Box + 'r>> { for fairing in iter!(self.request_filter) { - fairing.on_request_filter(req, data).await?; + fairing.on_request_filter(req).await?; } Ok(()) } @@ -227,6 +227,7 @@ impl std::fmt::Debug for Fairings { .field("launch", &debug_info(iter!(self.ignite))) .field("liftoff", &debug_info(iter!(self.liftoff))) .field("request", &debug_info(iter!(self.request))) + .field("request_filter", &debug_info(iter!(self.request_filter))) .field("response", &debug_info(iter!(self.response))) .field("shutdown", &debug_info(iter!(self.shutdown))) .finish() diff --git a/core/lib/src/fairing/info_kind.rs b/core/lib/src/fairing/info_kind.rs index 3fd28e6972..ad06db4828 100644 --- a/core/lib/src/fairing/info_kind.rs +++ b/core/lib/src/fairing/info_kind.rs @@ -39,6 +39,7 @@ pub struct Info { /// * Ignite /// * Liftoff /// * Request +/// * RequestFilter /// * Response /// * Shutdown /// diff --git a/core/lib/src/fairing/mod.rs b/core/lib/src/fairing/mod.rs index b52c57f383..e8573ceee6 100644 --- a/core/lib/src/fairing/mod.rs +++ b/core/lib/src/fairing/mod.rs @@ -154,6 +154,21 @@ pub type Result, E = Rocket> = std::result::ResultRequestFilter (`on_request_filter`)** +/// +/// A request callback, represented by the [`Fairing::on_request_filter()`] method, +/// is called just after a request is received, immediately after +/// pre-processing the request and running all `Request` fairings. This method +/// returns a `Result`, which can be used to terminate processing of a request, +// TODO: Typed: links +/// bypassing the routing process. The error value must be a `TypedError`, which +/// can then be caught by a typed catcher. +/// +/// This method should only be used for global filters, i.e., filters that need +/// to be run on every (or very nearly every) route. One common example might be +/// CORS, since the CORS headers of every request need to be inspected, and potentially +/// rejected. +/// /// * **Response (`on_response`)** /// /// A response callback, represented by the [`Fairing::on_response()`] @@ -276,6 +291,13 @@ pub type Result, E = Rocket> = std::result::Result(&self, req: &'r Request<'_>) +/// -> Result<(), Box + 'r>> +/// { +/// /* ... */ +/// # unimplemented!() +/// } +/// /// async fn on_response<'r>(&self, req: &'r Request<'_>, res: &mut Response<'r>) { /// /* ... */ /// # unimplemented!() @@ -516,11 +538,9 @@ pub trait Fairing: Send + Sync + AsAny + 'static { /// ## Default Implementation /// /// The default implementation of this method does nothing. - async fn on_request_filter<'r>(&self, _req: &'r Request<'_>, _data: &mut Data<'_>) + async fn on_request_filter<'r>(&self, _req: &'r Request<'_>) -> Result<(), Box + 'r>> - { - Ok (()) - } + { Ok (()) } /// The response callback. /// @@ -579,6 +599,13 @@ impl Fairing for std::sync::Arc { (self as &T).on_request(req, data).await } + #[inline] + async fn on_request_filter<'r>(&self, req: &'r Request<'_>) + -> Result<(), Box + 'r>> + { + (self as &T).on_request_filter(req).await + } + #[inline] async fn on_response<'r>(&self, req: &'r Request<'_>, res: &mut Response<'r>) { (self as &T).on_response(req, res).await diff --git a/core/lib/src/lifecycle.rs b/core/lib/src/lifecycle.rs index 62ec4b1c97..29721fc372 100644 --- a/core/lib/src/lifecycle.rs +++ b/core/lib/src/lifecycle.rs @@ -100,14 +100,14 @@ impl Rocket { _token: RequestToken, request: &'r Request<'s>, error_box: &mut ErrorBox, - mut data: Data<'r>, + data: Data<'r>, // io_stream: impl Future> + Send, ) -> Response<'r> { // Remember if the request is `HEAD` for later body stripping. let was_head_request = request.method() == Method::Head; // Run request filter - let mut response = if let Err(error) = self.fairings.handle_request_filter(request, &mut data).await { + let mut response = if let Err(error) = self.fairings.handle_request_filter(request).await { let error = error_box.write(error); self.dispatch_error(error, request).await } else { diff --git a/core/lib/src/outcome.rs b/core/lib/src/outcome.rs index 6a07080272..6db291e151 100644 --- a/core/lib/src/outcome.rs +++ b/core/lib/src/outcome.rs @@ -769,14 +769,14 @@ impl<'r, T: FromData<'r>> IntoOutcome> for Result IntoOutcome> for Result { - type Error = Status; + type Error = (); type Forward = Status; #[inline] - fn or_error(self, error: Status) -> request::Outcome { + fn or_error(self, _: ()) -> request::Outcome { match self { Ok(val) => Success(val), - Err(err) => Error((error, err)) + Err(err) => Error(err) } } diff --git a/core/lib/src/request/from_request.rs b/core/lib/src/request/from_request.rs index f4677db3c9..432cedee68 100644 --- a/core/lib/src/request/from_request.rs +++ b/core/lib/src/request/from_request.rs @@ -1,7 +1,7 @@ -use std::fmt::Debug; use std::convert::Infallible; use std::net::{IpAddr, SocketAddr}; +use crate::catcher::TypedError; use crate::{Request, Route}; use crate::outcome::{self, IntoOutcome, Outcome::*}; @@ -10,7 +10,7 @@ use crate::http::{Status, ContentType, Accept, Method, ProxyProto, CookieJar}; use crate::listener::Endpoint; /// Type alias for the `Outcome` of a `FromRequest` conversion. -pub type Outcome = outcome::Outcome; +pub type Outcome = outcome::Outcome; /// Trait implemented by request guards to derive a value from incoming /// requests. @@ -376,10 +376,11 @@ pub type Outcome = outcome::Outcome; /// User` and `Admin<'a>`) as the data is now owned by the request's cache. /// /// [request-local state]: https://rocket.rs/master/guide/state/#request-local-state +// TODO: Typed: docs #[crate::async_trait] pub trait FromRequest<'r>: Sized { /// The associated error to be returned if derivation fails. - type Error: Debug; + type Error: TypedError<'r> + 'r; /// Derives an instance of `Self` from the incoming request metadata. /// @@ -513,7 +514,7 @@ impl<'r, T: FromRequest<'r>> FromRequest<'r> for Result { async fn from_request(request: &'r Request<'_>) -> Outcome { match T::from_request(request).await { Success(val) => Success(Ok(val)), - Error((_, e)) => Success(Err(e)), + Error(e) => Success(Err(e)), Forward(status) => Forward(status), } } diff --git a/core/lib/src/response/flash.rs b/core/lib/src/response/flash.rs index 1e495c97eb..73d70f7ae6 100644 --- a/core/lib/src/response/flash.rs +++ b/core/lib/src/response/flash.rs @@ -1,6 +1,8 @@ use time::Duration; use serde::ser::{Serialize, Serializer, SerializeStruct}; +use transient::Static; +use crate::catcher::TypedError; use crate::outcome::IntoOutcome; use crate::response::{self, Responder}; use crate::request::{self, Request, FromRequest}; @@ -234,6 +236,16 @@ impl<'r> FlashMessage<'r> { } } +/// Error for a FlashMessage not being present in a request. +#[derive(Debug, PartialEq, Eq)] +pub struct FlashCookieMissing; + +impl Static for FlashCookieMissing {} + +impl<'r> TypedError<'r> for FlashCookieMissing { + fn status(&self) -> Status { Status::InternalServerError } +} + /// Retrieves a flash message from a flash cookie. If there is no flash cookie, /// or if the flash cookie is malformed, an empty `Err` is returned. /// @@ -241,22 +253,22 @@ impl<'r> FlashMessage<'r> { /// in `request`: `Option`. #[crate::async_trait] impl<'r> FromRequest<'r> for FlashMessage<'r> { - type Error = (); + type Error = FlashCookieMissing; async fn from_request(req: &'r Request<'_>) -> request::Outcome { - req.cookies().get(FLASH_COOKIE_NAME).ok_or(()).and_then(|cookie| { + req.cookies().get(FLASH_COOKIE_NAME).ok_or(FlashCookieMissing).and_then(|cookie| { // Parse the flash message. let content = cookie.value(); let (len_str, kv) = match content.find(FLASH_COOKIE_DELIM) { Some(i) => (&content[..i], &content[(i + 1)..]), - None => return Err(()), + None => return Err(FlashCookieMissing), }; match len_str.parse::() { Ok(i) if i <= kv.len() => Ok(Flash::named(&kv[..i], &kv[i..], req)), - _ => Err(()) + _ => Err(FlashCookieMissing) } - }).or_error(Status::BadRequest) + }).or_error(()) } } diff --git a/core/lib/src/state.rs b/core/lib/src/state.rs index aa5d941d97..c256c7f0ec 100644 --- a/core/lib/src/state.rs +++ b/core/lib/src/state.rs @@ -3,7 +3,9 @@ use std::ops::Deref; use std::any::type_name; use ref_cast::RefCast; +use transient::Static; +use crate::catcher::TypedError; use crate::{Phase, Rocket, Ignite, Sentinel}; use crate::request::{self, FromRequest, Request}; use crate::outcome::Outcome; @@ -191,12 +193,22 @@ impl<'r, T: Send + Sync + 'static> From<&'r T> for &'r State { } } +/// Error for a managed state element not being present. +#[derive(Debug, PartialEq, Eq)] +pub struct StateMissing(pub &'static str); + +impl Static for StateMissing {} + +impl<'r> TypedError<'r> for StateMissing { + fn status(&self) -> Status { Status::InternalServerError } +} + #[crate::async_trait] impl<'r, T: Send + Sync + 'static> FromRequest<'r> for &'r State { - type Error = (); + type Error = StateMissing; #[inline(always)] - async fn from_request(req: &'r Request<'_>) -> request::Outcome { + async fn from_request(req: &'r Request<'_>) -> request::Outcome { match State::get(req.rocket()) { Some(state) => Outcome::Success(state), None => { @@ -204,7 +216,7 @@ impl<'r, T: Send + Sync + 'static> FromRequest<'r> for &'r State { "retrieving unmanaged state\n\ state must be managed via `rocket.manage()`"); - Outcome::Error((Status::InternalServerError, ())) + Outcome::Error(StateMissing(type_name::())) } } } From 3eb41541c0cb35eba505d76034f7b2d7bb8594f6 Mon Sep 17 00:00:00 2001 From: Matthew Pomes Date: Sun, 17 Nov 2024 19:14:09 -0600 Subject: [PATCH 03/20] Update sync_db_pools for new API --- contrib/sync_db_pools/lib/src/connection.rs | 22 ++++++++++++++++----- core/lib/src/erased.rs | 4 ++-- core/lib/src/local/asynchronous/response.rs | 4 ++-- core/lib/src/response/flash.rs | 2 +- 4 files changed, 22 insertions(+), 10 deletions(-) diff --git a/contrib/sync_db_pools/lib/src/connection.rs b/contrib/sync_db_pools/lib/src/connection.rs index 34f43e176c..8359e1c673 100644 --- a/contrib/sync_db_pools/lib/src/connection.rs +++ b/contrib/sync_db_pools/lib/src/connection.rs @@ -1,6 +1,8 @@ +use std::any::type_name; use std::sync::Arc; use std::marker::PhantomData; +use rocket::catcher::{Static, TypedError}; use rocket::{Phase, Rocket, Ignite, Sentinel}; use rocket::fairing::{AdHoc, Fairing}; use rocket::request::{Request, Outcome, FromRequest}; @@ -221,19 +223,29 @@ impl Drop for ConnectionPool { } } +/// Error for a managed state element not being present. +#[derive(Debug, PartialEq, Eq)] +pub struct ConnectionMissing(pub &'static str); + +impl Static for ConnectionMissing {} + +impl<'r> TypedError<'r> for ConnectionMissing { + fn status(&self) -> Status { Status::InternalServerError } +} + #[rocket::async_trait] impl<'r, K: 'static, C: Poolable> FromRequest<'r> for Connection { - type Error = (); + type Error = ConnectionMissing; #[inline] - async fn from_request(request: &'r Request<'_>) -> Outcome { + async fn from_request(request: &'r Request<'_>) -> Outcome { match request.rocket().state::>() { - Some(c) => c.get().await.or_error((Status::ServiceUnavailable, ())), + Some(c) => c.get().await.or_error(ConnectionMissing(type_name::())), None => { - let conn = std::any::type_name::(); + let conn = type_name::(); error!("`{conn}::fairing()` is not attached\n\ the fairing must be attached to use `{conn} in routes."); - Outcome::Error((Status::InternalServerError, ())) + Outcome::Error(ConnectionMissing(type_name::())) } } } diff --git a/core/lib/src/erased.rs b/core/lib/src/erased.rs index 2cd86bbf8a..02647b724d 100644 --- a/core/lib/src/erased.rs +++ b/core/lib/src/erased.rs @@ -65,7 +65,7 @@ pub struct ErasedResponse { // XXX: SAFETY: This (dependent) field must come first due to drop order! response: Response<'static>, // XXX: SAFETY: This (dependent) field must come second due to drop order! - error: ErrorBox, + _error: ErrorBox, _request: Arc, } @@ -151,7 +151,7 @@ impl ErasedRequest { ErasedResponse { _request: parent, - error, + _error: error, response, } } diff --git a/core/lib/src/local/asynchronous/response.rs b/core/lib/src/local/asynchronous/response.rs index 93313332e5..4166d9cd29 100644 --- a/core/lib/src/local/asynchronous/response.rs +++ b/core/lib/src/local/asynchronous/response.rs @@ -58,7 +58,7 @@ pub struct LocalResponse<'c> { // XXX: SAFETY: This (dependent) field must come first due to drop order! response: Response<'c>, // XXX: SAFETY: This (dependent) field must come second due to drop order! - error: ErrorBox, + _error: ErrorBox, cookies: CookieJar<'c>, _request: Box>, } @@ -108,7 +108,7 @@ impl<'c> LocalResponse<'c> { cookies.add_original(cookie.into_owned()); } - LocalResponse { _request: boxed_req, error: error_box, cookies, response, } + LocalResponse { _request: boxed_req, _error: error_box, cookies, response, } } } } diff --git a/core/lib/src/response/flash.rs b/core/lib/src/response/flash.rs index 73d70f7ae6..ac5b0febda 100644 --- a/core/lib/src/response/flash.rs +++ b/core/lib/src/response/flash.rs @@ -243,7 +243,7 @@ pub struct FlashCookieMissing; impl Static for FlashCookieMissing {} impl<'r> TypedError<'r> for FlashCookieMissing { - fn status(&self) -> Status { Status::InternalServerError } + fn status(&self) -> Status { Status::BadRequest } } /// Retrieves a flash message from a flash cookie. If there is no flash cookie, From 368bd062be084a55722c59519abfda95b3cb0831 Mon Sep 17 00:00:00 2001 From: Matthew Pomes Date: Sun, 17 Nov 2024 19:22:12 -0600 Subject: [PATCH 04/20] Update db_pools --- contrib/db_pools/lib/src/database.rs | 29 ++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/contrib/db_pools/lib/src/database.rs b/contrib/db_pools/lib/src/database.rs index 4467eaed06..5265982b3e 100644 --- a/contrib/db_pools/lib/src/database.rs +++ b/contrib/db_pools/lib/src/database.rs @@ -1,6 +1,7 @@ use std::marker::PhantomData; use std::ops::{Deref, DerefMut}; +use rocket::catcher::{Static, TypedError}; use rocket::{error, Build, Ignite, Phase, Rocket, Sentinel, Orbit}; use rocket::fairing::{self, Fairing, Info, Kind}; use rocket::request::{FromRequest, Outcome, Request}; @@ -278,17 +279,37 @@ impl Fairing for Initializer { } } +/// Possible errors when aquiring a database connection +#[derive(Debug, PartialEq, Eq)] +pub enum DbError { + ServiceUnavailable(E), + InternalError, +} + +impl Static for DbError {} + +impl<'r, E: Send + Sync + 'static> TypedError<'r> for DbError { + fn status(&self) -> Status { + match self { + Self::ServiceUnavailable(_) => Status::ServiceUnavailable, + Self::InternalError => Status::InternalServerError, + } + } +} + #[rocket::async_trait] -impl<'r, D: Database> FromRequest<'r> for Connection { - type Error = Option<::Error>; +impl<'r, D: Database> FromRequest<'r> for Connection + where ::Error: Send + Sync + 'static +{ + type Error = DbError<::Error>; async fn from_request(req: &'r Request<'_>) -> Outcome { match D::fetch(req.rocket()) { Some(db) => match db.get().await { Ok(conn) => Outcome::Success(Connection(conn)), - Err(e) => Outcome::Error((Status::ServiceUnavailable, Some(e))), + Err(e) => Outcome::Error(DbError::ServiceUnavailable(e)), }, - None => Outcome::Error((Status::InternalServerError, None)), + None => Outcome::Error(DbError::InternalError), } } } From 0c5f576dd776a2c6ace28532d1bc50792dfe175e Mon Sep 17 00:00:00 2001 From: Matthew Pomes Date: Sun, 17 Nov 2024 19:23:39 -0600 Subject: [PATCH 05/20] Update websockets --- contrib/ws/src/websocket.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/contrib/ws/src/websocket.rs b/contrib/ws/src/websocket.rs index 361550441e..fba837fcba 100644 --- a/contrib/ws/src/websocket.rs +++ b/contrib/ws/src/websocket.rs @@ -238,7 +238,7 @@ impl<'r> FromRequest<'r> for WebSocket { } impl<'r, 'o: 'r> Responder<'r, 'o> for Channel<'o> { - fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o> { + fn respond_to(self, _: &'r Request<'_>) -> response::Result<'r, 'o> { Response::build() .raw_header("Sec-Websocket-Version", "13") .raw_header("Sec-WebSocket-Accept", self.ws.key.clone()) @@ -250,7 +250,7 @@ impl<'r, 'o: 'r> Responder<'r, 'o> for Channel<'o> { impl<'r, 'o: 'r, S> Responder<'r, 'o> for MessageStream<'o, S> where S: futures::Stream> + Send + 'o { - fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o> { + fn respond_to(self, _: &'r Request<'_>) -> response::Result<'r, 'o> { Response::build() .raw_header("Sec-Websocket-Version", "13") .raw_header("Sec-WebSocket-Accept", self.ws.key.clone()) From fded2a8980ced2536375bcae95649e4966058dc6 Mon Sep 17 00:00:00 2001 From: Matthew Pomes Date: Sun, 17 Nov 2024 19:33:57 -0600 Subject: [PATCH 06/20] Update rocket_dyn_templates --- contrib/dyn_templates/src/metadata.rs | 18 ++++++++++-------- contrib/dyn_templates/src/template.rs | 2 +- core/lib/src/catcher/types.rs | 6 ++++++ core/lib/src/lib.rs | 1 + 4 files changed, 18 insertions(+), 9 deletions(-) diff --git a/contrib/dyn_templates/src/metadata.rs b/contrib/dyn_templates/src/metadata.rs index 8c19b1f5f1..a20985841a 100644 --- a/contrib/dyn_templates/src/metadata.rs +++ b/contrib/dyn_templates/src/metadata.rs @@ -1,7 +1,8 @@ use std::fmt; use std::borrow::Cow; -use rocket::{Request, Rocket, Ignite, Sentinel}; +use rocket::outcome::Outcome; +use rocket::{Ignite, Request, Rocket, Sentinel, StateMissing}; use rocket::http::{Status, ContentType}; use rocket::request::{self, FromRequest}; use rocket::serde::Serialize; @@ -152,18 +153,19 @@ impl Sentinel for Metadata<'_> { /// (`500`) is returned. #[rocket::async_trait] impl<'r> FromRequest<'r> for Metadata<'r> { - type Error = (); + type Error = StateMissing; - async fn from_request(request: &'r Request<'_>) -> request::Outcome { - request.rocket().state::() - .map(|cm| request::Outcome::Success(Metadata(cm))) - .unwrap_or_else(|| { + async fn from_request(request: &'r Request<'_>) -> request::Outcome { + match request.rocket().state::() { + Some(cm) => Outcome::Success(cm), + None => { error!( "uninitialized template context: missing `Template::fairing()`.\n\ To use templates, you must attach `Template::fairing()`." ); - request::Outcome::Error((Status::InternalServerError, ())) - }) + request::Outcome::Error(StateMissing("Template::fairing()")) + } + } } } diff --git a/contrib/dyn_templates/src/template.rs b/contrib/dyn_templates/src/template.rs index 97a73b7b76..79c2b2bcc2 100644 --- a/contrib/dyn_templates/src/template.rs +++ b/contrib/dyn_templates/src/template.rs @@ -265,7 +265,7 @@ impl Template { /// extension and a fixed-size body containing the rendered template. If /// rendering fails, an `Err` of `Status::InternalServerError` is returned. impl<'r> Responder<'r, 'static> for Template { - fn respond_to(self, req: &'r Request<'_>) -> response::Result<'static> { + fn respond_to(self, req: &'r Request<'_>) -> response::Result<'r, 'static> { let ctxt = req.rocket() .state::() .ok_or_else(|| { diff --git a/core/lib/src/catcher/types.rs b/core/lib/src/catcher/types.rs index 2eae56a97e..a2eaa563df 100644 --- a/core/lib/src/catcher/types.rs +++ b/core/lib/src/catcher/types.rs @@ -77,6 +77,12 @@ impl<'r> TypedError<'r> for Status { } } +impl<'r> From for Box + 'r> { + fn from(value: Status) -> Self { + Box::new(value) + } +} + // TODO: Typed: update transient to make the possible. // impl<'r, R: TypedError<'r> + Transient> TypedError<'r> for (Status, R) // where R::Transience: CanTranscendTo> diff --git a/core/lib/src/lib.rs b/core/lib/src/lib.rs index 8e629ed685..b2a099e8a2 100644 --- a/core/lib/src/lib.rs +++ b/core/lib/src/lib.rs @@ -179,6 +179,7 @@ mod erased; #[doc(inline)] pub use crate::rkt::Rocket; #[doc(inline)] pub use crate::shutdown::Shutdown; #[doc(inline)] pub use crate::state::State; +#[doc(inline)] pub use crate::state::StateMissing; /// Retrofits support for `async fn` in trait impls and declarations. /// From 35cc4fa611b31f6ecc518907ca809b66a014035c Mon Sep 17 00:00:00 2001 From: Matthew Pomes Date: Sun, 17 Nov 2024 20:14:11 -0600 Subject: [PATCH 07/20] Add todo notes --- core/lib/src/catcher/handler.rs | 1 + core/lib/src/lifecycle.rs | 1 + core/lib/src/route/handler.rs | 1 + 3 files changed, 3 insertions(+) diff --git a/core/lib/src/catcher/handler.rs b/core/lib/src/catcher/handler.rs index ea64fff500..e05302d18a 100644 --- a/core/lib/src/catcher/handler.rs +++ b/core/lib/src/catcher/handler.rs @@ -88,6 +88,7 @@ pub type BoxFuture<'r, T = Result<'r>> = futures::future::BoxFuture<'r, T>; /// directly as the parameter to `rocket.register("/", )`. /// 3. Unlike static-function-based handlers, this custom handler can make use /// of internal state. +// TODO: Typed: Docs #[crate::async_trait] pub trait Handler: Cloneable + Send + Sync + 'static { /// Called by Rocket when an error with `status` for a given `Request` diff --git a/core/lib/src/lifecycle.rs b/core/lib/src/lifecycle.rs index 29721fc372..f3be7d8d19 100644 --- a/core/lib/src/lifecycle.rs +++ b/core/lib/src/lifecycle.rs @@ -297,6 +297,7 @@ impl Rocket { /// Return `Ok(result)` if the handler succeeded. Returns `Ok(Some(Status))` /// if the handler ran to completion but failed. Returns `Ok(None)` if the /// handler panicked while executing. + // TODO: Typed: Docs async fn invoke_catcher<'s, 'r: 's>( &'s self, error: &'r dyn TypedError<'r>, diff --git a/core/lib/src/route/handler.rs b/core/lib/src/route/handler.rs index 0fb9c82e9c..e4d02b1693 100644 --- a/core/lib/src/route/handler.rs +++ b/core/lib/src/route/handler.rs @@ -133,6 +133,7 @@ pub type BoxFuture<'r, T = Outcome<'r>> = futures::future::BoxFuture<'r, T>; /// Use this alternative when a single configuration is desired and your custom /// handler is private to your application. For all other cases, a custom /// `Handler` implementation is preferred. +// TODO: Typed: Docs #[crate::async_trait] pub trait Handler: Cloneable + Send + Sync + 'static { /// Called by Rocket when a `Request` with its associated `Data` should be From e0fbae62081992a500410a5ecbd452a6c3eb83e4 Mon Sep 17 00:00:00 2001 From: Matthew Pomes Date: Sun, 17 Nov 2024 20:30:09 -0600 Subject: [PATCH 08/20] Fixing more compile issues --- contrib/dyn_templates/src/metadata.rs | 2 +- core/http/src/uri/error.rs | 4 ++++ core/lib/src/catcher/types.rs | 4 ++++ core/lib/src/error.rs | 10 ++++++++++ core/lib/src/mtls/certificate.rs | 4 ++-- core/lib/src/mtls/error.rs | 11 ++++++++++- core/lib/src/serde/json.rs | 4 ++-- core/lib/src/serde/msgpack.rs | 2 +- 8 files changed, 34 insertions(+), 7 deletions(-) diff --git a/contrib/dyn_templates/src/metadata.rs b/contrib/dyn_templates/src/metadata.rs index a20985841a..db7bf677ba 100644 --- a/contrib/dyn_templates/src/metadata.rs +++ b/contrib/dyn_templates/src/metadata.rs @@ -157,7 +157,7 @@ impl<'r> FromRequest<'r> for Metadata<'r> { async fn from_request(request: &'r Request<'_>) -> request::Outcome { match request.rocket().state::() { - Some(cm) => Outcome::Success(cm), + Some(cm) => Outcome::Success(Metadata(cm)), None => { error!( "uninitialized template context: missing `Template::fairing()`.\n\ diff --git a/core/http/src/uri/error.rs b/core/http/src/uri/error.rs index 06705d27f2..bb56b53a1c 100644 --- a/core/http/src/uri/error.rs +++ b/core/http/src/uri/error.rs @@ -2,6 +2,8 @@ use std::fmt; +use transient::Static; + pub use crate::parse::uri::Error; /// The error type returned when a URI conversion fails. @@ -29,6 +31,8 @@ pub enum PathError { BadEnd(char), } +impl Static for PathError {} + impl fmt::Display for PathError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { diff --git a/core/lib/src/catcher/types.rs b/core/lib/src/catcher/types.rs index a2eaa563df..b968e5f3b6 100644 --- a/core/lib/src/catcher/types.rs +++ b/core/lib/src/catcher/types.rs @@ -150,6 +150,10 @@ impl<'r> TypedError<'r> for std::string::FromUtf8Error { fn status(&self) -> Status { Status::BadRequest } } +impl<'r> TypedError<'r> for crate::http::uri::error::PathError { + fn status(&self) -> Status { Status::BadRequest } +} + #[cfg(feature = "json")] impl<'r> TypedError<'r> for serde_json::Error { fn status(&self) -> Status { Status::BadRequest } diff --git a/core/lib/src/error.rs b/core/lib/src/error.rs index 60bbdd87ec..12a0ec4708 100644 --- a/core/lib/src/error.rs +++ b/core/lib/src/error.rs @@ -5,7 +5,10 @@ use std::error::Error as StdError; use std::sync::Arc; use figment::Profile; +use transient::Static; +use crate::http::Status; +use crate::catcher::TypedError; use crate::listener::Endpoint; use crate::{Catcher, Ignite, Orbit, Phase, Rocket, Route}; use crate::trace::Trace; @@ -89,6 +92,13 @@ pub enum ErrorKind { #[derive(Clone, Copy, Default, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct Empty; +impl Static for Empty {} +impl<'r> TypedError<'r> for Empty { + fn status(&self) -> Status { + Status::BadRequest + } +} + /// An error that occurs when a value doesn't match one of the expected options. /// /// This error is returned by the [`FromParam`] trait implementation generated diff --git a/core/lib/src/mtls/certificate.rs b/core/lib/src/mtls/certificate.rs index 3bccb7bb8d..fe887d434d 100644 --- a/core/lib/src/mtls/certificate.rs +++ b/core/lib/src/mtls/certificate.rs @@ -114,13 +114,13 @@ impl<'r> FromRequest<'r> for Certificate<'r> { async fn from_request(req: &'r Request<'_>) -> Outcome { use crate::outcome::{try_outcome, IntoOutcome}; - let certs = req.connection + let certs: Outcome<_, Error> = req.connection .peer_certs .as_ref() .or_forward(Status::Unauthorized); let chain = try_outcome!(certs); - Certificate::parse(chain.inner()).or_error(Status::Unauthorized) + Certificate::parse(chain.inner()).or_error(()) } } diff --git a/core/lib/src/mtls/error.rs b/core/lib/src/mtls/error.rs index 703835f299..9741121e38 100644 --- a/core/lib/src/mtls/error.rs +++ b/core/lib/src/mtls/error.rs @@ -1,7 +1,10 @@ use std::fmt; use std::num::NonZeroUsize; -use crate::mtls::x509::{self, nom}; +use transient::Static; + +use crate::{catcher::TypedError, mtls::x509::{self, nom}}; +use crate::http::Status; /// An error returned by the [`Certificate`](crate::mtls::Certificate) guard. /// @@ -41,6 +44,12 @@ pub enum Error { Trailing(usize), } +impl Static for Error {} + +impl<'r> TypedError<'r> for Error { + fn status(&self) -> Status { Status::Unauthorized } +} + impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { diff --git a/core/lib/src/serde/json.rs b/core/lib/src/serde/json.rs index d68b24f04b..fd08f2401e 100644 --- a/core/lib/src/serde/json.rs +++ b/core/lib/src/serde/json.rs @@ -216,7 +216,7 @@ impl<'r, T: Deserialize<'r>> FromData<'r> for Json { /// JSON and a fixed-size body with the serialized value. If serialization /// fails, an `Err` of `Status::InternalServerError` is returned. impl<'r, T: Serialize> Responder<'r, 'static> for Json { - fn respond_to(self, req: &'r Request<'_>) -> response::Result<'static> { + fn respond_to(self, req: &'r Request<'_>) -> response::Result<'r, 'static> { let string = serde_json::to_string(&self.0) .map_err(|e| { error!("JSON serialize failure: {}", e); @@ -298,7 +298,7 @@ impl<'v, T: Deserialize<'v> + Send> form::FromFormField<'v> for Json { /// Serializes the value into JSON. Returns a response with Content-Type JSON /// and a fixed-size body with the serialized value. impl<'r> Responder<'r, 'static> for Value { - fn respond_to(self, req: &'r Request<'_>) -> response::Result<'static> { + fn respond_to(self, req: &'r Request<'_>) -> response::Result<'r, 'static> { content::RawJson(self.to_string()).respond_to(req) } } diff --git a/core/lib/src/serde/msgpack.rs b/core/lib/src/serde/msgpack.rs index bf758b8fe1..93fcd4511d 100644 --- a/core/lib/src/serde/msgpack.rs +++ b/core/lib/src/serde/msgpack.rs @@ -216,7 +216,7 @@ impl<'r, T: Deserialize<'r>> FromData<'r> for MsgPack { /// Content-Type `MsgPack` and a fixed-size body with the serialization. If /// serialization fails, an `Err` of `Status::InternalServerError` is returned. impl<'r, T: Serialize, const COMPACT: bool> Responder<'r, 'static> for MsgPack { - fn respond_to(self, req: &'r Request<'_>) -> response::Result<'static> { + fn respond_to(self, req: &'r Request<'_>) -> response::Result<'r, 'static> { let maybe_buf = if COMPACT { rmp_serde::to_vec(&self.0) } else { From ed5618a062b00ab2874db420dfac738ecbff6572 Mon Sep 17 00:00:00 2001 From: Matthew Pomes Date: Tue, 19 Nov 2024 19:24:50 -0600 Subject: [PATCH 09/20] Complete initial implementation --- contrib/db_pools/codegen/src/database.rs | 6 +- contrib/dyn_templates/src/metadata.rs | 2 +- contrib/sync_db_pools/codegen/src/database.rs | 4 +- contrib/ws/src/websocket.rs | 2 +- core/codegen/src/attribute/catch/mod.rs | 73 +++++-- core/codegen/src/attribute/catch/parse.rs | 47 ++++- core/codegen/src/attribute/route/mod.rs | 50 +++-- core/codegen/src/derive/mod.rs | 1 + core/codegen/src/derive/responder.rs | 13 +- core/codegen/src/derive/typed_error.rs | 160 +++++++++++++++ core/codegen/src/exports.rs | 4 + core/codegen/src/lib.rs | 7 + core/codegen/tests/catcher.rs | 14 +- core/codegen/tests/route-data.rs | 3 +- core/lib/Cargo.toml | 5 +- core/lib/src/catcher/from_error.rs | 6 +- core/lib/src/catcher/types.rs | 188 ++++++++++-------- core/lib/src/data/capped.rs | 4 +- core/lib/src/data/from_data.rs | 36 ++-- core/lib/src/form/error.rs | 32 ++- core/lib/src/form/form.rs | 2 +- core/lib/src/form/from_form_field.rs | 12 +- core/lib/src/form/parser.rs | 8 +- core/lib/src/fs/temp_file.rs | 13 +- core/lib/src/mtls/certificate.rs | 3 +- core/lib/src/outcome.rs | 28 +-- core/lib/src/request/from_param.rs | 134 ++++++++++++- core/lib/src/request/from_request.rs | 56 +++--- core/lib/src/request/mod.rs | 2 +- core/lib/src/response/debug.rs | 16 ++ core/lib/src/response/mod.rs | 2 +- core/lib/src/response/responder.rs | 8 +- core/lib/src/response/status.rs | 29 ++- core/lib/src/route/handler.rs | 4 +- core/lib/src/router/router.rs | 2 +- core/lib/src/serde/json.rs | 30 ++- core/lib/src/serde/msgpack.rs | 11 +- .../lib/tests/forward-includes-status-1560.rs | 4 +- .../local-request-content-type-issue-505.rs | 6 +- .../lib/tests/responder_lifetime-issue-345.rs | 2 +- examples/cookies/src/session.rs | 9 +- examples/error-handling/Cargo.toml | 3 +- examples/error-handling/src/main.rs | 62 +++++- examples/error-handling/src/tests.rs | 21 +- examples/manual-routing/src/main.rs | 15 +- examples/pastebin/src/paste_id.rs | 8 +- examples/responders/src/main.rs | 6 +- examples/serialization/src/tests.rs | 4 +- examples/serialization/src/uuid.rs | 4 +- examples/state/src/request_local.rs | 18 +- examples/todo/src/main.rs | 13 +- 51 files changed, 859 insertions(+), 333 deletions(-) create mode 100644 core/codegen/src/derive/typed_error.rs diff --git a/contrib/db_pools/codegen/src/database.rs b/contrib/db_pools/codegen/src/database.rs index a2ca218de1..b1cbc513a5 100644 --- a/contrib/db_pools/codegen/src/database.rs +++ b/contrib/db_pools/codegen/src/database.rs @@ -60,15 +60,15 @@ pub fn derive_database(input: TokenStream) -> TokenStream { #[rocket::async_trait] impl<'r> rocket::request::FromRequest<'r> for &'r #decorated_type { - type Error = (); + type Error = rocket::http::Status; async fn from_request( req: &'r rocket::request::Request<'_> ) -> rocket::request::Outcome { match #db_ty::fetch(req.rocket()) { Some(db) => rocket::outcome::Outcome::Success(db), - None => rocket::outcome::Outcome::Error(( - rocket::http::Status::InternalServerError, ())) + None => rocket::outcome::Outcome::Error( + rocket::http::Status::InternalServerError) } } } diff --git a/contrib/dyn_templates/src/metadata.rs b/contrib/dyn_templates/src/metadata.rs index db7bf677ba..82ada354da 100644 --- a/contrib/dyn_templates/src/metadata.rs +++ b/contrib/dyn_templates/src/metadata.rs @@ -3,7 +3,7 @@ use std::borrow::Cow; use rocket::outcome::Outcome; use rocket::{Ignite, Request, Rocket, Sentinel, StateMissing}; -use rocket::http::{Status, ContentType}; +use rocket::http::ContentType; use rocket::request::{self, FromRequest}; use rocket::serde::Serialize; diff --git a/contrib/sync_db_pools/codegen/src/database.rs b/contrib/sync_db_pools/codegen/src/database.rs index 51de85c5ef..be033f9f9c 100644 --- a/contrib/sync_db_pools/codegen/src/database.rs +++ b/contrib/sync_db_pools/codegen/src/database.rs @@ -112,11 +112,11 @@ pub fn database_attr(attr: TokenStream, input: TokenStream) -> Result #rocket::request::FromRequest<'r> for #guard_type { - type Error = (); + type Error = ::rocket_sync_db_pools::ConnectionMissing; async fn from_request( __r: &'r #rocket::request::Request<'_> - ) -> #rocket::request::Outcome { + ) -> #rocket::request::Outcome { <#conn>::from_request(__r).await.map(Self) } } diff --git a/contrib/ws/src/websocket.rs b/contrib/ws/src/websocket.rs index fba837fcba..5eea92d53b 100644 --- a/contrib/ws/src/websocket.rs +++ b/contrib/ws/src/websocket.rs @@ -213,7 +213,7 @@ pub struct MessageStream<'r, S> { #[rocket::async_trait] impl<'r> FromRequest<'r> for WebSocket { - type Error = std::convert::Infallible; + type Error = Status; async fn from_request(req: &'r Request<'_>) -> Outcome { use crate::tungstenite::handshake::derive_accept_key; diff --git a/core/codegen/src/attribute/catch/mod.rs b/core/codegen/src/attribute/catch/mod.rs index 57c898a059..f6824f8f7f 100644 --- a/core/codegen/src/attribute/catch/mod.rs +++ b/core/codegen/src/attribute/catch/mod.rs @@ -1,11 +1,12 @@ mod parse; -use devise::ext::SpanDiagnosticExt; +use devise::ext::TypeExt; use devise::{Spanned, Result}; use proc_macro2::{TokenStream, Span}; +use syn::{Lifetime, TypeReference}; use crate::http_codegen::Optional; -use crate::syn_ext::ReturnTypeExt; +use crate::syn_ext::{FnArgExt, IdentExt, ReturnTypeExt}; use crate::exports::*; pub fn _catch( @@ -22,35 +23,63 @@ pub fn _catch( let status_code = Optional(catch.status.map(|s| s.code)); let deprecated = catch.function.attrs.iter().find(|a| a.path().is_ident("deprecated")); - // Determine the number of parameters that will be passed in. - if catch.function.sig.inputs.len() > 2 { - return Err(catch.function.sig.paren_token.span.join() - .error("invalid number of arguments: must be zero, one, or two") - .help("catchers optionally take `&Request` or `Status, &Request`")); - } - // This ensures that "Responder not implemented" points to the return type. let return_type_span = catch.function.sig.output.ty() .map(|ty| ty.span()) .unwrap_or_else(Span::call_site); - // Set the `req` and `status` spans to that of their respective function - // arguments for a more correct `wrong type` error span. `rev` to be cute. - let codegen_args = &[__req, __status]; - let inputs = catch.function.sig.inputs.iter().rev() - .zip(codegen_args.iter()) - .map(|(fn_arg, codegen_arg)| match fn_arg { - syn::FnArg::Receiver(_) => codegen_arg.respanned(fn_arg.span()), - syn::FnArg::Typed(a) => codegen_arg.respanned(a.ty.span()) - }).rev(); + let from_error = catch.guards.iter().map(|g| { + let name = g.fn_ident.rocketized(); + let ty = g.ty.with_replaced_lifetimes(Lifetime::new("'__r", g.ty.span())); + quote_spanned!(g.span() => + let #name: #ty = match <#ty as #FromError<'__r>>::from_error(#__status, #__req, #__error).await { + #_Ok(v) => v, + #_Err(s) => { + // TODO: Typed: log failure + return #_Err(s); + }, + }; + ) + }); + + let error = catch.error.iter().map(|g| { + let name = g.fn_ident.rocketized(); + let ty = g.ty.with_replaced_lifetimes(Lifetime::new("'__r", g.ty.span())); + quote!( + let #name: #ty = match #_catcher::downcast(#__error) { + Some(v) => v, + None => { + // TODO: Typed: log failure - this should never happen + return #_Err(#Status::InternalServerError); + }, + }; + ) + }); + + let error_type = Optional(catch.error.as_ref().map(|g| { + let ty = match &g.ty { + syn::Type::Reference(TypeReference { mutability: None, elem, .. }) => { + elem.as_ref().with_stripped_lifetimes() + }, + _ => todo!("Invalid type"), + }; + quote_spanned!(g.span() => + #_catcher::TypeId::of::<#ty>() + ) + })); // We append `.await` to the function call if this is `async`. let dot_await = catch.function.sig.asyncness .map(|a| quote_spanned!(a.span() => .await)); + let args = catch.function.sig.inputs.iter().map(|a| { + let name = a.typed().unwrap().0.rocketized(); + quote!(#name) + }); + let catcher_response = quote_spanned!(return_type_span => { - let ___responder = #user_catcher_fn_name(#(#inputs),*) #dot_await; - #_response::Responder::respond_to(___responder, #__req)? + let ___responder = #user_catcher_fn_name(#(#args),*) #dot_await; + #_response::Responder::respond_to(___responder, #__req).map_err(|e| e.status())? }); // Generate the catcher, keeping the user's input around. @@ -68,9 +97,12 @@ pub fn _catch( fn into_info(self) -> #_catcher::StaticInfo { fn monomorphized_function<'__r>( #__status: #Status, + #__error: &'__r dyn #TypedError<'__r>, #__req: &'__r #Request<'_> ) -> #_catcher::BoxFuture<'__r> { #_Box::pin(async move { + #(#from_error)* + #(#error)* let __response = #catcher_response; #Response::build() .status(#__status) @@ -83,6 +115,7 @@ pub fn _catch( name: ::core::stringify!(#user_catcher_fn_name), code: #status_code, handler: monomorphized_function, + type_id: #error_type, location: (::core::file!(), ::core::line!(), ::core::column!()), } } diff --git a/core/codegen/src/attribute/catch/parse.rs b/core/codegen/src/attribute/catch/parse.rs index 34125c9c74..8764d14aa8 100644 --- a/core/codegen/src/attribute/catch/parse.rs +++ b/core/codegen/src/attribute/catch/parse.rs @@ -1,13 +1,21 @@ use devise::ext::SpanDiagnosticExt; -use devise::{MetaItem, Spanned, Result, FromMeta, Diagnostic}; +use devise::{Diagnostic, FromMeta, MetaItem, Result, SpanWrapped, Spanned}; use proc_macro2::TokenStream; +use crate::attribute::param::{Dynamic, Guard}; +use crate::name::Name; +use crate::proc_macro_ext::Diagnostics; +use crate::syn_ext::FnArgExt; use crate::{http, http_codegen}; /// This structure represents the parsed `catch` attribute and associated items. pub struct Attribute { /// The status associated with the code in the `#[catch(code)]` attribute. pub status: Option, + /// The parameter to be used as the error type. + pub error: Option, + /// All the other guards + pub guards: Vec, /// The function that was decorated with the `catch` attribute. pub function: syn::ItemFn, } @@ -17,6 +25,7 @@ pub struct Attribute { struct Meta { #[meta(naked)] code: Code, + error: Option>, } /// `Some` if there's a code, `None` if it's `default`. @@ -48,11 +57,41 @@ impl Attribute { .map_err(|diag| diag.help("`#[catch]` can only be used on functions"))?; let attr: MetaItem = syn::parse2(quote!(catch(#args)))?; - let status = Meta::from_meta(&attr) - .map(|meta| meta.code.0) + let meta = Meta::from_meta(&attr) + // .map(|meta| meta.code.0) .map_err(|diag| diag.help("`#[catch]` expects a status code int or `default`: \ `#[catch(404)]` or `#[catch(default)]`"))?; - Ok(Attribute { status, function }) + let mut diags = Diagnostics::new(); + let mut guards = Vec::new(); + let mut error = None; + for (index, arg) in function.sig.inputs.iter().enumerate() { + if let Some((ident, ty)) = arg.typed() { + match meta.error.as_ref() { + Some(err) if Name::from(ident) == err.name => { + error = Some(Guard { source: meta.error.clone().unwrap().value, fn_ident: ident.clone(), ty: ty.clone() }); + } + _ => { + guards.push(Guard { source: Dynamic { name: Name::from(ident), index, trailing: false }, fn_ident: ident.clone(), ty: ty.clone() }) + } + } + } else { + let span = arg.span(); + let diag = if arg.wild().is_some() { + span.error("handler arguments must be named") + .help("to name an ignored handler argument, use `_name`") + } else { + span.error("handler arguments must be of the form `ident: Type`") + }; + + diags.push(diag); + } + } + if meta.error.is_some() != error.is_some() { + let span = meta.error.unwrap().span(); + diags.push(span.error("Error parameter not found on function")); + } + + diags.head_err_or(Attribute { status: meta.code.0, error, guards, function }) } } diff --git a/core/codegen/src/attribute/route/mod.rs b/core/codegen/src/attribute/route/mod.rs index b2979d3fac..68d8ccde86 100644 --- a/core/codegen/src/attribute/route/mod.rs +++ b/core/codegen/src/attribute/route/mod.rs @@ -41,7 +41,7 @@ fn query_decls(route: &Route) -> Option { } define_spanned_export!(Span::call_site() => - __req, __data, _form, Outcome, _Ok, _Err, _Some, _None, Status + __req, __data, _form, Outcome, _Ok, _Err, _Some, _None, TypedError ); // Record all of the static parameters for later filtering. @@ -108,13 +108,13 @@ fn query_decls(route: &Route) -> Option { ::rocket::trace::span_info!( "codegen", "query string failed to match route declaration" => - { for _err in __e { ::rocket::trace::info!( + { for _err in __e.iter() { ::rocket::trace::info!( target: concat!("rocket::codegen::route::", module_path!()), "{_err}" ); } } ); - return #Outcome::Forward((#__data, #Status::UnprocessableEntity)); + return #Outcome::Forward((#__data, Box::new(__e) as Box + '__r>)); } (#(#ident.unwrap()),*) @@ -125,7 +125,7 @@ fn query_decls(route: &Route) -> Option { fn request_guard_decl(guard: &Guard) -> TokenStream { let (ident, ty) = (guard.fn_ident.rocketized(), &guard.ty); define_spanned_export!(ty.span() => - __req, __data, _request, display_hack, FromRequest, Outcome + __req, __data, _request, FromRequest, Outcome, TypedError ); quote_spanned! { ty.span() => @@ -137,24 +137,25 @@ fn request_guard_decl(guard: &Guard) -> TokenStream { target: concat!("rocket::codegen::route::", module_path!()), parameter = stringify!(#ident), type_name = stringify!(#ty), - status = __e.code, + status = #TypedError::status(&__e).code, "request guard forwarding" ); - return #Outcome::Forward((#__data, __e)); + return #Outcome::Forward((#__data, Box::new(__e) as Box + '__r>)); }, #[allow(unreachable_code)] - #Outcome::Error((__c, __e)) => { + #Outcome::Error(__c) => { ::rocket::trace::info!( name: "failure", target: concat!("rocket::codegen::route::", module_path!()), parameter = stringify!(#ident), type_name = stringify!(#ty), - reason = %#display_hack!(__e), + error_name = #TypedError::name(&__c), + // reason = %#display_hack!(__e), "request guard failed" ); - return #Outcome::Error(__c); + return #Outcome::Error(Box::new(__c) as Box + '__r>); } }; } @@ -164,7 +165,7 @@ fn param_guard_decl(guard: &Guard) -> TokenStream { let (i, name, ty) = (guard.index, &guard.name, &guard.ty); define_spanned_export!(ty.span() => __req, __data, _None, _Some, _Ok, _Err, - Outcome, FromSegments, FromParam, Status, display_hack + Outcome, FromSegments, FromParam, Status, TypedError, FromParamError, FromSegmentsError ); // Returned when a dynamic parameter fails to parse. @@ -174,11 +175,12 @@ fn param_guard_decl(guard: &Guard) -> TokenStream { target: concat!("rocket::codegen::route::", module_path!()), parameter = #name, type_name = stringify!(#ty), - reason = %#display_hack!(__error), + name = #TypedError::name(&__error), + // reason = %#display_hack!(__error), "path guard forwarding" ); - #Outcome::Forward((#__data, #Status::UnprocessableEntity)) + #Outcome::Forward((#__data, Box::new(__error) as Box + '__r>)) }); // All dynamic parameters should be found if this function is being called; @@ -189,7 +191,10 @@ fn param_guard_decl(guard: &Guard) -> TokenStream { #_Some(__s) => match <#ty as #FromParam>::from_param(__s) { #_Ok(__v) => __v, #[allow(unreachable_code)] - #_Err(__error) => return #parse_error, + #_Err(__error) => { + let __error = #FromParamError::new(__s, __error); + return #parse_error; + } }, #_None => { ::rocket::trace::error!( @@ -200,7 +205,7 @@ fn param_guard_decl(guard: &Guard) -> TokenStream { #i ); - return #Outcome::Forward((#__data, #Status::InternalServerError)); + return #Outcome::Forward((#__data, Box::new(#Status::InternalServerError) as Box + '__r>)) } } }, @@ -208,7 +213,10 @@ fn param_guard_decl(guard: &Guard) -> TokenStream { match <#ty as #FromSegments>::from_segments(#__req.routed_segments(#i..)) { #_Ok(__v) => __v, #[allow(unreachable_code)] - #_Err(__error) => return #parse_error, + #_Err(__error) => { + let __error = #FromSegmentsError::new(#__req.routed_segments(#i..), __error); + return #parse_error; + } } }, }; @@ -219,7 +227,7 @@ fn param_guard_decl(guard: &Guard) -> TokenStream { fn data_guard_decl(guard: &Guard) -> TokenStream { let (ident, ty) = (guard.fn_ident.rocketized(), &guard.ty); - define_spanned_export!(ty.span() => __req, __data, display_hack, FromData, Outcome); + define_spanned_export!(ty.span() => __req, __data, FromData, Outcome, TypedError); quote_spanned! { ty.span() => let #ident: #ty = match <#ty as #FromData>::from_data(#__req, #__data).await { @@ -230,24 +238,24 @@ fn data_guard_decl(guard: &Guard) -> TokenStream { target: concat!("rocket::codegen::route::", module_path!()), parameter = stringify!(#ident), type_name = stringify!(#ty), - status = __e.code, + status = #TypedError::status(&__e).code, "data guard forwarding" ); - return #Outcome::Forward((__d, __e)); + return #Outcome::Forward((__d, Box::new(__e) as Box + '__r>)); } #[allow(unreachable_code)] - #Outcome::Error((__c, __e)) => { + #Outcome::Error(__e) => { ::rocket::trace::info!( name: "failure", target: concat!("rocket::codegen::route::", module_path!()), parameter = stringify!(#ident), type_name = stringify!(#ty), - reason = %#display_hack!(__e), + // reason = %#display_hack!(__e), "data guard failed" ); - return #Outcome::Error(__c); + return #Outcome::Error(Box::new(__e) as Box + '__r>); } }; } diff --git a/core/codegen/src/derive/mod.rs b/core/codegen/src/derive/mod.rs index ef77c74947..2eb80e5be6 100644 --- a/core/codegen/src/derive/mod.rs +++ b/core/codegen/src/derive/mod.rs @@ -4,3 +4,4 @@ pub mod from_form_field; pub mod responder; pub mod uri_display; pub mod from_param; +pub mod typed_error; diff --git a/core/codegen/src/derive/responder.rs b/core/codegen/src/derive/responder.rs index 736ee97d42..1ff402d711 100644 --- a/core/codegen/src/derive/responder.rs +++ b/core/codegen/src/derive/responder.rs @@ -1,3 +1,8 @@ +// use quote::ToTokens; +// use crate::{exports::{*, Status as _Status}, syn_ext::IdentExt}; +// use devise::{*, ext::{TypeExt, SpanDiagnosticExt}}; +// use crate::http_codegen::{ContentType, Status}; + use quote::ToTokens; use devise::{*, ext::{TypeExt, SpanDiagnosticExt}}; use proc_macro2::TokenStream; @@ -6,12 +11,12 @@ use crate::exports::*; use crate::syn_ext::{TypeExt as _, GenericsExt as _}; use crate::http_codegen::{ContentType, Status}; + #[derive(Debug, Default, FromMeta)] struct ItemAttr { content_type: Option>, - status: Option>, + status: Option>, } - #[derive(Default, FromMeta)] struct FieldAttr { ignore: bool, @@ -65,7 +70,9 @@ pub fn derive_responder(input: proc_macro::TokenStream) -> TokenStream { ) .inner_mapper(MapperBuild::new() .with_output(|_, output| quote! { - fn respond_to(self, __req: &'r #Request<'_>) -> #_response::Result<'o> { + fn respond_to(self, __req: &'r #Request<'_>) + -> #_response::Result<'r, 'o> + { #output } }) diff --git a/core/codegen/src/derive/typed_error.rs b/core/codegen/src/derive/typed_error.rs new file mode 100644 index 0000000000..45960993f3 --- /dev/null +++ b/core/codegen/src/derive/typed_error.rs @@ -0,0 +1,160 @@ +use devise::{*, ext::SpanDiagnosticExt}; +use proc_macro2::TokenStream; +use syn::{ConstParam, Index, LifetimeParam, Member, TypeParam}; + +use crate::exports::{*, Status as _Status}; +use crate::http_codegen::Status; + +#[derive(Debug, Default, FromMeta)] +struct ItemAttr { + status: Option>, + /// Option to generate a respond_to impl with the debug repr of the type + debug: bool, +} + +#[derive(Default, FromMeta)] +struct FieldAttr { + source: bool, +} + +pub fn derive_typed_error(input: proc_macro::TokenStream) -> TokenStream { + let impl_tokens = quote!(impl<'r> #TypedError<'r>); + let typed_error: TokenStream = DeriveGenerator::build_for(input.clone(), impl_tokens) + .support(Support::Struct | Support::Enum | Support::Lifetime | Support::Type) + .replace_generic(0, 0) + .type_bound_mapper(MapperBuild::new() + .input_map(|_, i| { + let bounds = i.generics().type_params().map(|g| &g.ident); + quote! { #(#bounds: ::std::marker::Send + ::std::marker::Sync + 'static,)* } + }) + ) + .validator(ValidatorBuild::new() + .input_validate(|_, i| match i.generics().lifetimes().count() > 1 { + true => Err(i.generics().span().error("only one lifetime is supported")), + false => Ok(()) + }) + ) + .inner_mapper(MapperBuild::new() + .with_output(|_, output| quote! { + #[allow(unused_variables)] + fn respond_to(&self, request: &'r #Request<'_>) + -> #_Result<#Response<'r>, #_Status> + { + #output + } + }) + .try_fields_map(|_, fields| { + let item = ItemAttr::one_from_attrs("error", fields.parent.attrs())?.unwrap_or(Default::default()); + let status = item.status.map_or(quote!(#_Status::InternalServerError), |m| quote!(#m)); + Ok(if item.debug { + quote! { + use #_response::Responder; + #_response::Debug(self) + .respond_to(request) + .map_err(|_| #status) + .map(|mut r| { r.set_status(#status); r }) + } + } else { + quote! { + #_Err(#status) + } + }) + }) + ) + .inner_mapper(MapperBuild::new() + .with_output(|_, output| quote! { + fn source(&'r self) -> #_Option<&'r (dyn #TypedError<'r> + 'r)> { + #output + } + }) + .try_fields_map(|_, fields| { + let mut source = None; + for field in fields.iter() { + if FieldAttr::one_from_attrs("error", &field.attrs)?.is_some_and(|a| a.source) { + if source.is_some() { + return Err(Diagnostic::spanned( + field.span(), + Level::Error, + "Only one field may be declared as `#[error(source)]`")); + } + if let FieldParent::Variant(_) = field.parent { + let name = field.match_ident(); + source = Some(quote! { #_Some(#name as &dyn #TypedError<'r>) }) + } else { + let span = field.field.span().into(); + let member = match field.ident { + Some(ref ident) => Member::Named(ident.clone()), + None => Member::Unnamed(Index { index: field.index as u32, span }) + }; + + source = Some(quote_spanned!( + span => #_Some(&self.#member as &dyn #TypedError<'r> + ))); + } + } + } + Ok(source.unwrap_or_else(|| quote! { #_None })) + }) + ) + .inner_mapper(MapperBuild::new() + .with_output(|_, output| quote! { + fn status(&self) -> #_Status { #output } + }) + .try_fields_map(|_, fields| { + let item = ItemAttr::one_from_attrs("error", fields.parent.attrs())?.unwrap_or(Default::default()); + let status = item.status.map_or(quote!(#_Status::InternalServerError), |m| quote!(#m)); + Ok(quote! { #status }) + }) + ) + .to_tokens(); + let impl_tokens = quote!(unsafe impl #_catcher::Transient); + let transient: TokenStream = DeriveGenerator::build_for(input, impl_tokens) + .support(Support::Struct | Support::Enum | Support::Lifetime | Support::Type) + .replace_generic(1, 0) + .type_bound_mapper(MapperBuild::new() + .input_map(|_, i| { + let bounds = i.generics().type_params().map(|g| &g.ident); + quote! { #(#bounds: 'static,)* } + }) + ) + .validator(ValidatorBuild::new() + .input_validate(|_, i| match i.generics().lifetimes().count() > 1 { + true => Err(i.generics().span().error("only one lifetime is supported")), + false => Ok(()) + }) + ) + .inner_mapper(MapperBuild::new() + .with_output(|_, output| quote! { + #output + }) + .input_map(|_, input| { + let name = input.ident(); + let args = input.generics() + .params + .iter() + .map(|g| { + match g { + syn::GenericParam::Lifetime(_) => quote!{ 'static }, + syn::GenericParam::Type(TypeParam { ident, .. }) => quote! { #ident }, + syn::GenericParam::Const(ConstParam { .. }) => todo!(), + } + }); + let trans = input.generics() + .lifetimes() + .map(|LifetimeParam { lifetime, .. }| quote!{#_catcher::Inv<#lifetime>}); + quote!{ + type Static = #name <#(#args)*>; + type Transience = (#(#trans,)*); + } + }) + ) + // TODO: hack to generate unsafe impl + .outer_mapper(MapperBuild::new() + .input_map(|_, _| quote!{ unsafe }) + ) + .to_tokens(); + quote!{ + #typed_error + #transient + } +} diff --git a/core/codegen/src/exports.rs b/core/codegen/src/exports.rs index 50470b46b9..de9b574321 100644 --- a/core/codegen/src/exports.rs +++ b/core/codegen/src/exports.rs @@ -102,6 +102,10 @@ define_exported_paths! { Route => ::rocket::Route, Catcher => ::rocket::Catcher, Status => ::rocket::http::Status, + TypedError => ::rocket::catcher::TypedError, + FromError => ::rocket::catcher::FromError, + FromParamError => ::rocket::request::FromParamError, + FromSegmentsError => ::rocket::request::FromSegmentsError, } macro_rules! define_spanned_export { diff --git a/core/codegen/src/lib.rs b/core/codegen/src/lib.rs index 3b41af8b38..731f0db79e 100644 --- a/core/codegen/src/lib.rs +++ b/core/codegen/src/lib.rs @@ -1016,6 +1016,13 @@ pub fn derive_responder(input: TokenStream) -> TokenStream { emit!(derive::responder::derive_responder(input)) } +/// Derive for the [`TypedError`] trait. +// TODO: Typed: Docs +#[proc_macro_derive(TypedError, attributes(error))] +pub fn derive_typed_error(input: TokenStream) -> TokenStream { + emit!(derive::typed_error::derive_typed_error(input)) +} + /// Derive for the [`UriDisplay`] trait. /// /// The [`UriDisplay`] derive can be applied to enums and structs. When diff --git a/core/codegen/tests/catcher.rs b/core/codegen/tests/catcher.rs index ddc59cb175..59a9b1b345 100644 --- a/core/codegen/tests/catcher.rs +++ b/core/codegen/tests/catcher.rs @@ -10,9 +10,9 @@ use rocket::local::blocking::Client; use rocket::http::Status; #[catch(404)] fn not_found_0() -> &'static str { "404-0" } -#[catch(404)] fn not_found_1(_: &Request<'_>) -> &'static str { "404-1" } -#[catch(404)] fn not_found_2(_: Status, _: &Request<'_>) -> &'static str { "404-2" } -#[catch(default)] fn all(_: Status, r: &Request<'_>) -> String { r.uri().to_string() } +#[catch(404)] fn not_found_1(_r: &Request<'_>) -> &'static str { "404-1" } +#[catch(404)] fn not_found_2(_s: Status, _r: &Request<'_>) -> &'static str { "404-2" } +#[catch(default)] fn all(_s: Status, r: &Request<'_>) -> String { r.uri().to_string() } #[test] fn test_simple_catchers() { @@ -37,10 +37,10 @@ fn test_simple_catchers() { } #[get("/")] fn forward(code: u16) -> Status { Status::new(code) } -#[catch(400)] fn forward_400(status: Status, _: &Request<'_>) -> String { status.code.to_string() } -#[catch(404)] fn forward_404(status: Status, _: &Request<'_>) -> String { status.code.to_string() } -#[catch(444)] fn forward_444(status: Status, _: &Request<'_>) -> String { status.code.to_string() } -#[catch(500)] fn forward_500(status: Status, _: &Request<'_>) -> String { status.code.to_string() } +#[catch(400)] fn forward_400(status: Status, _r: &Request<'_>) -> String { status.code.to_string() } +#[catch(404)] fn forward_404(status: Status, _r: &Request<'_>) -> String { status.code.to_string() } +#[catch(444)] fn forward_444(status: Status, _r: &Request<'_>) -> String { status.code.to_string() } +#[catch(500)] fn forward_500(status: Status, _r: &Request<'_>) -> String { status.code.to_string() } #[test] fn test_status_param() { diff --git a/core/codegen/tests/route-data.rs b/core/codegen/tests/route-data.rs index 1ecd0452f3..687b05a459 100644 --- a/core/codegen/tests/route-data.rs +++ b/core/codegen/tests/route-data.rs @@ -1,5 +1,6 @@ #[macro_use] extern crate rocket; +use rocket::response::status::BadRequest; use rocket::{Request, Data}; use rocket::local::blocking::Client; use rocket::data::{self, FromData}; @@ -17,7 +18,7 @@ struct Simple<'r>(&'r str); #[async_trait] impl<'r> FromData<'r> for Simple<'r> { - type Error = std::io::Error; + type Error = BadRequest; async fn from_data(req: &'r Request<'_>, data: Data<'r>) -> data::Outcome<'r, Self> { <&'r str>::from_data(req, data).await.map(Simple) diff --git a/core/lib/Cargo.toml b/core/lib/Cargo.toml index 848d53f6e0..156288fa17 100644 --- a/core/lib/Cargo.toml +++ b/core/lib/Cargo.toml @@ -128,8 +128,9 @@ features = ["provider-address-token-default", "provider-tls-rustls"] optional = true [dependencies.s2n-quic-h3] -git = "https://github.com/SergioBenitez/s2n-quic-h3.git" -rev = "6613956" +# git = "https://github.com/SergioBenitez/s2n-quic-h3.git" +# rev = "6613956" +path = "../../../s2n-quic-h3" optional = true [target.'cfg(unix)'.dependencies] diff --git a/core/lib/src/catcher/from_error.rs b/core/lib/src/catcher/from_error.rs index d0daaca320..a5202f6ef2 100644 --- a/core/lib/src/catcher/from_error.rs +++ b/core/lib/src/catcher/from_error.rs @@ -68,9 +68,9 @@ impl<'r, T: FromRequest<'r>> FromError<'r> for T { info!("Catcher guard error type: `{:?}`", e.name()); Err(e.status()) }, - Outcome::Forward(s) => { - info!(status = %s, "Catcher guard forwarding"); - Err(s) + Outcome::Forward(e) => { + info!("Catcher guard error type: `{:?}`", e.name()); + Err(e.status()) }, } } diff --git a/core/lib/src/catcher/types.rs b/core/lib/src/catcher/types.rs index b968e5f3b6..723a29cdb6 100644 --- a/core/lib/src/catcher/types.rs +++ b/core/lib/src/catcher/types.rs @@ -1,3 +1,5 @@ +use std::fmt; + use either::Either; use transient::{Any, CanRecoverFrom, Downcast, Transience}; use crate::{http::Status, response::status::Custom, Request, Response}; @@ -164,6 +166,19 @@ impl<'r> TypedError<'r> for rmp_serde::encode::Error { } #[cfg(feature = "msgpack")] impl<'r> TypedError<'r> for rmp_serde::decode::Error { + fn status(&self) -> Status { + match self { + rmp_serde::decode::Error::InvalidDataRead(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => Status::BadRequest, + | rmp_serde::decode::Error::TypeMismatch(..) + | rmp_serde::decode::Error::OutOfRange + | rmp_serde::decode::Error::LengthMismatch(..) => Status::UnprocessableEntity, + _ => Status::BadRequest, + } + } +} + +#[cfg(feature = "uuid")] +impl<'r> TypedError<'r> for uuid_::Error { fn status(&self) -> Status { Status::BadRequest } } @@ -209,6 +224,12 @@ impl<'r, L, R> TypedError<'r> for Either } } +impl fmt::Debug for dyn TypedError<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "<{} as TypedError>", self.name()) + } +} + // // TODO: This cannot be used as a bound on an untyped catcher to get any error type. // // This is mostly an implementation detail (and issue with double boxing) for // // the responder derive @@ -239,106 +260,103 @@ pub fn type_id_of<'r, T: TypedError<'r> + Transient + 'r>() -> (TypeId, &'static /// Downcast an error type to the underlying concrete type. Used by the `#[catch]` attribute. #[doc(hidden)] -pub fn downcast<'r, T>(v: Option<&'r dyn TypedError<'r>>) -> Option<&'r T> +pub fn downcast<'r, T>(v: &'r dyn TypedError<'r>) -> Option<&'r T> where T: TypedError<'r> + Transient + 'r, T::Transience: CanRecoverFrom>, { - // if v.is_none() { - // crate::trace::error!("No value to downcast from"); - // } - let v = v?; // crate::trace::error!("Downcasting error from {}", v.name()); v.as_any().downcast_ref() } -/// Upcasts a value to `Box>`, falling back to a default if it doesn't implement -/// `Error` -#[doc(hidden)] -#[macro_export] -macro_rules! resolve_typed_catcher { - ($T:expr) => ({ - #[allow(unused_imports)] - use $crate::catcher::resolution::{Resolve, DefaultTypeErase, ResolvedTypedError}; - - let inner = Resolve::new($T).cast(); - ResolvedTypedError { - name: inner.as_ref().ok().map(|e| e.name()), - val: inner, - } - }); -} +// TODO: Typed: This isn't used at all right now +// /// Upcasts a value to `Box>`, falling back to a default if it doesn't implement +// /// `Error` +// #[doc(hidden)] +// #[macro_export] +// macro_rules! resolve_typed_catcher { +// ($T:expr) => ({ +// #[allow(unused_imports)] +// use $crate::catcher::resolution::{Resolve, DefaultTypeErase, ResolvedTypedError}; + +// let inner = Resolve::new($T).cast(); +// ResolvedTypedError { +// name: inner.as_ref().ok().map(|e| e.name()), +// val: inner, +// } +// }); +// } -pub use resolve_typed_catcher; +// pub use resolve_typed_catcher; -pub mod resolution { - use std::marker::PhantomData; +// pub mod resolution { +// use std::marker::PhantomData; - use transient::{CanTranscendTo, Transient}; +// use transient::{CanTranscendTo, Transient}; - use super::*; +// use super::*; - /// The *magic*. - /// - /// `Resolve::item` for `T: Transient` is `::item`. - /// `Resolve::item` for `T: !Transient` is `DefaultTypeErase::item`. - /// - /// This _must_ be used as `Resolve:::item` for resolution to work. This - /// is a fun, static dispatch hack for "specialization" that works because - /// Rust prefers inherent methods over blanket trait impl methods. - pub struct Resolve<'r, T: 'r>(pub T, PhantomData<&'r ()>); - - impl<'r, T: 'r> Resolve<'r, T> { - pub fn new(val: T) -> Self { - Self(val, PhantomData) - } - } +// /// The *magic*. +// /// +// /// `Resolve::item` for `T: Transient` is `::item`. +// /// `Resolve::item` for `T: !Transient` is `DefaultTypeErase::item`. +// /// +// /// This _must_ be used as `Resolve:::item` for resolution to work. This +// /// is a fun, static dispatch hack for "specialization" that works because +// /// Rust prefers inherent methods over blanket trait impl methods. +// pub struct Resolve<'r, T: 'r>(pub T, PhantomData<&'r ()>); - /// Fallback trait "implementing" `Transient` for all types. This is what - /// Rust will resolve `Resolve::item` to when `T: !Transient`. - pub trait DefaultTypeErase<'r>: Sized { - const SPECIALIZED: bool = false; +// impl<'r, T: 'r> Resolve<'r, T> { +// pub fn new(val: T) -> Self { +// Self(val, PhantomData) +// } +// } - fn cast(self) -> Result>, Self> { Err(self) } - } +// /// Fallback trait "implementing" `Transient` for all types. This is what +// /// Rust will resolve `Resolve::item` to when `T: !Transient`. +// pub trait DefaultTypeErase<'r>: Sized { +// const SPECIALIZED: bool = false; + +// fn cast(self) -> Result>, Self> { Err(self) } +// } - impl<'r, T: 'r> DefaultTypeErase<'r> for Resolve<'r, T> {} +// impl<'r, T: 'r> DefaultTypeErase<'r> for Resolve<'r, T> {} - /// "Specialized" "implementation" of `Transient` for `T: Transient`. This is - /// what Rust will resolve `Resolve::item` to when `T: Transient`. - impl<'r, T: TypedError<'r> + Transient> Resolve<'r, T> - where T::Transience: CanTranscendTo> - { - pub const SPECIALIZED: bool = true; +// /// "Specialized" "implementation" of `Transient` for `T: Transient`. This is +// /// what Rust will resolve `Resolve::item` to when `T: Transient`. +// impl<'r, T: TypedError<'r> + Transient> Resolve<'r, T> +// where T::Transience: CanTranscendTo> +// { +// pub const SPECIALIZED: bool = true; - pub fn cast(self) -> Result>, Self> { Ok(Box::new(self.0)) } - } +// pub fn cast(self) -> Result>, Self> { Ok(Box::new(self.0)) } +// } - // TODO: These extensions maybe useful, but so far not really - // // Box can be upcast without double boxing? - // impl<'r> Resolve<'r, Box>> { - // pub const SPECIALIZED: bool = true; - - // pub fn cast(self) -> Result>, Self> { Ok(self.0) } - // } - - // Ideally, we should be able to handle this case, but we can't, since we don't own `Either` - // impl<'r, A, B> Resolve<'r, Either> - // where A: TypedError<'r> + Transient, - // A::Transience: CanTranscendTo>, - // B: TypedError<'r> + Transient, - // B::Transience: CanTranscendTo>, - // { - // pub const SPECIALIZED: bool = true; - - // pub fn cast(self) -> Result>, Self> { Ok(Box::new(self.0)) } - // } - - /// Wrapper type to hold the return type of `resolve_typed_catcher`. - #[doc(hidden)] - pub struct ResolvedTypedError<'r, T> { - /// The return value from `TypedError::name()`, if Some - pub name: Option<&'static str>, - /// The upcast error, if it supports it - pub val: Result + 'r>, Resolve<'r, T>>, - } -} +// // TODO: These extensions maybe useful, but so far not really +// // // Box can be upcast without double boxing? +// // impl<'r> Resolve<'r, Box>> { +// // pub const SPECIALIZED: bool = true; + +// // pub fn cast(self) -> Result>, Self> { Ok(self.0) } +// // } + +// // Ideally, we should be able to handle this case, but we can't, since we don't own `Either` +// // impl<'r, A, B> Resolve<'r, Either> +// // where A: TypedError<'r> + Transient, +// // A::Transience: CanTranscendTo>, +// // B: TypedError<'r> + Transient, +// // B::Transience: CanTranscendTo>, +// // { +// // pub const SPECIALIZED: bool = true; + +// // pub fn cast(self) -> Result>, Self> { Ok(Box::new(self.0)) } +// // } + +// /// Wrapper type to hold the return type of `resolve_typed_catcher`. +// #[doc(hidden)] +// pub struct ResolvedTypedError<'r, T> { +// /// The return value from `TypedError::name()`, if Some +// pub name: Option<&'static str>, +// /// The upcast error, if it supports it +// pub val: Result + 'r>, Resolve<'r, T>>, +// } +// } diff --git a/core/lib/src/data/capped.rs b/core/lib/src/data/capped.rs index 094a10d5b4..b254a21675 100644 --- a/core/lib/src/data/capped.rs +++ b/core/lib/src/data/capped.rs @@ -260,10 +260,10 @@ macro_rules! impl_strict_from_data_from_capped { Success(p) if p.is_complete() => Success(p.into_inner()), Success(_) => { let e = Error::new(UnexpectedEof, "data limit exceeded"); - Error((Status::BadRequest, e.into())) + Error(e.into()) }, Forward(d) => Forward(d), - Error((s, e)) => Error((s, e)), + Error(e) => Error(e), } } } diff --git a/core/lib/src/data/from_data.rs b/core/lib/src/data/from_data.rs index 3eec28932d..18d52aac51 100644 --- a/core/lib/src/data/from_data.rs +++ b/core/lib/src/data/from_data.rs @@ -1,13 +1,14 @@ -use crate::http::{RawStr, Status}; +use crate::catcher::TypedError; +use crate::http::RawStr; use crate::request::{Request, local_cache}; use crate::data::{Data, Limits}; -use crate::outcome::{self, IntoOutcome, try_outcome, Outcome::*}; +use crate::outcome::{self, try_outcome, Outcome::*}; /// Type alias for the `Outcome` of [`FromData`]. /// /// [`FromData`]: crate::data::FromData pub type Outcome<'r, T, E = >::Error> - = outcome::Outcome, Status)>; + = outcome::Outcome, E)>; /// Trait implemented by data guards to derive a value from request body data. /// @@ -303,7 +304,7 @@ pub type Outcome<'r, T, E = >::Error> #[crate::async_trait] pub trait FromData<'r>: Sized { /// The associated error to be returned when the guard fails. - type Error: Send + std::fmt::Debug; + type Error: TypedError<'r> + 'r; /// Asynchronously validates, parses, and converts an instance of `Self` /// from the incoming request body data. @@ -315,14 +316,18 @@ pub trait FromData<'r>: Sized { } use crate::data::Capped; +use crate::response::status::BadRequest; #[crate::async_trait] impl<'r> FromData<'r> for Capped { - type Error = std::io::Error; + type Error = BadRequest; async fn from_data(req: &'r Request<'_>, data: Data<'r>) -> Outcome<'r, Self> { let limit = req.limits().get("string").unwrap_or(Limits::STRING); - data.open(limit).into_string().await.or_error(Status::BadRequest) + match data.open(limit).into_string().await { + Ok(v) => Outcome::Success(v), + Err(e) => Outcome::Error(BadRequest(e)), + } } } @@ -330,7 +335,7 @@ impl_strict_from_data_from_capped!(String); #[crate::async_trait] impl<'r> FromData<'r> for Capped<&'r str> { - type Error = std::io::Error; + type Error = BadRequest; async fn from_data(req: &'r Request<'_>, data: Data<'r>) -> Outcome<'r, Self> { let capped = try_outcome!(>::from_data(req, data).await); @@ -343,7 +348,7 @@ impl_strict_from_data_from_capped!(&'r str); #[crate::async_trait] impl<'r> FromData<'r> for Capped<&'r RawStr> { - type Error = std::io::Error; + type Error = BadRequest; async fn from_data(req: &'r Request<'_>, data: Data<'r>) -> Outcome<'r, Self> { let capped = try_outcome!(>::from_data(req, data).await); @@ -356,7 +361,7 @@ impl_strict_from_data_from_capped!(&'r RawStr); #[crate::async_trait] impl<'r> FromData<'r> for Capped> { - type Error = std::io::Error; + type Error = BadRequest; async fn from_data(req: &'r Request<'_>, data: Data<'r>) -> Outcome<'r, Self> { let capped = try_outcome!(>::from_data(req, data).await); @@ -368,7 +373,7 @@ impl_strict_from_data_from_capped!(std::borrow::Cow<'_, str>); #[crate::async_trait] impl<'r> FromData<'r> for Capped<&'r [u8]> { - type Error = std::io::Error; + type Error = BadRequest; async fn from_data(req: &'r Request<'_>, data: Data<'r>) -> Outcome<'r, Self> { let capped = try_outcome!(>>::from_data(req, data).await); @@ -381,11 +386,14 @@ impl_strict_from_data_from_capped!(&'r [u8]); #[crate::async_trait] impl<'r> FromData<'r> for Capped> { - type Error = std::io::Error; + type Error = BadRequest; async fn from_data(req: &'r Request<'_>, data: Data<'r>) -> Outcome<'r, Self> { let limit = req.limits().get("bytes").unwrap_or(Limits::BYTES); - data.open(limit).into_bytes().await.or_error(Status::BadRequest) + match data.open(limit).into_bytes().await { + Ok(v) => Outcome::Success(v), + Err(e) => Outcome::Error(BadRequest(e)) + } } } @@ -402,12 +410,12 @@ impl<'r> FromData<'r> for Data<'r> { #[crate::async_trait] impl<'r, T: FromData<'r> + 'r> FromData<'r> for Result { - type Error = std::convert::Infallible; + type Error = T::Error; async fn from_data(req: &'r Request<'_>, data: Data<'r>) -> Outcome<'r, Self> { match T::from_data(req, data).await { Success(v) => Success(Ok(v)), - Error((_, e)) => Success(Err(e)), + Error(e) => Success(Err(e)), Forward(d) => Forward(d), } } diff --git a/core/lib/src/form/error.rs b/core/lib/src/form/error.rs index b2c2c06e30..5ae9228a34 100644 --- a/core/lib/src/form/error.rs +++ b/core/lib/src/form/error.rs @@ -8,7 +8,9 @@ use std::net::AddrParseError; use std::borrow::Cow; use serde::{Serialize, ser::{Serializer, SerializeStruct}}; +use transient::Transient; +use crate::catcher::TypedError; use crate::http::Status; use crate::form::name::{NameBuf, Name}; use crate::data::ByteUnit; @@ -54,10 +56,16 @@ use crate::data::ByteUnit; /// Ok(i) /// } /// ``` -#[derive(Default, Debug, PartialEq, Serialize)] +#[derive(Default, Debug, PartialEq, Serialize, Transient)] #[serde(transparent)] pub struct Errors<'v>(Vec>); +impl<'r> TypedError<'r> for Errors<'r> { + fn status(&self) -> Status { + self.status() + } +} + /// A form error, potentially tied to a specific form field. /// /// An `Error` is returned by [`FromForm`], [`FromFormField`], and [`validate`] @@ -196,7 +204,7 @@ pub enum ErrorKind<'v> { Unknown, /// A custom error occurred. Status defaults to /// [`Status::UnprocessableEntity`] if one is not directly specified. - Custom(Status, Box), + Custom(Status, Box), /// An error while parsing a multipart form occurred. Multipart(multer::Error), /// A string was invalid UTF-8. @@ -213,6 +221,8 @@ pub enum ErrorKind<'v> { Addr(AddrParseError), /// An I/O error occurred. Io(io::Error), + /// An Unsupported media type + UnsupportedMediaType, } /// The erroneous form entity or form component. @@ -451,9 +461,9 @@ impl<'v> Error<'v> { /// } /// ``` pub fn custom(error: E) -> Self - where E: std::error::Error + Send + 'static + where E: std::error::Error + Send + Sync + 'static { - (Box::new(error) as Box).into() + (Box::new(error) as Box).into() } /// Creates a new `Error` with `ErrorKind::Validation` and message `msg`. @@ -732,6 +742,7 @@ impl<'v> Error<'v> { Unknown => Status::InternalServerError, Io(_) if self.entity == Entity::Form => Status::BadRequest, Custom(status, _) => status, + UnsupportedMediaType => Status::UnsupportedMediaType, _ => Status::UnprocessableEntity } } @@ -866,6 +877,7 @@ impl fmt::Display for ErrorKind<'_> { ErrorKind::Float(e) => write!(f, "invalid float: {}", e)?, ErrorKind::Addr(e) => write!(f, "invalid address: {}", e)?, ErrorKind::Io(e) => write!(f, "i/o error: {}", e)?, + ErrorKind::UnsupportedMediaType => write!(f, "unsupported media type")?, } Ok(()) @@ -900,7 +912,8 @@ impl crate::http::ext::IntoOwned for ErrorKind<'_> { .map(|s| Cow::Owned(s.to_string())) .collect::>() .into() - } + }, + UnsupportedMediaType => UnsupportedMediaType, } } } @@ -966,14 +979,14 @@ impl<'a, 'v: 'a, const N: usize> From<&'static [Cow<'v, str>; N]> for ErrorKind< } } -impl<'a> From> for ErrorKind<'a> { - fn from(e: Box) -> Self { +impl<'a> From> for ErrorKind<'a> { + fn from(e: Box) -> Self { ErrorKind::Custom(Status::UnprocessableEntity, e) } } -impl<'a> From<(Status, Box)> for ErrorKind<'a> { - fn from((status, e): (Status, Box)) -> Self { +impl<'a> From<(Status, Box)> for ErrorKind<'a> { + fn from((status, e): (Status, Box)) -> Self { ErrorKind::Custom(status, e) } } @@ -1042,6 +1055,7 @@ impl Entity { | ErrorKind::Unknown | ErrorKind::Unexpected => Entity::Field, + | ErrorKind::UnsupportedMediaType | ErrorKind::Multipart(_) | ErrorKind::Io(_) => Entity::Form, } diff --git a/core/lib/src/form/form.rs b/core/lib/src/form/form.rs index bcb2a2414b..1dc002f50c 100644 --- a/core/lib/src/form/form.rs +++ b/core/lib/src/form/form.rs @@ -333,7 +333,7 @@ impl<'r, T: FromForm<'r>> FromData<'r> for Form { match T::finalize(context) { Ok(value) => Outcome::Success(Form(value)), - Err(e) => Outcome::Error((e.status(), e)), + Err(e) => Outcome::Error(e), } } } diff --git a/core/lib/src/form/from_form_field.rs b/core/lib/src/form/from_form_field.rs index 2a7f5ab22b..6b1eecf988 100644 --- a/core/lib/src/form/from_form_field.rs +++ b/core/lib/src/form/from_form_field.rs @@ -297,7 +297,7 @@ impl<'v> FromFormField<'v> for Capped<&'v str> { match as FromData>::from_data(f.request, f.data).await { Outcome::Success(p) => Ok(p), - Outcome::Error((_, e)) => Err(e)?, + Outcome::Error(e) => Err(e.0)?, Outcome::Forward(..) => { Err(Error::from(ErrorKind::Unexpected).with_entity(Entity::DataField))? } @@ -318,7 +318,7 @@ impl<'v> FromFormField<'v> for Capped { match as FromData>::from_data(f.request, f.data).await { Outcome::Success(p) => Ok(p), - Outcome::Error((_, e)) => Err(e)?, + Outcome::Error(e) => Err(e.0)?, Outcome::Forward(..) => { Err(Error::from(ErrorKind::Unexpected).with_entity(Entity::DataField))? } @@ -354,7 +354,7 @@ impl<'v> FromFormField<'v> for Capped<&'v [u8]> { match as FromData>::from_data(f.request, f.data).await { Outcome::Success(p) => Ok(p), - Outcome::Error((_, e)) => Err(e)?, + Outcome::Error(e) => Err(e.0)?, Outcome::Forward(..) => { Err(Error::from(ErrorKind::Unexpected).with_entity(Entity::DataField))? } @@ -412,7 +412,7 @@ static DATE_TIME_FMT2: &[FormatItem<'_>] = impl<'v> FromFormField<'v> for Date { fn from_value(field: ValueField<'v>) -> Result<'v, Self> { let date = Self::parse(field.value, &DATE_FMT) - .map_err(|e| Box::new(e) as Box)?; + .map_err(|e| Box::new(e) as Box)?; Ok(date) } @@ -422,7 +422,7 @@ impl<'v> FromFormField<'v> for Time { fn from_value(field: ValueField<'v>) -> Result<'v, Self> { let time = Self::parse(field.value, &TIME_FMT1) .or_else(|_| Self::parse(field.value, &TIME_FMT2)) - .map_err(|e| Box::new(e) as Box)?; + .map_err(|e| Box::new(e) as Box)?; Ok(time) } @@ -432,7 +432,7 @@ impl<'v> FromFormField<'v> for PrimitiveDateTime { fn from_value(field: ValueField<'v>) -> Result<'v, Self> { let dt = Self::parse(field.value, &DATE_TIME_FMT1) .or_else(|_| Self::parse(field.value, &DATE_TIME_FMT2)) - .map_err(|e| Box::new(e) as Box)?; + .map_err(|e| Box::new(e) as Box)?; Ok(dt) } diff --git a/core/lib/src/form/parser.rs b/core/lib/src/form/parser.rs index be9130a4e1..980b02e0f1 100644 --- a/core/lib/src/form/parser.rs +++ b/core/lib/src/form/parser.rs @@ -3,7 +3,7 @@ use either::Either; use crate::request::{Request, local_cache_once}; use crate::data::{Data, Limits, Outcome}; -use crate::http::{RawStr, Status}; +use crate::http::RawStr; use crate::form::prelude::*; type Result<'r, T> = std::result::Result>; @@ -35,12 +35,14 @@ impl<'r, 'i> Parser<'r, 'i> { let parser = match req.content_type() { Some(c) if c.is_form() => Self::from_form(req, data).await, Some(c) if c.is_form_data() => Self::from_multipart(req, data).await, - _ => return Outcome::Forward((data, Status::UnsupportedMediaType)), + _ => return Outcome::Forward((data, Error { + name: None, value: None, kind: ErrorKind::UnsupportedMediaType, entity: Entity::Form, + }.into())), }; match parser { Ok(storage) => Outcome::Success(storage), - Err(e) => Outcome::Error((e.status(), e.into())) + Err(e) => Outcome::Error(e.into()), } } diff --git a/core/lib/src/fs/temp_file.rs b/core/lib/src/fs/temp_file.rs index 0464dd46ce..32e0cda90d 100644 --- a/core/lib/src/fs/temp_file.rs +++ b/core/lib/src/fs/temp_file.rs @@ -1,11 +1,11 @@ use std::{io, mem}; use std::path::{PathBuf, Path}; +use crate::response::status::BadRequest; use crate::Request; -use crate::http::{ContentType, Status}; +use crate::http::ContentType; use crate::data::{self, FromData, Data, Capped, N, Limits}; use crate::form::{FromFormField, ValueField, DataField, error::Errors}; -use crate::outcome::IntoOutcome; use crate::fs::FileName; use tokio::task; @@ -551,7 +551,7 @@ impl<'v> FromFormField<'v> for Capped> { #[crate::async_trait] impl<'r> FromData<'r> for Capped> { - type Error = io::Error; + type Error = BadRequest; async fn from_data(req: &'r Request<'_>, data: Data<'r>) -> data::Outcome<'r, Self> { let has_form = |ty: &ContentType| ty.is_form_data() || ty.is_form(); @@ -562,9 +562,10 @@ impl<'r> FromData<'r> for Capped> { Perhaps you meant to use `Form>` instead?"); } - TempFile::from(req, data, None, req.content_type().cloned()) - .await - .or_error(Status::BadRequest) + match TempFile::from(req, data, None, req.content_type().cloned()).await { + Ok(f) => data::Outcome::Success(f), + Err(e) => data::Outcome::Error(BadRequest(e)), + } } } diff --git a/core/lib/src/mtls/certificate.rs b/core/lib/src/mtls/certificate.rs index fe887d434d..b8645ea7df 100644 --- a/core/lib/src/mtls/certificate.rs +++ b/core/lib/src/mtls/certificate.rs @@ -2,7 +2,6 @@ use ref_cast::RefCast; use crate::mtls::{x509, oid, bigint, Name, Result, Error}; use crate::request::{Request, FromRequest, Outcome}; -use crate::http::Status; /// A request guard for validated, verified client certificates. /// @@ -117,7 +116,7 @@ impl<'r> FromRequest<'r> for Certificate<'r> { let certs: Outcome<_, Error> = req.connection .peer_certs .as_ref() - .or_forward(Status::Unauthorized); + .or_forward(Error::Empty); let chain = try_outcome!(certs); Certificate::parse(chain.inner()).or_error(()) diff --git a/core/lib/src/outcome.rs b/core/lib/src/outcome.rs index 6db291e151..e72aee71aa 100644 --- a/core/lib/src/outcome.rs +++ b/core/lib/src/outcome.rs @@ -88,8 +88,7 @@ use crate::catcher::TypedError; use crate::{route, request, response}; -use crate::data::{self, Data, FromData}; -use crate::http::Status; +use crate::data::Data; use self::Outcome::*; @@ -747,30 +746,9 @@ impl IntoOutcome> for Option { } } -impl<'r, T: FromData<'r>> IntoOutcome> for Result { - type Error = Status; - type Forward = (Data<'r>, Status); - - #[inline] - fn or_error(self, error: Status) -> data::Outcome<'r, T> { - match self { - Ok(val) => Success(val), - Err(err) => Error((error, err)) - } - } - - #[inline] - fn or_forward(self, (data, forward): (Data<'r>, Status)) -> data::Outcome<'r, T> { - match self { - Ok(val) => Success(val), - Err(_) => Forward((data, forward)) - } - } -} - impl IntoOutcome> for Result { type Error = (); - type Forward = Status; + type Forward = E; #[inline] fn or_error(self, _: ()) -> request::Outcome { @@ -781,7 +759,7 @@ impl IntoOutcome> for Result { } #[inline] - fn or_forward(self, status: Status) -> request::Outcome { + fn or_forward(self, status: E) -> request::Outcome { match self { Ok(val) => Success(val), Err(_) => Forward(status) diff --git a/core/lib/src/request/from_param.rs b/core/lib/src/request/from_param.rs index f639d38efb..733e38dd85 100644 --- a/core/lib/src/request/from_param.rs +++ b/core/lib/src/request/from_param.rs @@ -1,9 +1,13 @@ +use std::fmt; use std::str::FromStr; use std::path::PathBuf; +use transient::{CanTranscendTo, Inv, Transient}; + +use crate::catcher::TypedError; use crate::error::Empty; use crate::either::Either; -use crate::http::uri::{Segments, error::PathError, fmt::Path}; +use crate::http::{uri::{Segments, error::PathError, fmt::Path}, Status}; /// Trait to convert a dynamic path segment string to a concrete value. /// @@ -12,11 +16,6 @@ use crate::http::uri::{Segments, error::PathError, fmt::Path}; /// a dynamic segment `` where `param` has some type `T` that implements /// `FromParam`, `T::from_param` will be called. /// -/// # Deriving -/// -/// The `FromParam` trait can be automatically derived for C-like enums. See -/// [`FromParam` derive](macro@rocket::FromParam) for more information. -/// /// # Forwarding /// /// If the conversion fails, the incoming request will be forwarded to the next @@ -217,6 +216,67 @@ pub trait FromParam<'a>: Sized { fn from_param(param: &'a str) -> Result; } +/// The error type produced by every `FromParam` implementation. +/// This can be used to obtain both the error returned by the param, +/// as well as the raw parameter value. +pub struct FromParamError<'a, T> { + pub raw: &'a str, + pub error: T, + _priv: (), +} + +impl<'a, T> FromParamError<'a, T> { + /// Unstable constructor used by codegen + #[doc(hidden)] + pub fn new(raw: &'a str, error: T) -> Self { + Self { + raw, + error, + _priv: (), + } + } +} + +impl<'a, T: TypedError<'a>> TypedError<'a> for FromParamError<'a, T> + where Self: Transient> +{ + fn respond_to(&self, request: &'a crate::Request<'_>) -> Result, Status> { + self.error.respond_to(request) + } + + fn source(&'a self) -> Option<&'a (dyn TypedError<'a> + 'a)> { + Some(&self.error) + } + + fn status(&self) -> Status { + Status::UnprocessableEntity + } +} + +// SAFETY: Since `T` (and &'a str) `CanTransendTo` `Inv<'a>`, it is safe to +// transend `FromParamError<'a, T>` to `Inv<'a>` +unsafe impl<'a, T: Transient + 'a> Transient for FromParamError<'a, T> + where T::Transience: CanTranscendTo>, +{ + type Static = FromParamError<'static, T::Static>; + type Transience = Inv<'a>; +} + +impl<'a, T: fmt::Debug> fmt::Debug for FromParamError<'a, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("FromParamError") + .field("raw", &self.raw) + .field("error", &self.error) + .finish_non_exhaustive() + } +} + +impl<'a, T: fmt::Display> fmt::Display for FromParamError<'a, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.error.fmt(f) + } +} + impl<'a> FromParam<'a> for &'a str { type Error = Empty; @@ -338,6 +398,68 @@ pub trait FromSegments<'r>: Sized { fn from_segments(segments: Segments<'r, Path>) -> Result; } +/// The error type produced by every `FromParam` implementation. +/// This can be used to obtain both the error returned by the param, +/// as well as the raw parameter value. +pub struct FromSegmentsError<'a, T> { + pub raw: Segments<'a, Path>, + pub error: T, + _priv: (), +} + +impl<'a, T> FromSegmentsError<'a, T> { + /// Unstable constructor used by codegen + #[doc(hidden)] + pub fn new(raw: Segments<'a, Path>, error: T) -> Self { + Self { + raw, + error, + _priv: (), + } + } +} + + +impl<'a, T: TypedError<'a>> TypedError<'a> for FromSegmentsError<'a, T> + where Self: Transient> +{ + fn respond_to(&self, request: &'a crate::Request<'_>) -> Result, Status> { + self.error.respond_to(request) + } + + fn source(&'a self) -> Option<&'a (dyn TypedError<'a> + 'a)> { + Some(&self.error) + } + + fn status(&self) -> Status { + Status::UnprocessableEntity + } +} + +// SAFETY: Since `T` (and Segments<'a, Path>) `CanTransendTo` `Inv<'a>`, it is safe to +// transend `FromSegmentsError<'a, T>` to `Inv<'a>` +unsafe impl<'a, T: Transient + 'a> Transient for FromSegmentsError<'a, T> + where T::Transience: CanTranscendTo>, +{ + type Static = FromSegmentsError<'static, T::Static>; + type Transience = Inv<'a>; +} + +impl<'a, T: fmt::Debug> fmt::Debug for FromSegmentsError<'a, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("FromSegmentsError") + .field("raw", &self.raw) + .field("error", &self.error) + .finish_non_exhaustive() + } +} + +impl<'a, T: fmt::Display> fmt::Display for FromSegmentsError<'a, T> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.error.fmt(f) + } +} + impl<'r> FromSegments<'r> for Segments<'r, Path> { type Error = std::convert::Infallible; diff --git a/core/lib/src/request/from_request.rs b/core/lib/src/request/from_request.rs index 432cedee68..b97b0365ed 100644 --- a/core/lib/src/request/from_request.rs +++ b/core/lib/src/request/from_request.rs @@ -6,11 +6,11 @@ use crate::{Request, Route}; use crate::outcome::{self, IntoOutcome, Outcome::*}; use crate::http::uri::{Host, Origin}; -use crate::http::{Status, ContentType, Accept, Method, ProxyProto, CookieJar}; +use crate::http::{ContentType, Accept, Method, ProxyProto, CookieJar}; use crate::listener::Endpoint; /// Type alias for the `Outcome` of a `FromRequest` conversion. -pub type Outcome = outcome::Outcome; +pub type Outcome = outcome::Outcome; /// Trait implemented by request guards to derive a value from incoming /// requests. @@ -411,24 +411,24 @@ impl<'r> FromRequest<'r> for &'r Origin<'r> { #[crate::async_trait] impl<'r> FromRequest<'r> for &'r Host<'r> { - type Error = Infallible; + type Error = (); - async fn from_request(request: &'r Request<'_>) -> Outcome { + async fn from_request(request: &'r Request<'_>) -> Outcome { match request.host() { Some(host) => Success(host), - None => Forward(Status::InternalServerError) + None => Forward(()) } } } #[crate::async_trait] impl<'r> FromRequest<'r> for &'r Route { - type Error = Infallible; + type Error = (); - async fn from_request(request: &'r Request<'_>) -> Outcome { + async fn from_request(request: &'r Request<'_>) -> Outcome { match request.route() { Some(route) => Success(route), - None => Forward(Status::InternalServerError) + None => Forward(()) } } } @@ -444,78 +444,78 @@ impl<'r> FromRequest<'r> for &'r CookieJar<'r> { #[crate::async_trait] impl<'r> FromRequest<'r> for &'r Accept { - type Error = Infallible; + type Error = (); - async fn from_request(request: &'r Request<'_>) -> Outcome { + async fn from_request(request: &'r Request<'_>) -> Outcome { match request.accept() { Some(accept) => Success(accept), - None => Forward(Status::InternalServerError) + None => Forward(()) } } } #[crate::async_trait] impl<'r> FromRequest<'r> for &'r ContentType { - type Error = Infallible; + type Error = (); - async fn from_request(request: &'r Request<'_>) -> Outcome { + async fn from_request(request: &'r Request<'_>) -> Outcome { match request.content_type() { Some(content_type) => Success(content_type), - None => Forward(Status::InternalServerError) + None => Forward(()) } } } #[crate::async_trait] impl<'r> FromRequest<'r> for IpAddr { - type Error = Infallible; + type Error = (); - async fn from_request(request: &'r Request<'_>) -> Outcome { + async fn from_request(request: &'r Request<'_>) -> Outcome { match request.client_ip() { Some(addr) => Success(addr), - None => Forward(Status::InternalServerError) + None => Forward(()) } } } #[crate::async_trait] impl<'r> FromRequest<'r> for ProxyProto<'r> { - type Error = std::convert::Infallible; + type Error = (); async fn from_request(request: &'r Request<'_>) -> Outcome { - request.proxy_proto().or_forward(Status::InternalServerError) + request.proxy_proto().or_forward(()) } } #[crate::async_trait] impl<'r> FromRequest<'r> for &'r Endpoint { - type Error = Infallible; + type Error = (); - async fn from_request(request: &'r Request<'_>) -> Outcome { - request.remote().or_forward(Status::InternalServerError) + async fn from_request(request: &'r Request<'_>) -> Outcome { + request.remote().or_forward(()) } } #[crate::async_trait] impl<'r> FromRequest<'r> for SocketAddr { - type Error = Infallible; + type Error = (); - async fn from_request(request: &'r Request<'_>) -> Outcome { + async fn from_request(request: &'r Request<'_>) -> Outcome { request.remote() .and_then(|r| r.socket_addr()) - .or_forward(Status::InternalServerError) + .or_forward(()) } } #[crate::async_trait] impl<'r, T: FromRequest<'r>> FromRequest<'r> for Result { - type Error = Infallible; + type Error = T::Error; - async fn from_request(request: &'r Request<'_>) -> Outcome { + async fn from_request(request: &'r Request<'_>) -> Outcome { match T::from_request(request).await { Success(val) => Success(Ok(val)), Error(e) => Success(Err(e)), - Forward(status) => Forward(status), + Forward(forward) => Forward(forward), } } } diff --git a/core/lib/src/request/mod.rs b/core/lib/src/request/mod.rs index 48ac79c7bd..fcc3531fe4 100644 --- a/core/lib/src/request/mod.rs +++ b/core/lib/src/request/mod.rs @@ -10,7 +10,7 @@ mod tests; pub use self::request::Request; pub use self::from_request::{FromRequest, Outcome}; -pub use self::from_param::{FromParam, FromSegments}; +pub use self::from_param::{FromParam, FromParamError, FromSegments, FromSegmentsError}; #[doc(hidden)] pub use rocket_codegen::FromParam; diff --git a/core/lib/src/response/debug.rs b/core/lib/src/response/debug.rs index e3295bb5a6..c5c7f2f51b 100644 --- a/core/lib/src/response/debug.rs +++ b/core/lib/src/response/debug.rs @@ -1,7 +1,12 @@ +use transient::Static; + +use crate::catcher::TypedError; use crate::request::Request; use crate::response::{self, Responder}; use crate::http::Status; +use super::Response; + /// Debug prints the internal value before forwarding to the 500 error catcher. /// /// This value exists primarily to allow handler return types that would not @@ -82,6 +87,17 @@ impl<'r, E: std::fmt::Debug> Responder<'r, 'static> for Debug { } } +// TODO: Typed: This is a stop-gap measure to allow any 'static type to be a `TypedError` +impl<'r, E: std::fmt::Debug + Send + Sync + 'static> TypedError<'r> for Debug { + fn respond_to(&self, _: &'r Request<'_>) -> Result, Status> { + let type_name = std::any::type_name::(); + info!(type_name, value = ?self.0, "debug response (500)"); + Err(Status::InternalServerError) + } +} + +impl Static for Debug { } + /// Prints a warning with the error and forwards to the `500` error catcher. impl<'r> Responder<'r, 'static> for std::io::Error { fn respond_to(self, _: &'r Request<'_>) -> response::Result<'r, 'static> { diff --git a/core/lib/src/response/mod.rs b/core/lib/src/response/mod.rs index dfd4257c07..c2bb6505e8 100644 --- a/core/lib/src/response/mod.rs +++ b/core/lib/src/response/mod.rs @@ -38,4 +38,4 @@ pub use self::flash::Flash; pub use self::debug::Debug; /// Type alias for the `Result` of a [`Responder::respond_to()`] call. -pub type Result<'r, 'o> = std::result::Result, Box>>; +pub type Result<'r, 'o> = std::result::Result, Box + 'r>>; diff --git a/core/lib/src/response/responder.rs b/core/lib/src/response/responder.rs index 54d878c814..6079aa8349 100644 --- a/core/lib/src/response/responder.rs +++ b/core/lib/src/response/responder.rs @@ -2,6 +2,7 @@ use std::fs::File; use std::io::Cursor; use std::sync::Arc; +use crate::catcher::TypedError; use crate::http::{Status, ContentType, StatusClass}; use crate::response::{self, Response}; use crate::request::Request; @@ -494,13 +495,14 @@ impl<'r, 'o: 'r, R: Responder<'r, 'o>> Responder<'r, 'o> for Option { /// Responds with the wrapped `Responder` in `self`, whether it is `Ok` or /// `Err`. -impl<'r, 'o: 'r, 't: 'o, 'e: 'o, T, E> Responder<'r, 'o> for Result - where T: Responder<'r, 't>, E: Responder<'r, 'e> +impl<'r, 'o: 'r, T, E> Responder<'r, 'o> for Result + where T: Responder<'r, 'o>, E: TypedError<'r> + 'r { fn respond_to(self, req: &'r Request<'_>) -> response::Result<'r, 'o> { match self { Ok(responder) => responder.respond_to(req), - Err(responder) => responder.respond_to(req), + // Err(responder) => responder.respond_to(req), + Err(e) => Err(Box::new(e)), } } } diff --git a/core/lib/src/response/status.rs b/core/lib/src/response/status.rs index 0095c48144..35f4e6544a 100644 --- a/core/lib/src/response/status.rs +++ b/core/lib/src/response/status.rs @@ -29,8 +29,9 @@ use std::hash::{Hash, Hasher}; use std::collections::hash_map::DefaultHasher; use std::borrow::Cow; -use transient::Transient; +use transient::{CanTranscendTo, Inv, Transient}; +use crate::catcher::TypedError; use crate::request::Request; use crate::response::{self, Responder, Response}; use crate::http::Status; @@ -300,6 +301,32 @@ macro_rules! status_response { Custom(Status::$T, self.0).respond_to(req) } } + + impl From for $T { + fn from(v: R) -> Self { + Self(v) + } + } + + unsafe impl Transient for $T { + type Static = BadRequest; + type Transience = R::Transience; + } + + impl<'r, R: TypedError<'r>> TypedError<'r> for $T + where R: Transient, + R::Transience: CanTranscendTo>, + { + fn respond_to(&self, request: &'r Request<'_>) -> Result, Status> { + self.0.respond_to(request) + } + + fn name(&self) -> &'static str { self.0.name() } + + fn source(&'r self) -> Option<&'r (dyn TypedError<'r> + 'r)> { Some(&self.0) } + + fn status(&self) -> Status { Status::$T } + } } } diff --git a/core/lib/src/route/handler.rs b/core/lib/src/route/handler.rs index e4d02b1693..35add00042 100644 --- a/core/lib/src/route/handler.rs +++ b/core/lib/src/route/handler.rs @@ -209,9 +209,9 @@ impl<'r, 'o: 'r> Outcome<'o> { /// ``` #[inline] pub fn try_from(req: &'r Request<'_>, result: Result) -> Outcome<'r> - where R: Responder<'r, 'o>, E: std::fmt::Debug + where R: Responder<'r, 'o>, E: TypedError<'r> + 'r { - let responder = result.map_err(crate::response::Debug); + let responder = result;//.map_err(crate::response::Debug); match responder.respond_to(req) { Ok(response) => Outcome::Success(response), Err(status) => Outcome::Error(status) diff --git a/core/lib/src/router/router.rs b/core/lib/src/router/router.rs index 2770783a4d..318e1a3615 100644 --- a/core/lib/src/router/router.rs +++ b/core/lib/src/router/router.rs @@ -589,7 +589,7 @@ mod test { fn catcher<'a>(r: &'a Router, status: Status, uri: &str) -> Option<&'a Catcher> { let client = Client::debug_with(vec![]).expect("client"); let request = client.get(Origin::parse(uri).unwrap()); - r.catch(status, &request) + r.catch(status, None, &request) } macro_rules! assert_catcher_routing { diff --git a/core/lib/src/serde/json.rs b/core/lib/src/serde/json.rs index fd08f2401e..8a493e3e44 100644 --- a/core/lib/src/serde/json.rs +++ b/core/lib/src/serde/json.rs @@ -27,6 +27,7 @@ use std::{io, fmt, error}; use std::ops::{Deref, DerefMut}; +use crate::catcher::TypedError; use crate::request::{Request, local_cache}; use crate::data::{Limits, Data, FromData, Outcome}; use crate::response::{self, Responder, content}; @@ -137,12 +138,32 @@ pub enum Error<'a> { /// received from the user, while the `Error` in `.1` is the deserialization /// error from `serde`. Parse(&'a str, serde_json::error::Error), + + /// An I/O error occurred while reading the incoming request data. + TooLarge(io::Error), +} + +// SAFETY: It's always safe to assume `Inv<'a>` +unsafe impl<'a> Transient for Error<'a> { + type Static = Error<'static>; + type Transience = Inv<'a>; +} + +impl<'r> TypedError<'r> for Error<'r> { + fn status(&self) -> Status { + match self { + Self::TooLarge(..) => Status::PayloadTooLarge, + Self::Io(..) => Status::BadRequest, + Self::Parse(..) => Status::UnprocessableEntity, + } + } } impl<'a> fmt::Display for Error<'a> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::Io(err) => write!(f, "i/o error: {}", err), + Self::TooLarge(err) => write!(f, "i/o error: {}", err), Self::Parse(_, err) => write!(f, "parse error: {}", err), } } @@ -152,6 +173,7 @@ impl<'a> error::Error for Error<'a> { fn source(&self) -> Option<&(dyn error::Error + 'static)> { match self { Self::Io(err) => Some(err), + Self::TooLarge(err) => Some(err), Self::Parse(_, err) => Some(err), } } @@ -201,12 +223,12 @@ impl<'r, T: Deserialize<'r>> FromData<'r> for Json { match Self::from_data(req, data).await { Ok(value) => Outcome::Success(value), Err(Error::Io(e)) if e.kind() == io::ErrorKind::UnexpectedEof => { - Outcome::Error((Status::PayloadTooLarge, Error::Io(e))) + Outcome::Error(Error::TooLarge(e)) }, Err(Error::Parse(s, e)) if e.classify() == serde_json::error::Category::Data => { - Outcome::Error((Status::UnprocessableEntity, Error::Parse(s, e))) + Outcome::Error(Error::Parse(s, e)) }, - Err(e) => Outcome::Error((Status::BadRequest, e)), + Err(e) => Outcome::Error(e), } } @@ -279,6 +301,7 @@ impl From> for form::Error<'_> { fn from(e: Error<'_>) -> Self { match e { Error::Io(e) => e.into(), + Error::TooLarge(e) => e.into(), Error::Parse(_, e) => form::Error::custom(e) } } @@ -416,6 +439,7 @@ crate::export! { /// ``` #[doc(inline)] pub use serde_json::Value; +use transient::{Inv, Transient}; /// Deserialize an instance of type `T` from bytes of JSON text. /// diff --git a/core/lib/src/serde/msgpack.rs b/core/lib/src/serde/msgpack.rs index 93fcd4511d..d8224998d1 100644 --- a/core/lib/src/serde/msgpack.rs +++ b/core/lib/src/serde/msgpack.rs @@ -198,16 +198,7 @@ impl<'r, T: Deserialize<'r>> FromData<'r> for MsgPack { async fn from_data(req: &'r Request<'_>, data: Data<'r>) -> Outcome<'r, Self> { match Self::from_data(req, data).await { Ok(value) => Outcome::Success(value), - Err(Error::InvalidDataRead(e)) if e.kind() == io::ErrorKind::UnexpectedEof => { - Outcome::Error((Status::PayloadTooLarge, Error::InvalidDataRead(e))) - }, - | Err(e@Error::TypeMismatch(_)) - | Err(e@Error::OutOfRange) - | Err(e@Error::LengthMismatch(_)) - => { - Outcome::Error((Status::UnprocessableEntity, e)) - }, - Err(e) => Outcome::Error((Status::BadRequest, e)), + Err(e) => Outcome::Error(e), } } } diff --git a/core/lib/tests/forward-includes-status-1560.rs b/core/lib/tests/forward-includes-status-1560.rs index f5b8190d47..1e612ff72a 100644 --- a/core/lib/tests/forward-includes-status-1560.rs +++ b/core/lib/tests/forward-includes-status-1560.rs @@ -7,7 +7,7 @@ struct Authenticated; #[rocket::async_trait] impl<'r> FromRequest<'r> for Authenticated { - type Error = std::convert::Infallible; + type Error = Status; async fn from_request(request: &'r Request<'_>) -> request::Outcome { if request.headers().contains("Authenticated") { @@ -22,7 +22,7 @@ struct TeapotForward; #[rocket::async_trait] impl<'r> FromRequest<'r> for TeapotForward { - type Error = std::convert::Infallible; + type Error = Status; async fn from_request(_: &'r Request<'_>) -> request::Outcome { request::Outcome::Forward(Status::ImATeapot) diff --git a/core/lib/tests/local-request-content-type-issue-505.rs b/core/lib/tests/local-request-content-type-issue-505.rs index d5042803b0..e997540977 100644 --- a/core/lib/tests/local-request-content-type-issue-505.rs +++ b/core/lib/tests/local-request-content-type-issue-505.rs @@ -9,9 +9,9 @@ struct HasContentType; #[rocket::async_trait] impl<'r> FromRequest<'r> for HasContentType { - type Error = (); + type Error = Status; - async fn from_request(req: &'r Request<'_>) -> request::Outcome { + async fn from_request(req: &'r Request<'_>) -> request::Outcome { req.content_type().map(|_| HasContentType).or_forward(Status::NotFound) } } @@ -20,7 +20,7 @@ use rocket::data::{self, FromData}; #[rocket::async_trait] impl<'r> FromData<'r> for HasContentType { - type Error = (); + type Error = Status; async fn from_data(req: &'r Request<'_>, data: Data<'r>) -> data::Outcome<'r, Self> { req.content_type().map(|_| HasContentType).or_forward((data, Status::NotFound)) diff --git a/core/lib/tests/responder_lifetime-issue-345.rs b/core/lib/tests/responder_lifetime-issue-345.rs index 4cd12f000b..372ffb0ab6 100644 --- a/core/lib/tests/responder_lifetime-issue-345.rs +++ b/core/lib/tests/responder_lifetime-issue-345.rs @@ -13,7 +13,7 @@ pub struct CustomResponder<'r, R> { } impl<'r, 'o: 'r, R: Responder<'r, 'o>> Responder<'r, 'o> for CustomResponder<'r, R> { - fn respond_to(self, req: &'r Request<'_>) -> Result<'o> { + fn respond_to(self, req: &'r Request<'_>) -> Result<'r, 'o> { self.responder.respond_to(req) } } diff --git a/examples/cookies/src/session.rs b/examples/cookies/src/session.rs index 31d0fc613c..4201f9c5a3 100644 --- a/examples/cookies/src/session.rs +++ b/examples/cookies/src/session.rs @@ -3,6 +3,7 @@ use rocket::request::{self, FlashMessage, FromRequest, Request}; use rocket::response::{Redirect, Flash}; use rocket::http::{CookieJar, Status}; use rocket::form::Form; +use rocket::either::Either; use rocket_dyn_templates::{Template, context}; @@ -17,7 +18,7 @@ struct User(usize); #[rocket::async_trait] impl<'r> FromRequest<'r> for User { - type Error = std::convert::Infallible; + type Error = Status; async fn from_request(request: &'r Request<'_>) -> request::Outcome { request.cookies() @@ -58,12 +59,12 @@ fn login_page(flash: Option>) -> Template { } #[post("/login", data = "")] -fn post_login(jar: &CookieJar<'_>, login: Form>) -> Result> { +fn post_login(jar: &CookieJar<'_>, login: Form>) -> Either> { if login.username == "Sergio" && login.password == "password" { jar.add_private(("user_id", "1")); - Ok(Redirect::to(uri!(index))) + Either::Left(Redirect::to(uri!(index))) } else { - Err(Flash::error(Redirect::to(uri!(login_page)), "Invalid username/password.")) + Either::Right(Flash::error(Redirect::to(uri!(login_page)), "Invalid username/password.")) } } diff --git a/examples/error-handling/Cargo.toml b/examples/error-handling/Cargo.toml index c19138a7b2..a8f17a5bbf 100644 --- a/examples/error-handling/Cargo.toml +++ b/examples/error-handling/Cargo.toml @@ -6,4 +6,5 @@ edition = "2021" publish = false [dependencies] -rocket = { path = "../../core/lib" } +rocket = { path = "../../core/lib", features = ["json"] } +transient = { path = "/code/matthew/transient" } diff --git a/examples/error-handling/src/main.rs b/examples/error-handling/src/main.rs index ffa0a6b13f..c2aa473964 100644 --- a/examples/error-handling/src/main.rs +++ b/examples/error-handling/src/main.rs @@ -2,9 +2,14 @@ #[cfg(test)] mod tests; -use rocket::{Rocket, Request, Build}; +use std::num::ParseIntError; + +use rocket::{Rocket, Build, Responder}; use rocket::response::{content, status}; -use rocket::http::Status; +use rocket::http::{Status, uri::Origin}; + +use rocket::serde::{Serialize, json::Json}; +use rocket::request::FromParamError; #[get("/hello//")] fn hello(name: &str, age: i8) -> String { @@ -16,6 +21,20 @@ fn forced_error(code: u16) -> Status { Status::new(code) } +// TODO: Derive TypedError +#[derive(TypedError, Debug)] +struct CustomError; + +#[get("/")] +fn forced_custom_error() -> Result<(), CustomError> { + Err(CustomError) +} + +#[catch(500, error = "<_e>")] +fn catch_custom(_e: &CustomError) -> &'static str { + "You found the custom error!" +} + #[catch(404)] fn general_not_found() -> content::RawHtml<&'static str> { content::RawHtml(r#" @@ -25,11 +44,36 @@ fn general_not_found() -> content::RawHtml<&'static str> { } #[catch(404)] -fn hello_not_found(req: &Request<'_>) -> content::RawHtml { +fn hello_not_found(uri: &Origin<'_>) -> content::RawHtml { content::RawHtml(format!("\

Sorry, but '{}' is not a valid path!

\

Try visiting /hello/<name>/<age> instead.

", - req.uri())) + uri)) +} + +// Code to generate a Json response: +#[derive(Responder)] +#[response(status = 422)] +struct ParameterError(T); + +#[derive(Serialize)] +#[serde(crate = "rocket::serde")] +struct ErrorInfo<'a> { + invalid_value: &'a str, + description: String, +} + +// Actual catcher: +#[catch(422, error = "")] +fn param_error<'a>( + // `&ParseIntError` would also work here, but `&FromParamError` + // also gives us access to `raw`, the specific segment that failed to parse. + int_error: &FromParamError<'a, ParseIntError> +) -> ParameterError>> { + ParameterError(Json(ErrorInfo { + invalid_value: int_error.raw, + description: format!("{}", int_error.error), + })) } #[catch(default)] @@ -38,8 +82,8 @@ fn sergio_error() -> &'static str { } #[catch(default)] -fn default_catcher(status: Status, req: &Request<'_>) -> status::Custom { - let msg = format!("{} ({})", status, req.uri()); +fn default_catcher(status: Status, uri: &Origin<'_>) -> status::Custom { + let msg = format!("{} ({})", status, uri); status::Custom(status, msg) } @@ -51,9 +95,9 @@ fn rocket() -> Rocket { rocket::build() // .mount("/", routes![hello, hello]) // uncomment this to get an error // .mount("/", routes![unmanaged]) // uncomment this to get a sentinel error - .mount("/", routes![hello, forced_error]) - .register("/", catchers![general_not_found, default_catcher]) - .register("/hello", catchers![hello_not_found]) + .mount("/", routes![hello, forced_error, forced_custom_error]) + .register("/", catchers![general_not_found, default_catcher, catch_custom]) + .register("/hello", catchers![hello_not_found, param_error]) .register("/hello/Sergio", catchers![sergio_error]) } diff --git a/examples/error-handling/src/tests.rs b/examples/error-handling/src/tests.rs index fcd78424c9..735fd9f566 100644 --- a/examples/error-handling/src/tests.rs +++ b/examples/error-handling/src/tests.rs @@ -1,5 +1,6 @@ use rocket::local::blocking::Client; use rocket::http::Status; +use rocket::serde::json::to_string as json_string; #[test] fn test_hello() { @@ -24,19 +25,19 @@ fn forced_error() { assert_eq!(response.into_string().unwrap(), expected.0); let request = client.get("/405"); - let expected = super::default_catcher(Status::MethodNotAllowed, request.inner()); + let expected = super::default_catcher(Status::MethodNotAllowed, request.uri()); let response = request.dispatch(); assert_eq!(response.status(), Status::MethodNotAllowed); assert_eq!(response.into_string().unwrap(), expected.1); let request = client.get("/533"); - let expected = super::default_catcher(Status::new(533), request.inner()); + let expected = super::default_catcher(Status::new(533), request.uri()); let response = request.dispatch(); assert_eq!(response.status(), Status::new(533)); assert_eq!(response.into_string().unwrap(), expected.1); let request = client.get("/700"); - let expected = super::default_catcher(Status::InternalServerError, request.inner()); + let expected = super::default_catcher(Status::InternalServerError, request.uri()); let response = request.dispatch(); assert_eq!(response.status(), Status::InternalServerError); assert_eq!(response.into_string().unwrap(), expected.1); @@ -48,16 +49,22 @@ fn test_hello_invalid_age() { for path in &["Ford/-129", "Trillian/128"] { let request = client.get(format!("/hello/{}", path)); - let expected = super::default_catcher(Status::UnprocessableEntity, request.inner()); + let expected = super::ErrorInfo { + invalid_value: path.split_once("/").unwrap().1, + description: format!( + "{}", + path.split_once("/").unwrap().1.parse::().unwrap_err() + ), + }; let response = request.dispatch(); assert_eq!(response.status(), Status::UnprocessableEntity); - assert_eq!(response.into_string().unwrap(), expected.1); + assert_eq!(response.into_string().unwrap(), json_string(&expected).unwrap()); } { let path = &"foo/bar/baz"; let request = client.get(format!("/hello/{}", path)); - let expected = super::hello_not_found(request.inner()); + let expected = super::hello_not_found(request.uri()); let response = request.dispatch(); assert_eq!(response.status(), Status::NotFound); assert_eq!(response.into_string().unwrap(), expected.0); @@ -68,6 +75,8 @@ fn test_hello_invalid_age() { fn test_hello_sergio() { let client = Client::tracked(super::rocket()).unwrap(); + // TODO: typed: This logic has changed, either needs to be fixed + // or this test changed. for path in &["oops", "-129"] { let request = client.get(format!("/hello/Sergio/{}", path)); let expected = super::sergio_error(); diff --git a/examples/manual-routing/src/main.rs b/examples/manual-routing/src/main.rs index e4a21620f0..5765fb8ef4 100644 --- a/examples/manual-routing/src/main.rs +++ b/examples/manual-routing/src/main.rs @@ -7,6 +7,7 @@ use rocket::http::{Status, Method::{Get, Post}}; use rocket::response::{Responder, status::Custom}; use rocket::outcome::{try_outcome, IntoOutcome}; use rocket::tokio::fs::File; +use rocket::catcher::TypedError; fn forward<'r>(_req: &'r Request, data: Data<'r>) -> route::BoxFuture<'r> { Box::pin(async move { route::Outcome::forward(data, Status::NotFound) }) @@ -62,9 +63,9 @@ fn get_upload<'r>(req: &'r Request, _: Data<'r>) -> route::BoxFuture<'r> { route::Outcome::from(req, std::fs::File::open(path).ok()).pin() } -fn not_found_handler<'r>(_: Status, req: &'r Request) -> catcher::BoxFuture<'r> { +fn not_found_handler<'r>(_: Status, _: &'r dyn TypedError<'r>, req: &'r Request) -> catcher::BoxFuture<'r> { let responder = Custom(Status::NotFound, format!("Couldn't find: {}", req.uri())); - Box::pin(async move { responder.respond_to(req) }) + Box::pin(async move { responder.respond_to(req).map_err(|e| e.status()) }) } #[derive(Clone)] @@ -82,11 +83,11 @@ impl CustomHandler { impl route::Handler for CustomHandler { async fn handle<'r>(&self, req: &'r Request<'_>, data: Data<'r>) -> route::Outcome<'r> { let self_data = self.data; - let id = req.param::<&str>(0) - .and_then(Result::ok) - .or_forward((data, Status::NotFound)); - - route::Outcome::from(req, format!("{} - {}", self_data, try_outcome!(id))) + match req.param::<&str>(0) { + Some(Ok(id)) => route::Outcome::from(req, format!("{} - {}", self_data, id)), + Some(Err(e)) => route::Outcome::Error(Box::new(e)), + None => route::Outcome::Forward((data, Box::new(Status::NotFound))), + } } } diff --git a/examples/pastebin/src/paste_id.rs b/examples/pastebin/src/paste_id.rs index 3f31a67b42..284d34cd85 100644 --- a/examples/pastebin/src/paste_id.rs +++ b/examples/pastebin/src/paste_id.rs @@ -32,14 +32,18 @@ impl PasteId<'_> { } } +#[derive(Debug, TypedError)] +#[error(debug)] +pub struct InvalidId<'a>(pub &'a str); + /// Returns an instance of `PasteId` if the path segment is a valid ID. /// Otherwise returns the invalid ID as the `Err` value. impl<'a> FromParam<'a> for PasteId<'a> { - type Error = &'a str; + type Error = InvalidId<'a>; fn from_param(param: &'a str) -> Result { param.chars().all(|c| c.is_ascii_alphanumeric()) .then(|| PasteId(param.into())) - .ok_or(param) + .ok_or(InvalidId(param)) } } diff --git a/examples/responders/src/main.rs b/examples/responders/src/main.rs index 90b65b3be2..9f55c8f220 100644 --- a/examples/responders/src/main.rs +++ b/examples/responders/src/main.rs @@ -113,10 +113,10 @@ fn redir_login() -> &'static str { } #[get("/redir/")] -fn maybe_redir(name: &str) -> Result<&'static str, Redirect> { +fn maybe_redir(name: &str) -> Either<&'static str, Redirect> { match name { - "Sergio" => Ok("Hello, Sergio!"), - _ => Err(Redirect::to(uri!(redir_login))), + "Sergio" => Either::Left("Hello, Sergio!"), + _ => Either::Right(Redirect::to(uri!(redir_login))), } } diff --git a/examples/serialization/src/tests.rs b/examples/serialization/src/tests.rs index 8a46c13dcc..bff72896f1 100644 --- a/examples/serialization/src/tests.rs +++ b/examples/serialization/src/tests.rs @@ -38,7 +38,9 @@ fn json_bad_get_put() { // Try to put a message without a proper body. let res = client.put("/json/80").header(ContentType::JSON).dispatch(); - assert_eq!(res.status(), Status::BadRequest); + // TODO: Typed: This behavior has changed + assert_eq!(res.status(), Status::UnprocessableEntity); + // Status::BadRequest); // Try to put a message with a semantically invalid body. let res = client.put("/json/0") diff --git a/examples/serialization/src/uuid.rs b/examples/serialization/src/uuid.rs index 15c804b733..41a902f882 100644 --- a/examples/serialization/src/uuid.rs +++ b/examples/serialization/src/uuid.rs @@ -8,10 +8,10 @@ use rocket::serde::uuid::Uuid; struct People(HashMap); #[get("/people/")] -fn people(id: Uuid, people: &State) -> Result { +fn people(id: Uuid, people: &State) -> String { people.0.get(&id) .map(|person| format!("We found: {}", person)) - .ok_or_else(|| format!("Missing person for UUID: {}", id)) + .unwrap_or_else(|| format!("Missing person for UUID: {}", id)) } pub fn stage() -> rocket::fairing::AdHoc { diff --git a/examples/state/src/request_local.rs b/examples/state/src/request_local.rs index 986242b757..c796d3dbd8 100644 --- a/examples/state/src/request_local.rs +++ b/examples/state/src/request_local.rs @@ -1,6 +1,6 @@ use std::sync::atomic::{AtomicUsize, Ordering}; -use rocket::State; +use rocket::{State, StateMissing}; use rocket::outcome::{Outcome, try_outcome}; use rocket::request::{self, FromRequest, Request}; use rocket::fairing::AdHoc; @@ -18,9 +18,9 @@ struct Guard4; #[rocket::async_trait] impl<'r> FromRequest<'r> for Guard1 { - type Error = (); + type Error = StateMissing; - async fn from_request(req: &'r Request<'_>) -> request::Outcome { + async fn from_request(req: &'r Request<'_>) -> request::Outcome { let atomics = try_outcome!(req.guard::<&State>().await); atomics.uncached.fetch_add(1, Ordering::Relaxed); req.local_cache(|| { @@ -33,9 +33,9 @@ impl<'r> FromRequest<'r> for Guard1 { #[rocket::async_trait] impl<'r> FromRequest<'r> for Guard2 { - type Error = (); + type Error = StateMissing; - async fn from_request(req: &'r Request<'_>) -> request::Outcome { + async fn from_request(req: &'r Request<'_>) -> request::Outcome { try_outcome!(req.guard::().await); Outcome::Success(Guard2) } @@ -43,9 +43,9 @@ impl<'r> FromRequest<'r> for Guard2 { #[rocket::async_trait] impl<'r> FromRequest<'r> for Guard3 { - type Error = (); + type Error = StateMissing; - async fn from_request(req: &'r Request<'_>) -> request::Outcome { + async fn from_request(req: &'r Request<'_>) -> request::Outcome { let atomics = try_outcome!(req.guard::<&State>().await); atomics.uncached.fetch_add(1, Ordering::Relaxed); req.local_cache_async(async { @@ -58,9 +58,9 @@ impl<'r> FromRequest<'r> for Guard3 { #[rocket::async_trait] impl<'r> FromRequest<'r> for Guard4 { - type Error = (); + type Error = StateMissing; - async fn from_request(req: &'r Request<'_>) -> request::Outcome { + async fn from_request(req: &'r Request<'_>) -> request::Outcome { try_outcome!(Guard3::from_request(req).await); Outcome::Success(Guard4) } diff --git a/examples/todo/src/main.rs b/examples/todo/src/main.rs index b641162435..8207da0c71 100644 --- a/examples/todo/src/main.rs +++ b/examples/todo/src/main.rs @@ -13,6 +13,7 @@ use rocket::response::{Flash, Redirect}; use rocket::serde::Serialize; use rocket::form::Form; use rocket::fs::{FileServer, relative}; +use rocket::either::Either; use rocket_dyn_templates::Template; @@ -64,23 +65,23 @@ async fn new(todo_form: Form, conn: DbConn) -> Flash { } #[put("/")] -async fn toggle(id: i32, conn: DbConn) -> Result { +async fn toggle(id: i32, conn: DbConn) -> Either { match Task::toggle_with_id(id, &conn).await { - Ok(_) => Ok(Redirect::to("/")), + Ok(_) => Either::Left(Redirect::to("/")), Err(e) => { error!("DB toggle({id}) error: {e}"); - Err(Template::render("index", Context::err(&conn, "Failed to toggle task.").await)) + Either::Right(Template::render("index", Context::err(&conn, "Failed to toggle task.").await)) } } } #[delete("/")] -async fn delete(id: i32, conn: DbConn) -> Result, Template> { +async fn delete(id: i32, conn: DbConn) -> Either, Template> { match Task::delete_with_id(id, &conn).await { - Ok(_) => Ok(Flash::success(Redirect::to("/"), "Todo was deleted.")), + Ok(_) => Either::Left(Flash::success(Redirect::to("/"), "Todo was deleted.")), Err(e) => { error!("DB deletion({id}) error: {e}"); - Err(Template::render("index", Context::err(&conn, "Failed to delete task.").await)) + Either::Right(Template::render("index", Context::err(&conn, "Failed to delete task.").await)) } } } From 88636e217c91ae3feb087745e1c04615483727e3 Mon Sep 17 00:00:00 2001 From: Matthew Pomes Date: Wed, 20 Nov 2024 02:12:53 -0600 Subject: [PATCH 10/20] Update with working implementation --- core/codegen/src/attribute/catch/mod.rs | 10 +- core/codegen/src/attribute/catch/parse.rs | 12 +- core/codegen/src/attribute/route/mod.rs | 20 ++- core/codegen/src/derive/responder.rs | 2 +- core/codegen/src/derive/typed_error.rs | 28 ++- core/codegen/tests/route.rs | 3 +- core/http/Cargo.toml | 3 +- core/http/src/lib.rs | 2 +- core/http/src/status.rs | 14 ++ core/http/src/uri/fmt/uri_display.rs | 4 +- core/lib/Cargo.toml | 3 +- core/lib/src/catcher/catcher.rs | 64 +++++-- core/lib/src/catcher/handler.rs | 16 +- core/lib/src/catcher/types.rs | 201 +++++++--------------- core/lib/src/data/data.rs | 5 +- core/lib/src/data/from_data.rs | 22 ++- core/lib/src/error.rs | 5 +- core/lib/src/fairing/ad_hoc.rs | 10 +- core/lib/src/fairing/mod.rs | 13 +- core/lib/src/form/parser.rs | 5 +- core/lib/src/http/cookies.rs | 2 +- core/lib/src/lifecycle.rs | 92 +++++++--- core/lib/src/mtls/certificate.rs | 2 +- core/lib/src/mtls/error.rs | 3 + core/lib/src/outcome.rs | 4 +- core/lib/src/request/from_param.rs | 37 ++-- core/lib/src/request/from_request.rs | 27 +-- core/lib/src/response/flash.rs | 7 +- core/lib/src/response/responder.rs | 10 +- core/lib/src/response/status.rs | 4 +- core/lib/src/response/stream/reader.rs | 2 +- core/lib/src/route/handler.rs | 6 +- core/lib/src/router/matcher.rs | 8 +- core/lib/src/router/router.rs | 145 ++++++++-------- core/lib/src/server.rs | 5 +- core/lib/src/state.rs | 4 +- core/lib/tests/panic-handling.rs | 7 +- core/lib/tests/sentinel.rs | 7 +- docs/guide/05-requests.md | 9 +- docs/guide/06-responses.md | 19 +- docs/guide/07-state.md | 4 +- docs/guide/12-pastebin.md | 18 +- docs/guide/14-faq.md | 2 +- examples/manual-routing/src/main.rs | 6 +- examples/pastebin/src/paste_id.rs | 6 +- examples/todo/src/main.rs | 8 +- testbench/src/servers/tracing.rs | 13 +- 47 files changed, 518 insertions(+), 381 deletions(-) diff --git a/core/codegen/src/attribute/catch/mod.rs b/core/codegen/src/attribute/catch/mod.rs index f6824f8f7f..55835cc599 100644 --- a/core/codegen/src/attribute/catch/mod.rs +++ b/core/codegen/src/attribute/catch/mod.rs @@ -31,8 +31,10 @@ pub fn _catch( let from_error = catch.guards.iter().map(|g| { let name = g.fn_ident.rocketized(); let ty = g.ty.with_replaced_lifetimes(Lifetime::new("'__r", g.ty.span())); - quote_spanned!(g.span() => - let #name: #ty = match <#ty as #FromError<'__r>>::from_error(#__status, #__req, #__error).await { + quote_spanned!(g.span() => + let #name: #ty = match < + #ty as #FromError<'__r> + >::from_error(#__status, #__req, #__error).await { #_Ok(v) => v, #_Err(s) => { // TODO: Typed: log failure @@ -63,7 +65,7 @@ pub fn _catch( }, _ => todo!("Invalid type"), }; - quote_spanned!(g.span() => + quote_spanned!(g.span() => #_catcher::TypeId::of::<#ty>() ) })); @@ -76,7 +78,7 @@ pub fn _catch( let name = a.typed().unwrap().0.rocketized(); quote!(#name) }); - + let catcher_response = quote_spanned!(return_type_span => { let ___responder = #user_catcher_fn_name(#(#args),*) #dot_await; #_response::Responder::respond_to(___responder, #__req).map_err(|e| e.status())? diff --git a/core/codegen/src/attribute/catch/parse.rs b/core/codegen/src/attribute/catch/parse.rs index 8764d14aa8..8309409a3a 100644 --- a/core/codegen/src/attribute/catch/parse.rs +++ b/core/codegen/src/attribute/catch/parse.rs @@ -69,10 +69,18 @@ impl Attribute { if let Some((ident, ty)) = arg.typed() { match meta.error.as_ref() { Some(err) if Name::from(ident) == err.name => { - error = Some(Guard { source: meta.error.clone().unwrap().value, fn_ident: ident.clone(), ty: ty.clone() }); + error = Some(Guard { + source: meta.error.clone().unwrap().value, + fn_ident: ident.clone(), + ty: ty.clone(), + }); } _ => { - guards.push(Guard { source: Dynamic { name: Name::from(ident), index, trailing: false }, fn_ident: ident.clone(), ty: ty.clone() }) + guards.push(Guard { + source: Dynamic { name: Name::from(ident), index, trailing: false }, + fn_ident: ident.clone(), + ty: ty.clone(), + }) } } } else { diff --git a/core/codegen/src/attribute/route/mod.rs b/core/codegen/src/attribute/route/mod.rs index 68d8ccde86..ba43576332 100644 --- a/core/codegen/src/attribute/route/mod.rs +++ b/core/codegen/src/attribute/route/mod.rs @@ -114,7 +114,10 @@ fn query_decls(route: &Route) -> Option { ); } } ); - return #Outcome::Forward((#__data, Box::new(__e) as Box + '__r>)); + return #Outcome::Forward(( + #__data, + Box::new(__e) as Box + '__r> + )); } (#(#ident.unwrap()),*) @@ -141,7 +144,10 @@ fn request_guard_decl(guard: &Guard) -> TokenStream { "request guard forwarding" ); - return #Outcome::Forward((#__data, Box::new(__e) as Box + '__r>)); + return #Outcome::Forward(( + #__data, + Box::new(__e) as Box + '__r> + )); }, #[allow(unreachable_code)] #Outcome::Error(__c) => { @@ -205,7 +211,10 @@ fn param_guard_decl(guard: &Guard) -> TokenStream { #i ); - return #Outcome::Forward((#__data, Box::new(#Status::InternalServerError) as Box + '__r>)) + return #Outcome::Forward(( + #__data, + Box::new(#Status::InternalServerError) as Box + '__r> + )); } } }, @@ -242,7 +251,10 @@ fn data_guard_decl(guard: &Guard) -> TokenStream { "data guard forwarding" ); - return #Outcome::Forward((__d, Box::new(__e) as Box + '__r>)); + return #Outcome::Forward(( + __d, + Box::new(__e) as Box + '__r> + )); } #[allow(unreachable_code)] #Outcome::Error(__e) => { diff --git a/core/codegen/src/derive/responder.rs b/core/codegen/src/derive/responder.rs index 1ff402d711..78a48ca9ed 100644 --- a/core/codegen/src/derive/responder.rs +++ b/core/codegen/src/derive/responder.rs @@ -15,7 +15,7 @@ use crate::http_codegen::{ContentType, Status}; #[derive(Debug, Default, FromMeta)] struct ItemAttr { content_type: Option>, - status: Option>, + status: Option>, } #[derive(Default, FromMeta)] struct FieldAttr { diff --git a/core/codegen/src/derive/typed_error.rs b/core/codegen/src/derive/typed_error.rs index 45960993f3..17efb4bd4e 100644 --- a/core/codegen/src/derive/typed_error.rs +++ b/core/codegen/src/derive/typed_error.rs @@ -9,7 +9,7 @@ use crate::http_codegen::Status; struct ItemAttr { status: Option>, /// Option to generate a respond_to impl with the debug repr of the type - debug: bool, + debug: Option, } #[derive(Default, FromMeta)] @@ -44,9 +44,13 @@ pub fn derive_typed_error(input: proc_macro::TokenStream) -> TokenStream { } }) .try_fields_map(|_, fields| { - let item = ItemAttr::one_from_attrs("error", fields.parent.attrs())?.unwrap_or(Default::default()); - let status = item.status.map_or(quote!(#_Status::InternalServerError), |m| quote!(#m)); - Ok(if item.debug { + let item = ItemAttr::one_from_attrs("error", fields.parent.attrs())? + .unwrap_or(Default::default()); + let status = item.status.map_or( + quote!(#_Status::InternalServerError), + |m| quote!(#m) + ); + Ok(if item.debug.unwrap_or(false) { quote! { use #_response::Responder; #_response::Debug(self) @@ -63,14 +67,16 @@ pub fn derive_typed_error(input: proc_macro::TokenStream) -> TokenStream { ) .inner_mapper(MapperBuild::new() .with_output(|_, output| quote! { - fn source(&'r self) -> #_Option<&'r (dyn #TypedError<'r> + 'r)> { - #output + fn source(&'r self, idx: usize) -> #_Option<&'r (dyn #TypedError<'r> + 'r)> { + if idx == 0 { #output } else { #_None } } }) .try_fields_map(|_, fields| { let mut source = None; for field in fields.iter() { - if FieldAttr::one_from_attrs("error", &field.attrs)?.is_some_and(|a| a.source) { + if FieldAttr::one_from_attrs("error", &field.attrs)? + .is_some_and(|a| a.source) + { if source.is_some() { return Err(Diagnostic::spanned( field.span(), @@ -101,8 +107,12 @@ pub fn derive_typed_error(input: proc_macro::TokenStream) -> TokenStream { fn status(&self) -> #_Status { #output } }) .try_fields_map(|_, fields| { - let item = ItemAttr::one_from_attrs("error", fields.parent.attrs())?.unwrap_or(Default::default()); - let status = item.status.map_or(quote!(#_Status::InternalServerError), |m| quote!(#m)); + let item = ItemAttr::one_from_attrs("error", fields.parent.attrs())? + .unwrap_or(Default::default()); + let status = item.status.map_or( + quote!(#_Status::InternalServerError), + |m| quote!(#m) + ); Ok(quote! { #status }) }) ) diff --git a/core/codegen/tests/route.rs b/core/codegen/tests/route.rs index f9a1bf66e9..9be578d9a1 100644 --- a/core/codegen/tests/route.rs +++ b/core/codegen/tests/route.rs @@ -12,6 +12,7 @@ use rocket::http::ext::Normalize; use rocket::local::blocking::Client; use rocket::data::{self, Data, FromData}; use rocket::http::{Status, RawStr, ContentType, uri::fmt::Path}; +use rocket::response::status::BadRequest; // Use all of the code generation available at once. @@ -24,7 +25,7 @@ struct Simple(String); #[async_trait] impl<'r> FromData<'r> for Simple { - type Error = std::io::Error; + type Error = BadRequest; async fn from_data(req: &'r Request<'_>, data: Data<'r>) -> data::Outcome<'r, Self> { String::from_data(req, data).await.map(Simple) diff --git a/core/http/Cargo.toml b/core/http/Cargo.toml index b2beca4adc..ab9d9c4ed3 100644 --- a/core/http/Cargo.toml +++ b/core/http/Cargo.toml @@ -36,7 +36,8 @@ memchr = "2" stable-pattern = "0.1" cookie = { version = "0.18", features = ["percent-encode"] } state = "0.6" -transient = "0.4.1" +# transient = "0.4.1" +transient = { git = "https://github.com/the10thWiz/transient.git", branch = "rocket-ready" } [dependencies.serde] version = "1.0" diff --git a/core/http/src/lib.rs b/core/http/src/lib.rs index 86950c2428..d0cffbcc02 100644 --- a/core/http/src/lib.rs +++ b/core/http/src/lib.rs @@ -36,6 +36,6 @@ pub mod private { } pub use crate::method::Method; -pub use crate::status::{Status, StatusClass}; +pub use crate::status::{Status, StatusClass, AsStatus}; pub use crate::raw_str::{RawStr, RawStrBuf}; pub use crate::header::*; diff --git a/core/http/src/status.rs b/core/http/src/status.rs index 41d2d90510..78014a97c2 100644 --- a/core/http/src/status.rs +++ b/core/http/src/status.rs @@ -43,6 +43,14 @@ impl StatusClass { class_check_fn!(is_unknown, "`Unknown`.", Unknown); } +/// Trait to convert any type into a status +/// +/// Mostly used to allow `Status` to implement `From` for any type `T`. +pub trait AsStatus { + /// Status associated with this particular object + fn as_status(&self) -> Status; +} + /// Structure representing an HTTP status: an integer code. /// /// A `Status` should rarely be created directly. Instead, an associated @@ -127,6 +135,12 @@ impl Default for Status { } } +impl From for Status { + fn from(val: T) -> Self { + val.as_status() + } +} + macro_rules! ctrs { ($($code:expr, $code_str:expr, $name:ident => $reason:expr),+) => { $( diff --git a/core/http/src/uri/fmt/uri_display.rs b/core/http/src/uri/fmt/uri_display.rs index 621f3337ba..e463ec644b 100644 --- a/core/http/src/uri/fmt/uri_display.rs +++ b/core/http/src/uri/fmt/uri_display.rs @@ -243,13 +243,13 @@ use crate::uri::fmt::{Part, Path, Query, Formatter}; /// const PREFIX: &str = "name:"; /// /// impl<'r> FromParam<'r> for Name<'r> { -/// type Error = &'r str; +/// type Error = (); /// /// /// Validates parameters that start with 'name:', extracting the text /// /// after 'name:' as long as there is at least one character. /// fn from_param(param: &'r str) -> Result { /// if !param.starts_with(PREFIX) || param.len() < (PREFIX.len() + 1) { -/// return Err(param); +/// return Err(()); /// } /// /// let real_name = ¶m[PREFIX.len()..]; diff --git a/core/lib/Cargo.toml b/core/lib/Cargo.toml index 156288fa17..4a50a641d8 100644 --- a/core/lib/Cargo.toml +++ b/core/lib/Cargo.toml @@ -74,7 +74,8 @@ tokio-stream = { version = "0.1.6", features = ["signal", "time"] } cookie = { version = "0.18", features = ["percent-encode"] } futures = { version = "0.3.30", default-features = false, features = ["std"] } state = "0.6" -transient = { version = "0.4.1", features = ["either"] } +# transient = { version = "0.4.1", features = ["either"] } +transient = { features = ["either"], git = "https://github.com/the10thWiz/transient.git", branch = "rocket-ready" } # tracing tracing = { version = "0.1.40", default-features = false, features = ["std", "attributes"] } diff --git a/core/lib/src/catcher/catcher.rs b/core/lib/src/catcher/catcher.rs index b030285ecb..eb067af751 100644 --- a/core/lib/src/catcher/catcher.rs +++ b/core/lib/src/catcher/catcher.rs @@ -154,22 +154,34 @@ impl Catcher { /// /// ```rust /// use rocket::request::Request; - /// use rocket::catcher::{Catcher, BoxFuture}; + /// use rocket::catcher::{Catcher, BoxFuture, TypedError}; /// use rocket::response::Responder; /// use rocket::http::Status; /// - /// fn handle_404<'r>(status: Status, req: &'r Request<'_>) -> BoxFuture<'r> { - /// let res = (status, format!("404: {}", req.uri())); - /// Box::pin(async move { res.respond_to(req) }) + /// fn handle_404<'r>( + /// status: Status, + /// _: &'r dyn TypedError<'r>, + /// req: &'r Request<'_> + /// ) -> BoxFuture<'r> { + /// let res = (status, format!("404: {}", req.uri())); + /// Box::pin(async move { res.respond_to(req).map_err(|e| e.into()) }) /// } /// - /// fn handle_500<'r>(_: Status, req: &'r Request<'_>) -> BoxFuture<'r> { - /// Box::pin(async move{ "Whoops, we messed up!".respond_to(req) }) + /// fn handle_500<'r>( + /// status: Status, + /// _: &'r dyn TypedError<'r>, + /// req: &'r Request<'_> + /// ) -> BoxFuture<'r> { + /// Box::pin(async move{ "Whoops, we messed up!".respond_to(req).map_err(|e| e.into()) }) /// } /// - /// fn handle_default<'r>(status: Status, req: &'r Request<'_>) -> BoxFuture<'r> { + /// fn handle_default<'r>( + /// status: Status, + /// _: &'r dyn TypedError<'r>, + /// req: &'r Request<'_> + /// ) -> BoxFuture<'r> { /// let res = (status, format!("{}: {}", status, req.uri())); - /// Box::pin(async move { res.respond_to(req) }) + /// Box::pin(async move { res.respond_to(req).map_err(|e| e.into()) }) /// } /// /// let not_found_catcher = Catcher::new(404, handle_404); @@ -207,13 +219,17 @@ impl Catcher { /// /// ```rust /// use rocket::request::Request; - /// use rocket::catcher::{Catcher, BoxFuture}; + /// use rocket::catcher::{Catcher, BoxFuture, TypedError}; /// use rocket::response::Responder; /// use rocket::http::Status; /// - /// fn handle_404<'r>(status: Status, req: &'r Request<'_>) -> BoxFuture<'r> { + /// fn handle_404<'r>( + /// status: Status, + /// _e: &'r dyn TypedError<'r>, + /// req: &'r Request<'_> + /// ) -> BoxFuture<'r> { /// let res = (status, format!("404: {}", req.uri())); - /// Box::pin(async move { res.respond_to(req) }) + /// Box::pin(async move { res.respond_to(req).map_err(|e| e.into()) }) /// } /// /// let catcher = Catcher::new(404, handle_404); @@ -233,14 +249,18 @@ impl Catcher { /// /// ```rust /// use rocket::request::Request; - /// use rocket::catcher::{Catcher, BoxFuture}; + /// use rocket::catcher::{Catcher, BoxFuture, TypedError}; /// use rocket::response::Responder; /// use rocket::http::Status; /// # use rocket::uri; /// - /// fn handle_404<'r>(status: Status, req: &'r Request<'_>) -> BoxFuture<'r> { + /// fn handle_404<'r>( + /// status: Status, + /// _e: &'r dyn TypedError<'r>, + /// req: &'r Request<'_> + /// ) -> BoxFuture<'r> { /// let res = (status, format!("404: {}", req.uri())); - /// Box::pin(async move { res.respond_to(req) }) + /// Box::pin(async move { res.respond_to(req).map_err(|e| e.into()) }) /// } /// /// let catcher = Catcher::new(404, handle_404); @@ -287,13 +307,17 @@ impl Catcher { /// /// ```rust /// use rocket::request::Request; - /// use rocket::catcher::{Catcher, BoxFuture}; + /// use rocket::catcher::{Catcher, BoxFuture, TypedError}; /// use rocket::response::Responder; /// use rocket::http::Status; /// - /// fn handle_404<'r>(status: Status, req: &'r Request<'_>) -> BoxFuture<'r> { + /// fn handle_404<'r>( + /// status: Status, + /// _: &'r dyn TypedError<'r>, + /// req: &'r Request<'_> + /// ) -> BoxFuture<'r> { /// let res = (status, format!("404: {}", req.uri())); - /// Box::pin(async move { res.respond_to(req) }) + /// Box::pin(async move { res.respond_to(req).map_err(|e| e.into()) }) /// } /// /// let catcher = Catcher::new(404, handle_404); @@ -321,7 +345,11 @@ impl Catcher { impl Default for Catcher { fn default() -> Self { - fn handler<'r>(status: Status, e: &'r dyn TypedError<'r>, req: &'r Request<'_>) -> BoxFuture<'r> { + fn handler<'r>( + status: Status, + e: &'r dyn TypedError<'r>, + req: &'r Request<'_> + ) -> BoxFuture<'r> { Box::pin(async move { Ok(default_handler(status, e, req)) }) } diff --git a/core/lib/src/catcher/handler.rs b/core/lib/src/catcher/handler.rs index e05302d18a..e285f8411a 100644 --- a/core/lib/src/catcher/handler.rs +++ b/core/lib/src/catcher/handler.rs @@ -32,6 +32,7 @@ pub type BoxFuture<'r, T = Result<'r>> = futures::future::BoxFuture<'r, T>; /// use rocket::{Request, Catcher, catcher}; /// use rocket::response::{Response, Responder}; /// use rocket::http::Status; +/// use rocket::catcher::TypedError; /// /// #[derive(Copy, Clone)] /// enum Kind { @@ -45,7 +46,7 @@ pub type BoxFuture<'r, T = Result<'r>> = futures::future::BoxFuture<'r, T>; /// /// #[rocket::async_trait] /// impl catcher::Handler for CustomHandler { -/// async fn handle<'r>(&self, status: Status, req: &'r Request<'_>) -> catcher::Result<'r> { +/// async fn handle<'r>(&self, status: Status, _e: &'r dyn TypedError<'r>, req: &'r Request<'_>) -> catcher::Result<'r> { /// let inner = match self.0 { /// Kind::Simple => "simple".respond_to(req)?, /// Kind::Intermediate => "intermediate".respond_to(req)?, @@ -98,7 +99,12 @@ pub trait Handler: Cloneable + Send + Sync + 'static { /// Nevertheless, failure is allowed, both for convenience and necessity. If /// an error handler fails, Rocket's default `500` catcher is invoked. If it /// succeeds, the returned `Response` is used to respond to the client. - async fn handle<'r>(&self, status: Status, error: &'r dyn TypedError<'r>, req: &'r Request<'_>) -> Result<'r>; + async fn handle<'r>( + &self, + status: Status, + error: &'r dyn TypedError<'r>, + req: &'r Request<'_> + ) -> Result<'r>; } // We write this manually to avoid double-boxing. @@ -122,7 +128,11 @@ impl Handler for F // Used in tests! Do not use, please. #[doc(hidden)] -pub fn dummy_handler<'r>(_: Status, _: &'r dyn TypedError<'r>, _: &'r Request<'_>) -> BoxFuture<'r> { +pub fn dummy_handler<'r>( + _: Status, + _: &'r dyn TypedError<'r>, + _: &'r Request<'_> +) -> BoxFuture<'r> { Box::pin(async move { Ok(Response::new()) }) } diff --git a/core/lib/src/catcher/types.rs b/core/lib/src/catcher/types.rs index 723a29cdb6..ab96544e64 100644 --- a/core/lib/src/catcher/types.rs +++ b/core/lib/src/catcher/types.rs @@ -2,7 +2,7 @@ use std::fmt; use either::Either; use transient::{Any, CanRecoverFrom, Downcast, Transience}; -use crate::{http::Status, response::status::Custom, Request, Response}; +use crate::{http::{Status, AsStatus}, response::status::Custom, Request, Response}; #[doc(inline)] pub use transient::{Static, Transient, TypeId, Inv, CanTranscendTo}; @@ -46,12 +46,22 @@ pub trait TypedError<'r>: AsAny> + Send + Sync + 'r { /// A descriptive name of this error type. Defaults to the type name. fn name(&self) -> &'static str { std::any::type_name::() } - /// The error that caused this error. Defaults to None. + // /// The error that caused this error. Defaults to None. + // /// + // /// # Warning + // /// A typed catcher will not attempt to follow the source of an error + // /// more than (TODO: exact number) 5 times. + // fn source(&'r self) -> Option<&'r (dyn TypedError<'r> + 'r)> { None } + + // TODO: Typed: need to support case where there are multiple errors + /// The error that caused this error. Defaults to None. Each source + /// should only be returned for one index - this method will be called + /// with indicies starting with 0, and increasing until it returns None. /// /// # Warning /// A typed catcher will not attempt to follow the source of an error /// more than (TODO: exact number) 5 times. - fn source(&'r self) -> Option<&'r (dyn TypedError<'r> + 'r)> { None } + fn source(&'r self, _idx: usize) -> Option<&'r (dyn TypedError<'r> + 'r)> { None } /// Status code fn status(&self) -> Status { Status::InternalServerError } @@ -70,10 +80,6 @@ impl<'r> TypedError<'r> for Status { "" } - fn source(&'r self) -> Option<&'r (dyn TypedError<'r> + 'r)> { - Some(&()) - } - fn status(&self) -> Status { *self } @@ -85,6 +91,11 @@ impl<'r> From for Box + 'r> { } } +impl AsStatus for Box + '_> { + fn as_status(&self) -> Status { + self.status() + } +} // TODO: Typed: update transient to make the possible. // impl<'r, R: TypedError<'r> + Transient> TypedError<'r> for (Status, R) // where R::Transience: CanTranscendTo> @@ -106,6 +117,34 @@ impl<'r> From for Box + 'r> { // } // } +impl<'r, A: TypedError<'r> + Transient, B: TypedError<'r> + Transient> TypedError<'r> for (A, B) + where A::Transience: CanTranscendTo>, + B::Transience: CanTranscendTo>, + // (A, B): Transient, + // <(A, B) as Transient>::Transience: CanTranscendTo>, +{ + fn respond_to(&self, request: &'r Request<'_>) -> Result, Status> { + self.0.respond_to(request).or_else(|_| self.1.respond_to(request)) + } + + fn name(&self) -> &'static str { + // TODO: Typed: Should indicate that the + std::any::type_name::<(A, B)>() + } + + fn source(&'r self, idx: usize) -> Option<&'r (dyn TypedError<'r> + 'r)> { + match idx { + 0 => Some(&self.0), + 1 => Some(&self.1), + _ => None, + } + } + + fn status(&self) -> Status { + self.0.status() + } +} + impl<'r, R: TypedError<'r> + Transient> TypedError<'r> for Custom where R::Transience: CanTranscendTo> { @@ -117,8 +156,8 @@ impl<'r, R: TypedError<'r> + Transient> TypedError<'r> for Custom self.1.name() } - fn source(&'r self) -> Option<&'r (dyn TypedError<'r> + 'r)> { - Some(&self.1) + fn source(&'r self, idx: usize) -> Option<&'r (dyn TypedError<'r> + 'r)> { + if idx == 0 { Some(&self.1) } else { None } } fn status(&self) -> Status { @@ -148,10 +187,18 @@ impl<'r> TypedError<'r> for std::num::ParseFloatError { fn status(&self) -> Status { Status::BadRequest } } +impl<'r> TypedError<'r> for std::str::ParseBoolError { + fn status(&self) -> Status { Status::BadRequest } +} + impl<'r> TypedError<'r> for std::string::FromUtf8Error { fn status(&self) -> Status { Status::BadRequest } } +impl<'r> TypedError<'r> for std::net::AddrParseError { + fn status(&self) -> Status { Status::BadRequest } +} + impl<'r> TypedError<'r> for crate::http::uri::error::PathError { fn status(&self) -> Status { Status::BadRequest } } @@ -168,7 +215,8 @@ impl<'r> TypedError<'r> for rmp_serde::encode::Error { } impl<'r> TypedError<'r> for rmp_serde::decode::Error { fn status(&self) -> Status { match self { - rmp_serde::decode::Error::InvalidDataRead(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => Status::BadRequest, + rmp_serde::decode::Error::InvalidDataRead(e) + if e.kind() == std::io::ErrorKind::UnexpectedEof => Status::BadRequest, | rmp_serde::decode::Error::TypeMismatch(..) | rmp_serde::decode::Error::OutOfRange | rmp_serde::decode::Error::LengthMismatch(..) => Status::UnprocessableEntity, @@ -182,13 +230,6 @@ impl<'r> TypedError<'r> for uuid_::Error { fn status(&self) -> Status { Status::BadRequest } } -// // TODO: This is a hack to make any static type implement Transient -// impl<'r, T: std::fmt::Debug + Send + Sync + 'static> TypedError<'r> for response::Debug { -// fn respond_to(&self, request: &'r Request<'_>) -> Result, Status> { -// format!("{:?}", self.0).respond_to(request).responder_error() -// } -// } - impl<'r, L, R> TypedError<'r> for Either where L: TypedError<'r> + Transient, L::Transience: CanTranscendTo>, @@ -209,10 +250,14 @@ impl<'r, L, R> TypedError<'r> for Either } } - fn source(&'r self) -> Option<&'r (dyn TypedError<'r> + 'r)> { - match self { - Self::Left(v) => Some(v), - Self::Right(v) => Some(v), + fn source(&'r self, idx: usize) -> Option<&'r (dyn TypedError<'r> + 'r)> { + if idx == 0 { + match self { + Self::Left(v) => Some(v), + Self::Right(v) => Some(v), + } + } else { + None } } @@ -230,27 +275,6 @@ impl fmt::Debug for dyn TypedError<'_> { } } -// // TODO: This cannot be used as a bound on an untyped catcher to get any error type. -// // This is mostly an implementation detail (and issue with double boxing) for -// // the responder derive -// // We should just get rid of this. `&dyn TypedError<'_>` impls `FromError` -// #[derive(Transient)] -// pub struct AnyError<'r>(pub Box + 'r>); - -// impl<'r> TypedError<'r> for AnyError<'r> { -// fn source(&'r self) -> Option<&'r (dyn TypedError<'r> + 'r)> { -// Some(self.0.as_ref()) -// } - -// fn respond_to(&self, request: &'r Request<'_>) -> Result, Status> { -// self.0.respond_to(request) -// } - -// fn name(&self) -> &'static str { self.0.name() } - -// fn status(&self) -> Status { self.0.status() } -// } - /// Validates that a type implements `TypedError`. Used by the `#[catch]` attribute to ensure /// the `TypeError` is first in the diagnostics. #[doc(hidden)] @@ -267,96 +291,3 @@ pub fn downcast<'r, T>(v: &'r dyn TypedError<'r>) -> Option<&'r T> // crate::trace::error!("Downcasting error from {}", v.name()); v.as_any().downcast_ref() } - -// TODO: Typed: This isn't used at all right now -// /// Upcasts a value to `Box>`, falling back to a default if it doesn't implement -// /// `Error` -// #[doc(hidden)] -// #[macro_export] -// macro_rules! resolve_typed_catcher { -// ($T:expr) => ({ -// #[allow(unused_imports)] -// use $crate::catcher::resolution::{Resolve, DefaultTypeErase, ResolvedTypedError}; - -// let inner = Resolve::new($T).cast(); -// ResolvedTypedError { -// name: inner.as_ref().ok().map(|e| e.name()), -// val: inner, -// } -// }); -// } - -// pub use resolve_typed_catcher; - -// pub mod resolution { -// use std::marker::PhantomData; - -// use transient::{CanTranscendTo, Transient}; - -// use super::*; - -// /// The *magic*. -// /// -// /// `Resolve::item` for `T: Transient` is `::item`. -// /// `Resolve::item` for `T: !Transient` is `DefaultTypeErase::item`. -// /// -// /// This _must_ be used as `Resolve:::item` for resolution to work. This -// /// is a fun, static dispatch hack for "specialization" that works because -// /// Rust prefers inherent methods over blanket trait impl methods. -// pub struct Resolve<'r, T: 'r>(pub T, PhantomData<&'r ()>); - -// impl<'r, T: 'r> Resolve<'r, T> { -// pub fn new(val: T) -> Self { -// Self(val, PhantomData) -// } -// } - -// /// Fallback trait "implementing" `Transient` for all types. This is what -// /// Rust will resolve `Resolve::item` to when `T: !Transient`. -// pub trait DefaultTypeErase<'r>: Sized { -// const SPECIALIZED: bool = false; - -// fn cast(self) -> Result>, Self> { Err(self) } -// } - -// impl<'r, T: 'r> DefaultTypeErase<'r> for Resolve<'r, T> {} - -// /// "Specialized" "implementation" of `Transient` for `T: Transient`. This is -// /// what Rust will resolve `Resolve::item` to when `T: Transient`. -// impl<'r, T: TypedError<'r> + Transient> Resolve<'r, T> -// where T::Transience: CanTranscendTo> -// { -// pub const SPECIALIZED: bool = true; - -// pub fn cast(self) -> Result>, Self> { Ok(Box::new(self.0)) } -// } - -// // TODO: These extensions maybe useful, but so far not really -// // // Box can be upcast without double boxing? -// // impl<'r> Resolve<'r, Box>> { -// // pub const SPECIALIZED: bool = true; - -// // pub fn cast(self) -> Result>, Self> { Ok(self.0) } -// // } - -// // Ideally, we should be able to handle this case, but we can't, since we don't own `Either` -// // impl<'r, A, B> Resolve<'r, Either> -// // where A: TypedError<'r> + Transient, -// // A::Transience: CanTranscendTo>, -// // B: TypedError<'r> + Transient, -// // B::Transience: CanTranscendTo>, -// // { -// // pub const SPECIALIZED: bool = true; - -// // pub fn cast(self) -> Result>, Self> { Ok(Box::new(self.0)) } -// // } - -// /// Wrapper type to hold the return type of `resolve_typed_catcher`. -// #[doc(hidden)] -// pub struct ResolvedTypedError<'r, T> { -// /// The return value from `TypedError::name()`, if Some -// pub name: Option<&'static str>, -// /// The upcast error, if it supports it -// pub val: Result + 'r>, Resolve<'r, T>>, -// } -// } diff --git a/core/lib/src/data/data.rs b/core/lib/src/data/data.rs index c54fa00d80..d5ea8d9e1d 100644 --- a/core/lib/src/data/data.rs +++ b/core/lib/src/data/data.rs @@ -114,7 +114,8 @@ impl<'r> Data<'r> { /// use rocket::data::{Data, FromData, Outcome}; /// use rocket::http::Status; /// # struct MyType; - /// # type MyError = String; + /// # #[derive(rocket::TypedError)] + /// # struct MyError; /// /// #[rocket::async_trait] /// impl<'r> FromData<'r> for MyType { @@ -122,7 +123,7 @@ impl<'r> Data<'r> { /// /// async fn from_data(r: &'r Request<'_>, mut data: Data<'r>) -> Outcome<'r, Self> { /// if data.peek(2).await != b"hi" { - /// return Outcome::Forward((data, Status::BadRequest)) + /// return Outcome::Forward((data, MyError)) /// } /// /// /* .. */ diff --git a/core/lib/src/data/from_data.rs b/core/lib/src/data/from_data.rs index 18d52aac51..758c6466dd 100644 --- a/core/lib/src/data/from_data.rs +++ b/core/lib/src/data/from_data.rs @@ -182,7 +182,8 @@ pub type Outcome<'r, T, E = >::Error> /// use rocket::request::Request; /// use rocket::data::{self, Data, FromData}; /// # struct MyType; -/// # type MyError = String; +/// # #[derive(rocket::TypedError)] +/// # struct MyError; /// /// #[rocket::async_trait] /// impl<'r> FromData<'r> for MyType { @@ -232,13 +233,20 @@ pub type Outcome<'r, T, E = >::Error> /// use rocket::data::{self, Data, FromData, ToByteUnit}; /// use rocket::http::{Status, ContentType}; /// use rocket::outcome::Outcome; +/// use rocket::TypedError; /// -/// #[derive(Debug)] +/// #[derive(Debug, TypedError)] /// enum Error { +/// #[error(status = 413)] /// TooLarge, +/// #[error(status = 400)] /// NoColon, +/// #[error(status = 422)] /// InvalidAge, +/// #[error(status = 500)] /// Io(std::io::Error), +/// #[error(status = 415)] +/// UnsupportedMediaType, /// } /// /// #[rocket::async_trait] @@ -251,7 +259,7 @@ pub type Outcome<'r, T, E = >::Error> /// // Ensure the content type is correct before opening the data. /// let person_ct = ContentType::new("application", "x-person"); /// if req.content_type() != Some(&person_ct) { -/// return Outcome::Forward((data, Status::UnsupportedMediaType)); +/// return Outcome::Forward((data, Error::UnsupportedMediaType)); /// } /// /// // Use a configured limit with name 'person' or fallback to default. @@ -260,8 +268,8 @@ pub type Outcome<'r, T, E = >::Error> /// // Read the data into a string. /// let string = match data.open(limit).into_string().await { /// Ok(string) if string.is_complete() => string.into_inner(), -/// Ok(_) => return Outcome::Error((Status::PayloadTooLarge, TooLarge)), -/// Err(e) => return Outcome::Error((Status::InternalServerError, Io(e))), +/// Ok(_) => return Outcome::Error(TooLarge), +/// Err(e) => return Outcome::Error(Io(e)), /// }; /// /// // We store `string` in request-local cache for long-lived borrows. @@ -270,13 +278,13 @@ pub type Outcome<'r, T, E = >::Error> /// // Split the string into two pieces at ':'. /// let (name, age) = match string.find(':') { /// Some(i) => (&string[..i], &string[(i + 1)..]), -/// None => return Outcome::Error((Status::UnprocessableEntity, NoColon)), +/// None => return Outcome::Error(NoColon), /// }; /// /// // Parse the age. /// let age: u16 = match age.parse() { /// Ok(age) => age, -/// Err(_) => return Outcome::Error((Status::UnprocessableEntity, InvalidAge)), +/// Err(_) => return Outcome::Error(InvalidAge), /// }; /// /// Outcome::Success(Person { name, age }) diff --git a/core/lib/src/error.rs b/core/lib/src/error.rs index 12a0ec4708..9ce70f85f3 100644 --- a/core/lib/src/error.rs +++ b/core/lib/src/error.rs @@ -8,6 +8,7 @@ use figment::Profile; use transient::Static; use crate::http::Status; +use crate::TypedError; use crate::catcher::TypedError; use crate::listener::Endpoint; use crate::{Catcher, Ignite, Orbit, Phase, Rocket, Route}; @@ -93,6 +94,7 @@ pub enum ErrorKind { pub struct Empty; impl Static for Empty {} + impl<'r> TypedError<'r> for Empty { fn status(&self) -> Status { Status::BadRequest @@ -129,7 +131,8 @@ impl<'r> TypedError<'r> for Empty { /// } /// } /// ``` -#[derive(Debug, Clone)] +#[derive(Debug, Clone, TypedError)] +#[error(status = 400)] #[non_exhaustive] pub struct InvalidOption<'a> { /// The value that was provided. diff --git a/core/lib/src/fairing/ad_hoc.rs b/core/lib/src/fairing/ad_hoc.rs index 4a768d2871..90cc20c26c 100644 --- a/core/lib/src/fairing/ad_hoc.rs +++ b/core/lib/src/fairing/ad_hoc.rs @@ -7,6 +7,8 @@ use crate::fairing::{Fairing, Kind, Info, Result}; use crate::route::RouteUri; use crate::trace::Trace; +use super::FilterResult; + /// A ad-hoc fairing that can be created from a function or closure. /// /// This enum can be used to create a fairing from a simple function or closure @@ -176,14 +178,14 @@ impl AdHoc { /// // The no-op request fairing. /// let fairing = AdHoc::on_request_filter("Dummy", |req| { /// Box::pin(async move { - /// // do something with the request and data... - /// # let (_, _) = (req, data); + /// // do something with the request... + /// # let _ = req; /// Ok(()) /// }) /// }); /// ``` pub fn on_request_filter(name: &'static str, f: F) -> AdHoc - where F: for<'a> Fn(&'a Request<'_>) -> BoxFuture<'a, Result<(), Box + 'a>>> + where F: for<'a> Fn(&'a Request<'_>) -> BoxFuture<'a, FilterResult<'a>> { AdHoc { name, kind: AdHocKind::RequestFilter(Box::new(f)) } } @@ -463,7 +465,7 @@ impl Fairing for AdHoc { } } - async fn on_request_filter<'r>(&self, req: &'r Request<'_>) -> Result<(), Box + 'r>> { + async fn on_request_filter<'r>(&self, req: &'r Request<'_>) -> FilterResult<'r> { if let AdHocKind::RequestFilter(ref f) = self.kind { f(req).await } else { diff --git a/core/lib/src/fairing/mod.rs b/core/lib/src/fairing/mod.rs index e8573ceee6..d24a8356b7 100644 --- a/core/lib/src/fairing/mod.rs +++ b/core/lib/src/fairing/mod.rs @@ -65,6 +65,9 @@ pub use self::info_kind::{Info, Kind}; /// A type alias for the return `Result` type of [`Fairing::on_ignite()`]. pub type Result, E = Rocket> = std::result::Result; +/// A type alias for the return `Result` type of [`Fairing::on_request_filter()`] +pub type FilterResult<'r> = std::result::Result<(), Box + 'r>>; + // We might imagine that a request fairing returns an `Outcome`. If it returns // `Success`, we don't do any routing and use that response directly. Same if it // returns `Error`. We only route if it returns `Forward`. I've chosen not to @@ -267,6 +270,7 @@ pub type Result, E = Rocket> = std::result::Result, E = Rocket> = std::result::Result FromRequest<'r> for StartTime { -/// type Error = (); +/// type Error = Status; /// -/// async fn from_request(request: &'r Request<'_>) -> request::Outcome { +/// async fn from_request(request: &'r Request<'_>) -> request::Outcome { /// match *request.local_cache(|| TimerStart(None)) { /// TimerStart(Some(time)) => request::Outcome::Success(StartTime(time)), -/// TimerStart(None) => request::Outcome::Error((Status::InternalServerError, ())), +/// TimerStart(None) => request::Outcome::Error(Status::InternalServerError), /// } /// } /// } @@ -538,8 +542,7 @@ pub trait Fairing: Send + Sync + AsAny + 'static { /// ## Default Implementation /// /// The default implementation of this method does nothing. - async fn on_request_filter<'r>(&self, _req: &'r Request<'_>) - -> Result<(), Box + 'r>> + async fn on_request_filter<'r>(&self, _req: &'r Request<'_>) -> FilterResult<'r> { Ok (()) } /// The response callback. diff --git a/core/lib/src/form/parser.rs b/core/lib/src/form/parser.rs index 980b02e0f1..fe88e9d6df 100644 --- a/core/lib/src/form/parser.rs +++ b/core/lib/src/form/parser.rs @@ -36,7 +36,10 @@ impl<'r, 'i> Parser<'r, 'i> { Some(c) if c.is_form() => Self::from_form(req, data).await, Some(c) if c.is_form_data() => Self::from_multipart(req, data).await, _ => return Outcome::Forward((data, Error { - name: None, value: None, kind: ErrorKind::UnsupportedMediaType, entity: Entity::Form, + name: None, + value: None, + kind: ErrorKind::UnsupportedMediaType, + entity: Entity::Form, }.into())), }; diff --git a/core/lib/src/http/cookies.rs b/core/lib/src/http/cookies.rs index 56f694774a..ee3cbee6ec 100644 --- a/core/lib/src/http/cookies.rs +++ b/core/lib/src/http/cookies.rs @@ -108,7 +108,7 @@ pub use cookie::{Cookie, SameSite, Iter}; /// /// #[rocket::async_trait] /// impl<'r> FromRequest<'r> for User { -/// type Error = std::convert::Infallible; +/// type Error = Status; /// /// async fn from_request(request: &'r Request<'_>) -> request::Outcome { /// request.cookies() diff --git a/core/lib/src/lifecycle.rs b/core/lib/src/lifecycle.rs index f3be7d8d19..0768e6bcf1 100644 --- a/core/lib/src/lifecycle.rs +++ b/core/lib/src/lifecycle.rs @@ -8,7 +8,7 @@ use crate::http::{Header, Method, Status}; use crate::outcome::Outcome; use crate::trace::Trace; use crate::util::Formatter; -use crate::{catcher, route, Data, Orbit, Request, Response, Rocket}; +use crate::{catcher, route, Catcher, Data, Orbit, Request, Response, Rocket}; // A token returned to force the execution of one method before another. pub(crate) struct RequestToken; @@ -257,7 +257,11 @@ impl Rocket { // // On catcher error, the 500 error catcher is attempted. If _that_ errors, // the (infallible) default 500 error cather is used. - #[tracing::instrument("catching", skip_all, fields(status = error.status().code, uri = %req.uri()))] + #[tracing::instrument( + "catching", + skip_all, + fields(status = error.status().code, uri = %req.uri()) + )] pub(crate) async fn dispatch_error<'r, 's: 'r>( &'s self, mut error: &'r dyn TypedError<'r>, @@ -287,47 +291,79 @@ impl Rocket { } } - /// Invokes the handler with `req` for catcher with status `status`. + /// Find minimum rank typed catcher, following up to 5 * 5 sources. + fn get_min<'s, 'r: 's>( + &'s self, + status: Status, + error: &'r dyn TypedError<'r>, + req: &'r Request<'s>, + depth: usize, + ) -> Option<&'s Catcher> { + const MAX_CALLS_TO_SOURCE: usize = 5; + if depth > MAX_CALLS_TO_SOURCE { + return None; + } + let mut min = self.router.catch(status, Some(error), req); + if let Some(catcher) = self.router.catch_any(status, Some(error), req) { + if min.is_none_or(|m| m.rank > catcher.rank) { + min = Some(catcher); + } + } + for i in 0..MAX_CALLS_TO_SOURCE { + let Some(val) = error.source(i) else { break; }; + if let Some(catcher) = self.get_min(status, val, req, depth + 1) { + if min.is_none_or(|m| m.rank > catcher.rank) { + min = Some(catcher); + } + } + } + min + } + + /// Invokes the handler with `req` for catcher with error `error`. + /// + /// In the order searched: + /// * Matching Status and Type + /// * Matching Type, but not Status + /// * Each of the above, but for the error's `source()`, up to + /// 5 calls deep + /// * Matching Status, but not Type + /// * Default handler + /// * Rocket's default /// - /// In order of preference, invoked handler is: - /// * the user's registered handler for `status` - /// * the user's registered `default` handler - /// * Rocket's default handler for `status` + /// The handler selected to be invoked is the one with the lowest rank. + /// + /// (Rocket's default is implicitly higher ranked than every other catcher) /// /// Return `Ok(result)` if the handler succeeded. Returns `Ok(Some(Status))` /// if the handler ran to completion but failed. Returns `Ok(None)` if the /// handler panicked while executing. - // TODO: Typed: Docs async fn invoke_catcher<'s, 'r: 's>( &'s self, error: &'r dyn TypedError<'r>, req: &'r Request<'s>, ) -> Result, Option> { - const MAX_CALLS_TO_SOURCE: usize = 5; let status = error.status(); - let iter = std::iter::successors(Some(error), |e| e.source()) - .take(MAX_CALLS_TO_SOURCE) - .flat_map(|e| [ - // Catchers with matching status and typeid - self.router.catch(status, Some(e), req), - // Catchers with `default` status and typeid - self.router.catch_any(status, Some(e), req) - ].into_iter().filter_map(|c| c)) - .chain([ - // Catcher with matching status and no typeid - self.router.catch(status, None, req), - // Catcher with `default` status and no typeid - self.router.catch_any(status, None, req) - ].into_iter().filter_map(|c| c)); - // Select lowest rank of (up to) 12 matching catchers. - if let Some(catcher) = iter.min_by_key(|c| c.rank) { + let mut min = self.get_min(status, error, req, 0); + if let Some(catcher) = self.router.catch(status, None, req) { + if min.is_none_or(|m| m.rank > catcher.rank) { + min = Some(catcher); + } + } + if let Some(catcher) = self.router.catch_any(status, None, req) { + if min.is_none_or(|m| m.rank > catcher.rank) { + min = Some(catcher); + } + } + if let Some(catcher) = min { catcher.trace_info(); - catch_handle(catcher.name.as_deref(), || catcher.handler.handle(status, error, req)).await + catch_handle(catcher.name.as_deref(), || catcher.handler.handle(status, error, req)) + .await .map(|result| result.map_err(Some)) .unwrap_or_else(|| Err(None)) } else { - info!(name: "catcher", name = "rocket::default", "uri.base" = "/", code = error.status().code, - "no registered catcher: using Rocket default"); + info!(name: "catcher", name = "rocket::default", "uri.base" = "/", + code = error.status().code, "no registered catcher: using Rocket default"); Ok(catcher::default_handler(status, error, req)) } } diff --git a/core/lib/src/mtls/certificate.rs b/core/lib/src/mtls/certificate.rs index b8645ea7df..e6dfb6bef5 100644 --- a/core/lib/src/mtls/certificate.rs +++ b/core/lib/src/mtls/certificate.rs @@ -55,7 +55,7 @@ use crate::request::{Request, FromRequest, Outcome}; /// if let Some(true) = cert.has_serial(ADMIN_SERIAL) { /// Outcome::Success(CertifiedAdmin(cert)) /// } else { -/// Outcome::Forward(Status::Unauthorized) +/// Outcome::Forward(mtls::Error::SubjectUnauthorized) /// } /// } /// } diff --git a/core/lib/src/mtls/error.rs b/core/lib/src/mtls/error.rs index 9741121e38..06f99493ed 100644 --- a/core/lib/src/mtls/error.rs +++ b/core/lib/src/mtls/error.rs @@ -42,6 +42,8 @@ pub enum Error { Incomplete(Option), /// The certificate contained `.0` bytes of trailing data. Trailing(usize), + /// The subject is not authorized + SubjectUnauthorized, } impl Static for Error {} @@ -59,6 +61,7 @@ impl fmt::Display for Error { Error::Empty => write!(f, "empty certificate chain"), Error::NoSubject => write!(f, "empty subject without subjectAlt"), Error::NonCriticalSubjectAlt => write!(f, "empty subject without critical subjectAlt"), + Error::SubjectUnauthorized => write!(f, "subject not permitted"), } } } diff --git a/core/lib/src/outcome.rs b/core/lib/src/outcome.rs index e72aee71aa..055a23535c 100644 --- a/core/lib/src/outcome.rs +++ b/core/lib/src/outcome.rs @@ -780,7 +780,9 @@ impl<'r, 'o: 'r> IntoOutcome> for response::Result<'r, 'o> { } #[inline] - fn or_forward(self, (data, forward): (Data<'r>, Box + 'r>)) -> route::Outcome<'r> { + fn or_forward(self, + (data, forward): (Data<'r>, Box + 'r>) + ) -> route::Outcome<'r> { match self { Ok(val) => Success(val), Err(_) => Forward((data, forward)) diff --git a/core/lib/src/request/from_param.rs b/core/lib/src/request/from_param.rs index 733e38dd85..8257bf7ddc 100644 --- a/core/lib/src/request/from_param.rs +++ b/core/lib/src/request/from_param.rs @@ -157,27 +157,30 @@ use crate::http::{uri::{Segments, error::PathError, fmt::Path}, Status}; /// /// ```rust /// use rocket::request::FromParam; +/// use rocket::TypedError; /// # #[allow(dead_code)] /// # struct MyParam<'r> { key: &'r str, value: usize } +/// #[derive(TypedError)] +/// struct MyParamError<'a>(&'a str); /// /// impl<'r> FromParam<'r> for MyParam<'r> { -/// type Error = &'r str; +/// type Error = MyParamError<'r>; /// /// fn from_param(param: &'r str) -> Result { /// // We can convert `param` into a `str` since we'll check every /// // character for safety later. /// let (key, val_str) = match param.find(':') { /// Some(i) if i > 0 => (¶m[..i], ¶m[(i + 1)..]), -/// _ => return Err(param) +/// _ => return Err(MyParamError(param)) /// }; /// /// if !key.chars().all(|c| c.is_ascii_alphabetic()) { -/// return Err(param); +/// return Err(MyParamError(param)); /// } /// /// val_str.parse() /// .map(|value| MyParam { key, value }) -/// .map_err(|_| param) +/// .map_err(|_| MyParamError(param)) /// } /// } /// ``` @@ -188,12 +191,15 @@ use crate::http::{uri::{Segments, error::PathError, fmt::Path}, Status}; /// ```rust /// # #[macro_use] extern crate rocket; /// # use rocket::request::FromParam; +/// # use rocket::TypedError; +/// # #[derive(TypedError)] +/// # struct MyParamError<'a>(&'a str); /// # #[allow(dead_code)] /// # struct MyParam<'r> { key: &'r str, value: usize } /// # impl<'r> FromParam<'r> for MyParam<'r> { -/// # type Error = &'r str; +/// # type Error = MyParamError<'r>; /// # fn from_param(param: &'r str) -> Result { -/// # Err(param) +/// # Err(MyParamError(param)) /// # } /// # } /// # @@ -209,7 +215,7 @@ use crate::http::{uri::{Segments, error::PathError, fmt::Path}, Status}; /// ``` pub trait FromParam<'a>: Sized { /// The associated error to be returned if parsing/validation fails. - type Error: std::fmt::Debug; + type Error: TypedError<'a>; /// Parses and validates an instance of `Self` from a path parameter string /// or returns an `Error` if parsing or validation fails. @@ -244,8 +250,8 @@ impl<'a, T: TypedError<'a>> TypedError<'a> for FromParamError<'a, T> self.error.respond_to(request) } - fn source(&'a self) -> Option<&'a (dyn TypedError<'a> + 'a)> { - Some(&self.error) + fn source(&'a self, idx: usize) -> Option<&'a (dyn TypedError<'a> + 'a)> { + if idx == 0 { Some(&self.error) } else { None } } fn status(&self) -> Status { @@ -391,7 +397,7 @@ impl<'a, T: FromParam<'a>> FromParam<'a> for Option { /// the `Utf8Error`. pub trait FromSegments<'r>: Sized { /// The associated error to be returned when parsing fails. - type Error: std::fmt::Debug; + type Error: TypedError<'r>; /// Parses an instance of `Self` from many dynamic path parameter strings or /// returns an `Error` if one cannot be parsed. @@ -427,8 +433,8 @@ impl<'a, T: TypedError<'a>> TypedError<'a> for FromSegmentsError<'a, T> self.error.respond_to(request) } - fn source(&'a self) -> Option<&'a (dyn TypedError<'a> + 'a)> { - Some(&self.error) + fn source(&'a self, idx: usize) -> Option<&'a (dyn TypedError<'a> + 'a)> { + if idx == 0 { Some(&self.error) } else { None } } fn status(&self) -> Status { @@ -522,7 +528,12 @@ impl<'r, T: FromSegments<'r>> FromSegments<'r> for Option { /// returned. If `B::from_param` returns `Ok(b)`, `Either::Right(b)` is /// returned. If both `A::from_param` and `B::from_param` return `Err(a)` and /// `Err(b)`, respectively, then `Err((a, b))` is returned. -impl<'v, A: FromParam<'v>, B: FromParam<'v>> FromParam<'v> for Either { +impl<'v, A: FromParam<'v>, B: FromParam<'v>> FromParam<'v> for Either + where A::Error: Transient, + ::Transience: CanTranscendTo>, + B::Error: Transient, + ::Transience: CanTranscendTo>, +{ type Error = (A::Error, B::Error); #[inline(always)] diff --git a/core/lib/src/request/from_request.rs b/core/lib/src/request/from_request.rs index b97b0365ed..dff3b9febf 100644 --- a/core/lib/src/request/from_request.rs +++ b/core/lib/src/request/from_request.rs @@ -35,7 +35,9 @@ pub type Outcome = outcome::Outcome; /// ```rust /// use rocket::request::{self, Request, FromRequest}; /// # struct MyType; -/// # type MyError = String; +/// # use rocket::TypedError; +/// # #[derive(TypedError)] +/// # struct MyError; /// /// #[rocket::async_trait] /// impl<'r> FromRequest<'r> for MyType { @@ -209,7 +211,8 @@ pub type Outcome = outcome::Outcome; /// /// struct ApiKey<'r>(&'r str); /// -/// #[derive(Debug)] +/// #[derive(Debug, TypedError)] +/// #[error(status = 400)] /// enum ApiKeyError { /// Missing, /// Invalid, @@ -226,9 +229,9 @@ pub type Outcome = outcome::Outcome; /// } /// /// match req.headers().get_one("x-api-key") { -/// None => Outcome::Error((Status::BadRequest, ApiKeyError::Missing)), +/// None => Outcome::Error(ApiKeyError::Missing), /// Some(key) if is_valid(key) => Outcome::Success(ApiKey(key)), -/// Some(_) => Outcome::Error((Status::BadRequest, ApiKeyError::Invalid)), +/// Some(_) => Outcome::Error(ApiKeyError::Invalid), /// } /// } /// } @@ -264,8 +267,8 @@ pub type Outcome = outcome::Outcome; /// # } /// # #[rocket::async_trait] /// # impl<'r> FromRequest<'r> for Database { -/// # type Error = (); -/// # async fn from_request(request: &'r Request<'_>) -> Outcome { +/// # type Error = Status; +/// # async fn from_request(request: &'r Request<'_>) -> Outcome { /// # Outcome::Success(Database) /// # } /// # } @@ -274,9 +277,9 @@ pub type Outcome = outcome::Outcome; /// # /// #[rocket::async_trait] /// impl<'r> FromRequest<'r> for User { -/// type Error = (); +/// type Error = Status; /// -/// async fn from_request(request: &'r Request<'_>) -> Outcome { +/// async fn from_request(request: &'r Request<'_>) -> Outcome { /// let db = try_outcome!(request.guard::().await); /// request.cookies() /// .get_private("user_id") @@ -288,9 +291,9 @@ pub type Outcome = outcome::Outcome; /// /// #[rocket::async_trait] /// impl<'r> FromRequest<'r> for Admin { -/// type Error = (); +/// type Error = Status; /// -/// async fn from_request(request: &'r Request<'_>) -> Outcome { +/// async fn from_request(request: &'r Request<'_>) -> Outcome { /// // This will unconditionally query the database! /// let user = try_outcome!(request.guard::().await); /// if user.is_admin { @@ -339,7 +342,7 @@ pub type Outcome = outcome::Outcome; /// # /// #[rocket::async_trait] /// impl<'r> FromRequest<'r> for &'r User { -/// type Error = std::convert::Infallible; +/// type Error = Status; /// /// async fn from_request(request: &'r Request<'_>) -> Outcome { /// // This closure will execute at most once per request, regardless of @@ -358,7 +361,7 @@ pub type Outcome = outcome::Outcome; /// /// #[rocket::async_trait] /// impl<'r> FromRequest<'r> for Admin<'r> { -/// type Error = std::convert::Infallible; +/// type Error = Status; /// /// async fn from_request(request: &'r Request<'_>) -> Outcome { /// let user = try_outcome!(request.guard::<&User>().await); diff --git a/core/lib/src/response/flash.rs b/core/lib/src/response/flash.rs index ac5b0febda..5621abf75f 100644 --- a/core/lib/src/response/flash.rs +++ b/core/lib/src/response/flash.rs @@ -52,13 +52,14 @@ const FLASH_COOKIE_DELIM: char = ':'; /// # #[macro_use] extern crate rocket; /// use rocket::response::{Flash, Redirect}; /// use rocket::request::FlashMessage; +/// use rocket::either::Either; /// /// #[post("/login/")] -/// fn login(name: &str) -> Result<&'static str, Flash> { +/// fn login(name: &str) -> Either<&'static str, Flash> { /// if name == "special_user" { -/// Ok("Hello, special user!") +/// Either::Left("Hello, special user!") /// } else { -/// Err(Flash::error(Redirect::to(uri!(index)), "Invalid username.")) +/// Either::Right(Flash::error(Redirect::to(uri!(index)), "Invalid username.")) /// } /// } /// diff --git a/core/lib/src/response/responder.rs b/core/lib/src/response/responder.rs index 6079aa8349..239c3b5c1f 100644 --- a/core/lib/src/response/responder.rs +++ b/core/lib/src/response/responder.rs @@ -174,7 +174,7 @@ use crate::request::Request; /// # struct A; /// // If the response contains no borrowed data. /// impl<'r> Responder<'r, 'static> for A { -/// fn respond_to(self, _: &'r Request<'_>) -> response::Result<'static> { +/// fn respond_to(self, _: &'r Request<'_>) -> response::Result<'r, 'static> { /// todo!() /// } /// } @@ -182,7 +182,7 @@ use crate::request::Request; /// # struct B<'r>(&'r str); /// // If the response borrows from the request. /// impl<'r> Responder<'r, 'r> for B<'r> { -/// fn respond_to(self, _: &'r Request<'_>) -> response::Result<'r> { +/// fn respond_to(self, _: &'r Request<'_>) -> response::Result<'r, 'r> { /// todo!() /// } /// } @@ -190,7 +190,7 @@ use crate::request::Request; /// # struct C; /// // If the response is or wraps a borrow that may outlive the request. /// impl<'r, 'o: 'r> Responder<'r, 'o> for &'o C { -/// fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o> { +/// fn respond_to(self, _: &'r Request<'_>) -> response::Result<'r, 'o> { /// todo!() /// } /// } @@ -198,7 +198,7 @@ use crate::request::Request; /// # struct D(R); /// // If the response wraps an existing responder. /// impl<'r, 'o: 'r, R: Responder<'r, 'o>> Responder<'r, 'o> for D { -/// fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o> { +/// fn respond_to(self, _: &'r Request<'_>) -> response::Result<'r, 'o> { /// todo!() /// } /// } @@ -248,7 +248,7 @@ use crate::request::Request; /// use rocket::http::ContentType; /// /// impl<'r> Responder<'r, 'static> for Person { -/// fn respond_to(self, req: &'r Request<'_>) -> response::Result<'static> { +/// fn respond_to(self, req: &'r Request<'_>) -> response::Result<'r, 'static> { /// let string = format!("{}:{}", self.name, self.age); /// Response::build_from(string.respond_to(req)?) /// .raw_header("X-Person-Name", self.name) diff --git a/core/lib/src/response/status.rs b/core/lib/src/response/status.rs index 35f4e6544a..e49a0d5451 100644 --- a/core/lib/src/response/status.rs +++ b/core/lib/src/response/status.rs @@ -323,7 +323,9 @@ macro_rules! status_response { fn name(&self) -> &'static str { self.0.name() } - fn source(&'r self) -> Option<&'r (dyn TypedError<'r> + 'r)> { Some(&self.0) } + fn source(&'r self, idx: usize) -> Option<&'r (dyn TypedError<'r> + 'r)> { + if idx == 0 { Some(&self.0) } else { None } + } fn status(&self) -> Status { Status::$T } } diff --git a/core/lib/src/response/stream/reader.rs b/core/lib/src/response/stream/reader.rs index d414996d65..5318370963 100644 --- a/core/lib/src/response/stream/reader.rs +++ b/core/lib/src/response/stream/reader.rs @@ -39,7 +39,7 @@ pin_project! { /// impl<'r, S: Stream> Responder<'r, 'r> for MyStream /// where S: Send + 'r /// { - /// fn respond_to(self, _: &'r Request<'_>) -> response::Result<'r> { + /// fn respond_to(self, _: &'r Request<'_>) -> response::Result<'r, 'r> { /// Response::build() /// .header(ContentType::Text) /// .streamed_body(ReaderStream::from(self.0.map(Cursor::new))) diff --git a/core/lib/src/route/handler.rs b/core/lib/src/route/handler.rs index 35add00042..9057fd3963 100644 --- a/core/lib/src/route/handler.rs +++ b/core/lib/src/route/handler.rs @@ -4,7 +4,11 @@ use crate::response::{Response, Responder}; /// Type alias for the return type of a [`Route`](crate::Route)'s /// [`Handler::handle()`]. -pub type Outcome<'r> = crate::outcome::Outcome, Box>, (Data<'r>, Box>)>; +pub type Outcome<'r> = crate::outcome::Outcome< + Response<'r>, + Box>, + (Data<'r>, Box>), +>; /// Type alias for the return type of a _raw_ [`Route`](crate::Route)'s /// [`Handler`]. diff --git a/core/lib/src/router/matcher.rs b/core/lib/src/router/matcher.rs index a76d7e64b8..97e86debd0 100644 --- a/core/lib/src/router/matcher.rs +++ b/core/lib/src/router/matcher.rs @@ -120,14 +120,14 @@ impl Catcher { /// // Let's say `request` is `GET /` that 404s. The error matches only `a`: /// let request = client.get("/"); /// # let request = request.inner(); - /// assert!(a.matches(Status::NotFound, &request)); - /// assert!(!b.matches(Status::NotFound, &request)); + /// assert!(a.matches(Status::NotFound, None, &request)); + /// assert!(!b.matches(Status::NotFound, None, &request)); /// /// // Now `request` is a 404 `GET /bar`. The error matches `a` and `b`: /// let request = client.get("/bar"); /// # let request = request.inner(); - /// assert!(a.matches(Status::NotFound, &request)); - /// assert!(b.matches(Status::NotFound, &request)); + /// assert!(a.matches(Status::NotFound, None, &request)); + /// assert!(b.matches(Status::NotFound, None, &request)); /// /// // Note that because `b`'s base' has more complete segments that `a's, /// // Rocket would route the error to `b`, not `a`, even though both match. diff --git a/core/lib/src/router/router.rs b/core/lib/src/router/router.rs index 318e1a3615..fa8be78b97 100644 --- a/core/lib/src/router/router.rs +++ b/core/lib/src/router/router.rs @@ -107,7 +107,12 @@ impl Router { // For many catchers, using aho-corasick or similar should be much faster. #[track_caller] - pub fn catch<'r>(&self, status: Status, error: Option<&'r dyn TypedError<'r>>, req: &'r Request<'r>) -> Option<&Catcher> { + pub fn catch<'r>( + &self, + status: Status, + error: Option<&'r dyn TypedError<'r>>, + req: &'r Request<'r> + ) -> Option<&Catcher> { let ty = error.map(|e| e.trait_obj_typeid()); // Note that catchers are presorted by descending base length. self.catcher_map.get(&Some(status.code)) @@ -116,7 +121,12 @@ impl Router { } #[track_caller] - pub fn catch_any<'r>(&self, status: Status, error: Option<&'r dyn TypedError<'r>>, req: &'r Request<'r>) -> Option<&Catcher> { + pub fn catch_any<'r>( + &self, + status: Status, + error: Option<&'r dyn TypedError<'r>>, + req: &'r Request<'r> + ) -> Option<&Catcher> { let ty = error.map(|e| e.trait_obj_typeid()); // Note that catchers are presorted by descending base length. self.catcher_map.get(&None) @@ -617,78 +627,79 @@ mod test { #[test] fn test_catcher_routing() { - // Check that the default `/` catcher catches everything. - assert_catcher_routing! { - catch: [(None, "/")], - reqs: [(404, "/a/b/c"), (500, "/a/b"), (415, "/a/b/d"), (422, "/a/b/c/d?foo")], - with: [(None, "/"), (None, "/"), (None, "/"), (None, "/")] - } - - // Check prefixes when they're exact. - assert_catcher_routing! { - catch: [(None, "/"), (None, "/a"), (None, "/a/b")], - reqs: [ - (404, "/"), (500, "/"), - (404, "/a"), (500, "/a"), - (404, "/a/b"), (500, "/a/b") - ], - with: [ - (None, "/"), (None, "/"), - (None, "/a"), (None, "/a"), - (None, "/a/b"), (None, "/a/b") - ] - } + // TODO: Typed: update tests for new logic - catch got split into two methods. + // // Check that the default `/` catcher catches everything. + // assert_catcher_routing! { + // catch: [(None, "/")], + // reqs: [(404, "/a/b/c"), (500, "/a/b"), (415, "/a/b/d"), (422, "/a/b/c/d?foo")], + // with: [(None, "/"), (None, "/"), (None, "/"), (None, "/")] + // } + + // // Check prefixes when they're exact. + // assert_catcher_routing! { + // catch: [(None, "/"), (None, "/a"), (None, "/a/b")], + // reqs: [ + // (404, "/"), (500, "/"), + // (404, "/a"), (500, "/a"), + // (404, "/a/b"), (500, "/a/b") + // ], + // with: [ + // (None, "/"), (None, "/"), + // (None, "/a"), (None, "/a"), + // (None, "/a/b"), (None, "/a/b") + // ] + // } // Check prefixes when they're not exact. - assert_catcher_routing! { - catch: [(None, "/"), (None, "/a"), (None, "/a/b")], - reqs: [ - (404, "/foo"), (500, "/bar"), (422, "/baz/bar"), (418, "/poodle?yes"), - (404, "/a/foo"), (500, "/a/bar/baz"), (510, "/a/c"), (423, "/a/c/b"), - (404, "/a/b/c"), (500, "/a/b/c/d"), (500, "/a/b?foo"), (400, "/a/b/yes") - ], - with: [ - (None, "/"), (None, "/"), (None, "/"), (None, "/"), - (None, "/a"), (None, "/a"), (None, "/a"), (None, "/a"), - (None, "/a/b"), (None, "/a/b"), (None, "/a/b"), (None, "/a/b") - ] - } + // assert_catcher_routing! { + // catch: [(None, "/"), (None, "/a"), (None, "/a/b")], + // reqs: [ + // (404, "/foo"), (500, "/bar"), (422, "/baz/bar"), (418, "/poodle?yes"), + // (404, "/a/foo"), (500, "/a/bar/baz"), (510, "/a/c"), (423, "/a/c/b"), + // (404, "/a/b/c"), (500, "/a/b/c/d"), (500, "/a/b?foo"), (400, "/a/b/yes") + // ], + // with: [ + // (None, "/"), (None, "/"), (None, "/"), (None, "/"), + // (None, "/a"), (None, "/a"), (None, "/a"), (None, "/a"), + // (None, "/a/b"), (None, "/a/b"), (None, "/a/b"), (None, "/a/b") + // ] + // } // Check that we prefer specific to default. - assert_catcher_routing! { - catch: [(400, "/"), (404, "/"), (None, "/")], - reqs: [ - (400, "/"), (400, "/bar"), (400, "/foo/bar"), - (404, "/"), (404, "/bar"), (404, "/foo/bar"), - (405, "/"), (405, "/bar"), (406, "/foo/bar") - ], - with: [ - (400, "/"), (400, "/"), (400, "/"), - (404, "/"), (404, "/"), (404, "/"), - (None, "/"), (None, "/"), (None, "/") - ] - } + // assert_catcher_routing! { + // catch: [(400, "/"), (404, "/"), (None, "/")], + // reqs: [ + // (400, "/"), (400, "/bar"), (400, "/foo/bar"), + // (404, "/"), (404, "/bar"), (404, "/foo/bar"), + // (405, "/"), (405, "/bar"), (406, "/foo/bar") + // ], + // with: [ + // (400, "/"), (400, "/"), (400, "/"), + // (404, "/"), (404, "/"), (404, "/"), + // (None, "/"), (None, "/"), (None, "/") + // ] + // } // Check that we prefer longer prefixes over specific. - assert_catcher_routing! { - catch: [(None, "/a/b"), (404, "/a"), (422, "/a")], - reqs: [ - (404, "/a/b"), (404, "/a/b/c"), (422, "/a/b/c"), - (404, "/a"), (404, "/a/c"), (404, "/a/cat/bar"), - (422, "/a"), (422, "/a/c"), (422, "/a/cat/bar") - ], - with: [ - (None, "/a/b"), (None, "/a/b"), (None, "/a/b"), - (404, "/a"), (404, "/a"), (404, "/a"), - (422, "/a"), (422, "/a"), (422, "/a") - ] - } + // assert_catcher_routing! { + // catch: [(None, "/a/b"), (404, "/a"), (422, "/a")], + // reqs: [ + // (404, "/a/b"), (404, "/a/b/c"), (422, "/a/b/c"), + // (404, "/a"), (404, "/a/c"), (404, "/a/cat/bar"), + // (422, "/a"), (422, "/a/c"), (422, "/a/cat/bar") + // ], + // with: [ + // (None, "/a/b"), (None, "/a/b"), (None, "/a/b"), + // (404, "/a"), (404, "/a"), (404, "/a"), + // (422, "/a"), (422, "/a"), (422, "/a") + // ] + // } // Just a fun one. - assert_catcher_routing! { - catch: [(None, "/"), (None, "/a/b"), (500, "/a/b/c"), (500, "/a/b")], - reqs: [(404, "/a/b/c"), (500, "/a/b"), (400, "/a/b/d"), (500, "/a/b/c/d?foo")], - with: [(None, "/a/b"), (500, "/a/b"), (None, "/a/b"), (500, "/a/b/c")] - } + // assert_catcher_routing! { + // catch: [(None, "/"), (None, "/a/b"), (500, "/a/b/c"), (500, "/a/b")], + // reqs: [(404, "/a/b/c"), (500, "/a/b"), (400, "/a/b/d"), (500, "/a/b/c/d?foo")], + // with: [(None, "/a/b"), (500, "/a/b"), (None, "/a/b"), (500, "/a/b/c")] + // } } } diff --git a/core/lib/src/server.rs b/core/lib/src/server.rs index 6f8b5e2f0f..1876787a38 100644 --- a/core/lib/src/server.rs +++ b/core/lib/src/server.rs @@ -45,7 +45,10 @@ impl Rocket { |rocket, request, data| Box::pin(rocket.preprocess(request, data)), |token, rocket, request, error_box, data| Box::pin(async move { if !request.errors.is_empty() { - return rocket.dispatch_error(error_box.write(Box::new(Status::BadRequest)), request).await; + return rocket.dispatch_error( + error_box.write(Box::new(Status::BadRequest)), + request + ).await; } rocket.dispatch(token, request, error_box, data).await diff --git a/core/lib/src/state.rs b/core/lib/src/state.rs index c256c7f0ec..e88fc8551b 100644 --- a/core/lib/src/state.rs +++ b/core/lib/src/state.rs @@ -72,9 +72,9 @@ use crate::http::Status; /// /// #[rocket::async_trait] /// impl<'r> FromRequest<'r> for Item<'r> { -/// type Error = (); +/// type Error = Status; /// -/// async fn from_request(request: &'r Request<'_>) -> request::Outcome { +/// async fn from_request(request: &'r Request<'_>) -> request::Outcome { /// // Using `State` as a request guard. Use `inner()` to get an `'r`. /// let outcome = request.guard::<&State>().await /// .map(|my_config| Item(&my_config.user_val)); diff --git a/core/lib/tests/panic-handling.rs b/core/lib/tests/panic-handling.rs index f5e8c1aea5..bfc2ed1101 100644 --- a/core/lib/tests/panic-handling.rs +++ b/core/lib/tests/panic-handling.rs @@ -4,6 +4,7 @@ use rocket::{Request, Rocket, Route, Catcher, Build, route, catcher}; use rocket::data::Data; use rocket::http::{Method, Status}; use rocket::local::blocking::Client; +use rocket::catcher::TypedError; #[get("/panic")] fn panic_route() -> &'static str { @@ -73,7 +74,11 @@ fn catches_early_route_panic() { #[test] fn catches_early_catcher_panic() { - fn pre_future_catcher<'r>(_: Status, _: &'r Request<'_>) -> catcher::BoxFuture<'r> { + fn pre_future_catcher<'r>( + _: Status, + _: &'r dyn TypedError<'r>, + _: &'r Request<'_> + ) -> catcher::BoxFuture<'r> { panic!("a panicking pre-future catcher") } diff --git a/core/lib/tests/sentinel.rs b/core/lib/tests/sentinel.rs index d88e99b98d..ffeda27e76 100644 --- a/core/lib/tests/sentinel.rs +++ b/core/lib/tests/sentinel.rs @@ -110,7 +110,7 @@ struct Data; #[crate::async_trait] impl<'r> data::FromData<'r> for Data { - type Error = Error; + type Error = std::io::Error; async fn from_data(_: &'r Request<'_>, _: data::Data<'r>) -> data::Outcome<'r, Self> { unimplemented!() } @@ -151,10 +151,11 @@ fn inner_sentinels_detected() { #[derive(Responder)] struct MyThing(T); + #[derive(TypedError)] struct ResponderSentinel; impl<'r, 'o: 'r> response::Responder<'r, 'o> for ResponderSentinel { - fn respond_to(self, _: &'r Request<'_>) -> response::Result<'o> { + fn respond_to(self, _: &'r Request<'_>) -> response::Result<'r, 'o> { unimplemented!() } } @@ -171,7 +172,7 @@ fn inner_sentinels_detected() { let err = Client::debug_with(routes![route]).unwrap_err(); assert!(matches!(err.kind(), SentinelAborts(vec) if vec.len() == 1)); - #[derive(Responder)] + #[derive(Responder, TypedError)] struct Inner(T); #[get("/")] diff --git a/docs/guide/05-requests.md b/docs/guide/05-requests.md index 5e8e0bcf95..604bbf4ca5 100644 --- a/docs/guide/05-requests.md +++ b/docs/guide/05-requests.md @@ -557,6 +557,7 @@ dynamically: # #[macro_use] extern crate rocket; # fn main() {} +# use rocket::either::Either; # type Template = (); # type AdminUser = rocket::http::Method; # type User = rocket::http::Method; @@ -566,9 +567,11 @@ dynamically: use rocket::response::Redirect; #[get("/admin", rank = 2)] -fn admin_panel_user(user: Option) -> Result<&'static str, Redirect> { - let user = user.ok_or_else(|| Redirect::to(uri!(login)))?; - Ok("Sorry, you must be an administrator to access this page.") +fn admin_panel_user(user: Option) -> Either<&'static str, Redirect> { + match user { + Some(user) => Either::Left("Sorry, you must be an administrator to access this page."), + None => Either::Right(Redirect::to(uri!(login))), + } } ``` diff --git a/docs/guide/06-responses.md b/docs/guide/06-responses.md index 1b9fcd29c7..5c73847e0b 100644 --- a/docs/guide/06-responses.md +++ b/docs/guide/06-responses.md @@ -275,7 +275,7 @@ use rocket::http::ContentType; # struct String(std::string::String); #[rocket::async_trait] impl<'r> Responder<'r, 'static> for String { - fn respond_to(self, _: &'r Request<'_>) -> response::Result<'static> { + fn respond_to(self, _: &'r Request<'_>) -> response::Result<'r, 'static> { Response::build() .header(ContentType::Plain) # /* @@ -327,30 +327,29 @@ async fn files(file: PathBuf) -> Option { ### `Result` `Result` is another _wrapping_ responder: a `Result` can only be returned -when `T` implements `Responder` and `E` implements `Responder`. +when `T` implements `Responder` and `E` implements `TypedError`. -The wrapped `Responder` in `Ok` or `Err`, whichever it might be, is used to -respond to the client. This means that the responder can be chosen dynamically -at run-time, and two different kinds of responses can be used depending on the -circumstances. Revisiting our file server, for instance, we might wish to -provide more feedback to the user when a file isn't found. We might do this as -follows: +`Result` either uses an `Ok` variant as a responder, or it passes the value in +the `Err` variant to Rocket's catcher mechanism. ```rust # #[macro_use] extern crate rocket; # fn main() {} # use std::path::{Path, PathBuf}; +use std::io; use rocket::fs::NamedFile; use rocket::response::status::NotFound; #[get("/")] -async fn files(file: PathBuf) -> Result> { +async fn files(file: PathBuf) -> Result> { let path = Path::new("static/").join(file); - NamedFile::open(&path).await.map_err(|e| NotFound(e.to_string())) + NamedFile::open(&path).await.map_err(|e| NotFound(e)) } ``` +TODO: Typed: show catching mechanism here + ## Rocket Responders Some of Rocket's best features are implemented through responders. Among these diff --git a/docs/guide/07-state.md b/docs/guide/07-state.md index bbea7fddbe..edb9536d70 100644 --- a/docs/guide/07-state.md +++ b/docs/guide/07-state.md @@ -136,9 +136,9 @@ struct Item<'r>(&'r str); #[rocket::async_trait] impl<'r> FromRequest<'r> for Item<'r> { - type Error = (); + type Error = Status; - async fn from_request(request: &'r Request<'_>) -> request::Outcome { + async fn from_request(request: &'r Request<'_>) -> request::Outcome { // Using `State` as a request guard. Use `inner()` to get an `'r`. let outcome = request.guard::<&State>().await .map(|my_config| Item(&my_config.user_val)); diff --git a/docs/guide/12-pastebin.md b/docs/guide/12-pastebin.md index 7b41470720..36182be5f8 100644 --- a/docs/guide/12-pastebin.md +++ b/docs/guide/12-pastebin.md @@ -326,16 +326,20 @@ Here's the `FromParam` implementation for `PasteId` in `src/paste_id.rs`: use rocket::request::FromParam; # use std::borrow::Cow; # pub struct PasteId<'a>(Cow<'a, str>); +# use rocket::TypedError; + +#[derive(Debug, TypedError)] +pub struct InvalidPasteId; /// Returns an instance of `PasteId` if the path segment is a valid ID. -/// Otherwise returns the invalid ID as the `Err` value. +/// Otherwise returns an `InvalidPasteId` as the error. impl<'a> FromParam<'a> for PasteId<'a> { - type Error = &'a str; + type Error = InvalidPasteId; fn from_param(param: &'a str) -> Result { param.chars().all(|c| c.is_ascii_alphanumeric()) .then(|| PasteId(param.into())) - .ok_or(param) + .ok_or(InvalidPasteId) } } ``` @@ -363,8 +367,10 @@ use rocket::tokio::fs::File; # pub fn new(size: usize) -> PasteId<'static> { todo!() } # pub fn file_path(&self) -> PathBuf { todo!() } # } +# #[derive(rocket::TypedError)] +# pub struct InvalidPasteId; # impl<'a> FromParam<'a> for PasteId<'a> { -# type Error = &'a str; +# type Error = InvalidPasteId; # fn from_param(param: &'a str) -> Result { todo!() } # } @@ -440,8 +446,10 @@ pub struct PasteId<'a>(Cow<'a, str>); # pub fn file_path(&self) -> PathBuf { todo!() } # } # +# #[derive(rocket::TypedError)] +# pub struct InvalidPasteId; # impl<'a> FromParam<'a> for PasteId<'a> { -# type Error = &'a str; +# type Error = InvalidPasteId; # fn from_param(param: &'a str) -> Result { todo!() } # } // We implement the `upload` route in `main.rs`: diff --git a/docs/guide/14-faq.md b/docs/guide/14-faq.md index e4dabca4b8..9547f53606 100644 --- a/docs/guide/14-faq.md +++ b/docs/guide/14-faq.md @@ -420,7 +420,7 @@ use rocket::response::{self, Response, Responder}; use rocket::serde::json::Json; impl<'r> Responder<'r, 'static> for Person { - fn respond_to(self, req: &'r Request<'_>) -> response::Result<'static> { + fn respond_to(self, req: &'r Request<'_>) -> response::Result<'r, 'static> { Response::build_from(Json(&self).respond_to(req)?) .raw_header("X-Person-Name", self.name) .raw_header("X-Person-Age", self.age.to_string()) diff --git a/examples/manual-routing/src/main.rs b/examples/manual-routing/src/main.rs index 5765fb8ef4..79547d9fc6 100644 --- a/examples/manual-routing/src/main.rs +++ b/examples/manual-routing/src/main.rs @@ -63,7 +63,11 @@ fn get_upload<'r>(req: &'r Request, _: Data<'r>) -> route::BoxFuture<'r> { route::Outcome::from(req, std::fs::File::open(path).ok()).pin() } -fn not_found_handler<'r>(_: Status, _: &'r dyn TypedError<'r>, req: &'r Request) -> catcher::BoxFuture<'r> { +fn not_found_handler<'r>( + _: Status, + _: &'r dyn TypedError<'r>, + req: &'r Request, +) -> catcher::BoxFuture<'r> { let responder = Custom(Status::NotFound, format!("Couldn't find: {}", req.uri())); Box::pin(async move { responder.respond_to(req).map_err(|e| e.status()) }) } diff --git a/examples/pastebin/src/paste_id.rs b/examples/pastebin/src/paste_id.rs index 284d34cd85..841196ed17 100644 --- a/examples/pastebin/src/paste_id.rs +++ b/examples/pastebin/src/paste_id.rs @@ -34,16 +34,16 @@ impl PasteId<'_> { #[derive(Debug, TypedError)] #[error(debug)] -pub struct InvalidId<'a>(pub &'a str); +pub struct InvalidId; /// Returns an instance of `PasteId` if the path segment is a valid ID. /// Otherwise returns the invalid ID as the `Err` value. impl<'a> FromParam<'a> for PasteId<'a> { - type Error = InvalidId<'a>; + type Error = InvalidId; fn from_param(param: &'a str) -> Result { param.chars().all(|c| c.is_ascii_alphanumeric()) .then(|| PasteId(param.into())) - .ok_or(InvalidId(param)) + .ok_or(InvalidId) } } diff --git a/examples/todo/src/main.rs b/examples/todo/src/main.rs index 8207da0c71..9d84c052b5 100644 --- a/examples/todo/src/main.rs +++ b/examples/todo/src/main.rs @@ -70,7 +70,9 @@ async fn toggle(id: i32, conn: DbConn) -> Either { Ok(_) => Either::Left(Redirect::to("/")), Err(e) => { error!("DB toggle({id}) error: {e}"); - Either::Right(Template::render("index", Context::err(&conn, "Failed to toggle task.").await)) + Either::Right( + Template::render("index", Context::err(&conn, "Failed to toggle task.").await) + ) } } } @@ -81,7 +83,9 @@ async fn delete(id: i32, conn: DbConn) -> Either, Template> { Ok(_) => Either::Left(Flash::success(Redirect::to("/"), "Todo was deleted.")), Err(e) => { error!("DB deletion({id}) error: {e}"); - Either::Right(Template::render("index", Context::err(&conn, "Failed to delete task.").await)) + Either::Right( + Template::render("index", Context::err(&conn, "Failed to delete task.").await) + ) } } } diff --git a/testbench/src/servers/tracing.rs b/testbench/src/servers/tracing.rs index b4ada1b3fa..ce5feefbbf 100644 --- a/testbench/src/servers/tracing.rs +++ b/testbench/src/servers/tracing.rs @@ -3,17 +3,16 @@ use std::fmt; -use rocket::http::Status; use rocket::data::{self, FromData}; use rocket::http::uri::{Segments, fmt::Path}; use rocket::request::{self, FromParam, FromRequest, FromSegments}; use crate::prelude::*; -#[derive(Debug)] +#[derive(Debug, TypedError)] struct UseDisplay(&'static str); -#[derive(Debug)] +#[derive(Debug, TypedError)] struct UseDebug; impl fmt::Display for UseDisplay { @@ -36,7 +35,7 @@ impl FromParam<'_> for UseDebug { impl<'r> FromRequest<'r> for UseDisplay { type Error = Self; async fn from_request(_: &'r Request<'_>) -> request::Outcome { - request::Outcome::Error((Status::InternalServerError, Self("req"))) + request::Outcome::Error(Self("req")) } } @@ -44,7 +43,7 @@ impl<'r> FromRequest<'r> for UseDisplay { impl<'r> FromRequest<'r> for UseDebug { type Error = Self; async fn from_request(_: &'r Request<'_>) -> request::Outcome { - request::Outcome::Error((Status::InternalServerError, Self)) + request::Outcome::Error(Self) } } @@ -52,7 +51,7 @@ impl<'r> FromRequest<'r> for UseDebug { impl<'r> FromData<'r> for UseDisplay { type Error = Self; async fn from_data(_: &'r Request<'_>, _: Data<'r>) -> data::Outcome<'r, Self> { - data::Outcome::Error((Status::InternalServerError, Self("data"))) + data::Outcome::Error(Self("data")) } } @@ -60,7 +59,7 @@ impl<'r> FromData<'r> for UseDisplay { impl<'r> FromData<'r> for UseDebug { type Error = Self; async fn from_data(_: &'r Request<'_>, _: Data<'r>) -> data::Outcome<'r, Self> { - data::Outcome::Error((Status::InternalServerError, Self)) + data::Outcome::Error(Self) } } From 32e25a1be01e86414aafdfadc89054611f78f1f0 Mon Sep 17 00:00:00 2001 From: Matthew Pomes Date: Wed, 20 Nov 2024 02:58:51 -0600 Subject: [PATCH 11/20] Add tests, and update collition logic --- core/codegen/src/attribute/catch/mod.rs | 1 + core/codegen/tests/catcher.rs | 101 ++++++++++++++++++++++-- core/codegen/tests/typed_error.rs | 69 ++++++++++++++++ core/lib/src/lifecycle.rs | 24 +++--- core/lib/src/router/collider.rs | 2 +- 5 files changed, 177 insertions(+), 20 deletions(-) create mode 100644 core/codegen/tests/typed_error.rs diff --git a/core/codegen/src/attribute/catch/mod.rs b/core/codegen/src/attribute/catch/mod.rs index 55835cc599..a83b138c78 100644 --- a/core/codegen/src/attribute/catch/mod.rs +++ b/core/codegen/src/attribute/catch/mod.rs @@ -52,6 +52,7 @@ pub fn _catch( Some(v) => v, None => { // TODO: Typed: log failure - this should never happen + println!("Failed to downcast error {}", stringify!(#ty)); return #_Err(#Status::InternalServerError); }, }; diff --git a/core/codegen/tests/catcher.rs b/core/codegen/tests/catcher.rs index 59a9b1b345..3887b9a9d9 100644 --- a/core/codegen/tests/catcher.rs +++ b/core/codegen/tests/catcher.rs @@ -5,14 +5,20 @@ #[macro_use] extern crate rocket; +use std::io; +use std::num::ParseIntError; +use std::str::ParseBoolError; + +use rocket::request::FromParamError; use rocket::{Request, Rocket, Build}; use rocket::local::blocking::Client; use rocket::http::Status; +use rocket_http::uri::Origin; #[catch(404)] fn not_found_0() -> &'static str { "404-0" } -#[catch(404)] fn not_found_1(_r: &Request<'_>) -> &'static str { "404-1" } -#[catch(404)] fn not_found_2(_s: Status, _r: &Request<'_>) -> &'static str { "404-2" } -#[catch(default)] fn all(_s: Status, r: &Request<'_>) -> String { r.uri().to_string() } +#[catch(404)] fn not_found_1() -> &'static str { "404-1" } +#[catch(404)] fn not_found_2() -> &'static str { "404-2" } +#[catch(default)] fn all(r: &Request<'_>) -> String { r.uri().to_string() } #[test] fn test_simple_catchers() { @@ -37,10 +43,10 @@ fn test_simple_catchers() { } #[get("/")] fn forward(code: u16) -> Status { Status::new(code) } -#[catch(400)] fn forward_400(status: Status, _r: &Request<'_>) -> String { status.code.to_string() } -#[catch(404)] fn forward_404(status: Status, _r: &Request<'_>) -> String { status.code.to_string() } -#[catch(444)] fn forward_444(status: Status, _r: &Request<'_>) -> String { status.code.to_string() } -#[catch(500)] fn forward_500(status: Status, _r: &Request<'_>) -> String { status.code.to_string() } +#[catch(400)] fn forward_400(status: Status) -> String { status.code.to_string() } +#[catch(404)] fn forward_404(status: Status) -> String { status.code.to_string() } +#[catch(444)] fn forward_444(status: Status) -> String { status.code.to_string() } +#[catch(500)] fn forward_500(status: Status) -> String { status.code.to_string() } #[test] fn test_status_param() { @@ -58,3 +64,84 @@ fn test_status_param() { assert_eq!(response.into_string().unwrap(), code.to_string()); } } + + +#[catch(400)] fn test_status(status: Status) -> String { format!("{}", status.code) } +#[catch(404)] fn test_request(r: &Request<'_>) -> String { format!("{}", r.uri()) } +#[catch(444)] fn test_uri(uri: &Origin<'_>) -> String { format!("{}", uri) } + +#[test] +fn test_basic_params() { + fn rocket() -> Rocket { + rocket::build() + .mount("/", routes![forward]) + .register("/", catchers![ + test_status, + test_request, + test_uri, + ]) + } + + let client = Client::debug(rocket()).unwrap(); + let response = client.get(uri!(forward(400))).dispatch(); + assert_eq!(response.status(), Status::BadRequest); + assert_eq!(response.into_string().unwrap(), "400"); + + let response = client.get(uri!(forward(404))).dispatch(); + assert_eq!(response.status(), Status::NotFound); + assert_eq!(response.into_string().unwrap(), "/404"); + + let response = client.get(uri!(forward(444))).dispatch(); + assert_eq!(response.status(), Status::new(444)); + assert_eq!(response.into_string().unwrap(), "/444"); +} + +#[get("/c/")] fn read_int(code: u16) -> String { format!("{code}") } +#[get("/b/")] fn read_bool(code: bool) -> String { format!("{code}") } +#[get("/b/force")] fn force_bool_error() -> Result<&'static str, ParseBoolError> { + "smt".parse::().map(|_| todo!()) +} + +#[catch(default, error = "")] +fn test_io_error(e: &io::Error) -> String { format!("{e:?}") } +#[catch(default, error = "<_e>")] +fn test_parse_int_error(_e: &ParseIntError) -> String { println!("ParseIntError"); format!("ParseIntError") } +#[catch(default, error = "<_e>")] +fn test_parse_bool_error(_e: &ParseBoolError) -> String { format!("ParseBoolError") } +#[catch(default, error = "")] +fn test_param_parse_bool_error(e: &FromParamError<'_, ParseBoolError>) -> String { format!("ParseBoolError: {}", e.raw) } + + +#[test] +fn test_error_types() { + fn rocket() -> Rocket { + rocket::build() + .mount("/", routes![read_int, read_bool, force_bool_error]) + .register("/", catchers![ + test_io_error, + test_parse_int_error, + test_parse_bool_error, + test_param_parse_bool_error, + ]) + } + + let client = Client::debug(rocket()).unwrap(); + let response = client.get(uri!(read_int(400))).dispatch(); + assert_eq!(response.status(), Status::Ok); + assert_eq!(response.into_string().unwrap(), "400"); + + let response = client.get(uri!("/c/40000000")).dispatch(); + assert_eq!(response.status(), Status::UnprocessableEntity); + assert_eq!(response.into_string().unwrap(), "ParseIntError"); + + let response = client.get(uri!(read_bool(true))).dispatch(); + assert_eq!(response.status(), Status::Ok); + assert_eq!(response.into_string().unwrap(), "true"); + + let response = client.get(uri!("/b/smt")).dispatch(); + assert_eq!(response.status(), Status::UnprocessableEntity); + assert_eq!(response.into_string().unwrap(), "ParseBoolError: smt"); + let response = client.get(uri!("/b/force")).dispatch(); + assert_eq!(response.status(), Status::BadRequest); + assert_eq!(response.into_string().unwrap(), "ParseBoolError"); +} diff --git a/core/codegen/tests/typed_error.rs b/core/codegen/tests/typed_error.rs new file mode 100644 index 0000000000..cbe9ae6cfd --- /dev/null +++ b/core/codegen/tests/typed_error.rs @@ -0,0 +1,69 @@ +#[macro_use] extern crate rocket; +use rocket::catcher::TypedError; +use rocket::http::Status; + +fn boxed_error<'r>(_val: Box + 'r>) {} + +#[derive(TypedError)] +pub enum Foo<'r> { + First(String), + Second(Vec), + Third { + #[error(source)] + responder: std::io::Error, + }, + #[error(status = 400)] + Fourth { + string: &'r str, + }, +} + +#[test] +fn validate_foo() { + let first = Foo::First("".into()); + assert_eq!(first.status(), Status::InternalServerError); + assert!(first.source(0).is_none()); + boxed_error(Box::new(first)); + let second = Foo::Second(vec![]); + assert_eq!(second.status(), Status::InternalServerError); + assert!(second.source(0).is_none()); + boxed_error(Box::new(second)); + let third = Foo::Third { + responder: std::io::Error::new(std::io::ErrorKind::NotFound, ""), + }; + assert_eq!(third.status(), Status::InternalServerError); + assert!(std::ptr::eq( + third.source(0).unwrap(), + if let Foo::Third { responder } = &third { responder } else { panic!() } + )); + boxed_error(Box::new(third)); + let fourth = Foo::Fourth { string: "" }; + assert_eq!(fourth.status(), Status::BadRequest); + assert!(fourth.source(0).is_none()); + boxed_error(Box::new(fourth)); +} + +#[derive(TypedError)] +pub struct InfallibleError { + #[error(source)] + _inner: std::convert::Infallible, +} + +#[derive(TypedError)] +pub struct StaticError { + #[error(source)] + inner: std::string::FromUtf8Error, +} + +#[test] +fn validate_static() { + let val = StaticError { + inner: String::from_utf8(vec![0xFF]).unwrap_err(), + }; + assert_eq!(val.status(), Status::InternalServerError); + assert!(std::ptr::eq( + val.source(0).unwrap(), + &val.inner, + )); + boxed_error(Box::new(val)); +} diff --git a/core/lib/src/lifecycle.rs b/core/lib/src/lifecycle.rs index 0768e6bcf1..24df7c6097 100644 --- a/core/lib/src/lifecycle.rs +++ b/core/lib/src/lifecycle.rs @@ -298,22 +298,22 @@ impl Rocket { error: &'r dyn TypedError<'r>, req: &'r Request<'s>, depth: usize, - ) -> Option<&'s Catcher> { + ) -> Option<(&'s Catcher, &'r dyn TypedError<'r>)> { const MAX_CALLS_TO_SOURCE: usize = 5; if depth > MAX_CALLS_TO_SOURCE { return None; } - let mut min = self.router.catch(status, Some(error), req); + let mut min = self.router.catch(status, Some(error), req).map(|s| (s, error)); if let Some(catcher) = self.router.catch_any(status, Some(error), req) { - if min.is_none_or(|m| m.rank > catcher.rank) { - min = Some(catcher); + if min.is_none_or(|(m, _)| m.rank > catcher.rank) { + min = Some((catcher, error)); } } for i in 0..MAX_CALLS_TO_SOURCE { let Some(val) = error.source(i) else { break; }; - if let Some(catcher) = self.get_min(status, val, req, depth + 1) { - if min.is_none_or(|m| m.rank > catcher.rank) { - min = Some(catcher); + if let Some((catcher, error)) = self.get_min(status, val, req, depth + 1) { + if min.is_none_or(|(m, _)| m.rank > catcher.rank) { + min = Some((catcher, error)); } } } @@ -346,16 +346,16 @@ impl Rocket { let status = error.status(); let mut min = self.get_min(status, error, req, 0); if let Some(catcher) = self.router.catch(status, None, req) { - if min.is_none_or(|m| m.rank > catcher.rank) { - min = Some(catcher); + if min.is_none_or(|(m, _)| m.rank > catcher.rank) { + min = Some((catcher, error)); } } if let Some(catcher) = self.router.catch_any(status, None, req) { - if min.is_none_or(|m| m.rank > catcher.rank) { - min = Some(catcher); + if min.is_none_or(|(m, _)| m.rank > catcher.rank) { + min = Some((catcher, error)); } } - if let Some(catcher) = min { + if let Some((catcher, error)) = min { catcher.trace_info(); catch_handle(catcher.name.as_deref(), || catcher.handler.handle(status, error, req)) .await diff --git a/core/lib/src/router/collider.rs b/core/lib/src/router/collider.rs index d55ded8ec6..c62d3ef6ad 100644 --- a/core/lib/src/router/collider.rs +++ b/core/lib/src/router/collider.rs @@ -141,7 +141,7 @@ impl Catcher { /// assert!(!a.collides_with(&b)); /// ``` pub fn collides_with(&self, other: &Self) -> bool { - self.code == other.code && self.base().segments().eq(other.base().segments()) + self.code == other.code && self.base().segments().eq(other.base().segments()) && self.type_id == other.type_id } } From ddc84171bae7dd6673cf87b127e0dbcc65e4fed4 Mon Sep 17 00:00:00 2001 From: Matthew Pomes Date: Wed, 20 Nov 2024 03:45:36 -0600 Subject: [PATCH 12/20] Add better tests --- core/codegen/src/attribute/catch/mod.rs | 30 ++-- core/lib/src/router/router.rs | 175 ++++++++++-------------- 2 files changed, 90 insertions(+), 115 deletions(-) diff --git a/core/codegen/src/attribute/catch/mod.rs b/core/codegen/src/attribute/catch/mod.rs index a83b138c78..bd1f081cf1 100644 --- a/core/codegen/src/attribute/catch/mod.rs +++ b/core/codegen/src/attribute/catch/mod.rs @@ -1,6 +1,6 @@ mod parse; -use devise::ext::TypeExt; +use devise::ext::{SpanDiagnosticExt, TypeExt}; use devise::{Spanned, Result}; use proc_macro2::{TokenStream, Span}; use syn::{Lifetime, TypeReference}; @@ -36,9 +36,16 @@ pub fn _catch( #ty as #FromError<'__r> >::from_error(#__status, #__req, #__error).await { #_Ok(v) => v, - #_Err(s) => { - // TODO: Typed: log failure - return #_Err(s); + #_Err(__e) => { + ::rocket::trace::info!( + name: "error", + target: concat!("rocket::codegen::route::", module_path!()), + parameter = stringify!(#name), + type_name = stringify!(#ty), + status = __e.code, + "error guard error" + ); + return #_Err(__e); }, }; ) @@ -51,8 +58,11 @@ pub fn _catch( let #name: #ty = match #_catcher::downcast(#__error) { Some(v) => v, None => { - // TODO: Typed: log failure - this should never happen - println!("Failed to downcast error {}", stringify!(#ty)); + ::rocket::trace::error!( + downcast_to = stringify!(#ty), + error_name = #__error.name(), + "Failed to downcast error. This should never happen, please open an issue with details." + ); return #_Err(#Status::InternalServerError); }, }; @@ -64,12 +74,12 @@ pub fn _catch( syn::Type::Reference(TypeReference { mutability: None, elem, .. }) => { elem.as_ref().with_stripped_lifetimes() }, - _ => todo!("Invalid type"), + _ => return Err(g.ty.span().error("invalid type, must be a reference")), }; - quote_spanned!(g.span() => + Ok(quote_spanned!(g.span() => #_catcher::TypeId::of::<#ty>() - ) - })); + )) + }).transpose()?); // We append `.await` to the function call if this is `async`. let dot_await = catch.function.sig.asyncness diff --git a/core/lib/src/router/router.rs b/core/lib/src/router/router.rs index fa8be78b97..b810c21bd1 100644 --- a/core/lib/src/router/router.rs +++ b/core/lib/src/router/router.rs @@ -151,8 +151,11 @@ impl DerefMut for Router { #[cfg(test)] mod test { + use transient::TypeId; + use super::*; + use crate::catcher; use crate::route::dummy_handler; use crate::local::blocking::Client; use crate::http::{Method::*, uri::Origin}; @@ -585,121 +588,83 @@ mod test { ); } - fn router_with_catchers(catchers: &[(Option, &str)]) -> Result> { + fn make_router_catches(catchers: I) -> Result, Collisions> + where I: IntoIterator, &'static str, Option)> + { let mut router = Router::new(); - for (code, base) in catchers { - let catcher = Catcher::new(*code, crate::catcher::dummy_handler); - router.catchers.push(catcher.map_base(|_| base.to_string()).unwrap()); + for (status, base, ty) in catchers { + let mut catcher = Catcher::new(status.map(|s| s.code), catcher::dummy_handler).rebase(Origin::parse(base).unwrap()); + catcher.type_id = ty; + router.catchers.push(catcher); } router.finalize() } + #[test] + fn test_catcher_collisions() { + #[derive(transient::Transient)] + struct A; + assert!(make_router_catches([ + (Some(Status::new(400)), "/", None), + (Some(Status::new(400)), "/", None) + ]).is_err()); + assert!(make_router_catches([ + (Some(Status::new(400)), "/", Some(TypeId::of::<()>())), + (Some(Status::new(400)), "/", None) + ]).is_ok()); + assert!(make_router_catches([ + (Some(Status::new(400)), "/", Some(TypeId::of::<()>())), + (Some(Status::new(400)), "/", Some(TypeId::of::<()>())), + ]).is_err()); + assert!(make_router_catches([ + (Some(Status::new(400)), "/", Some(TypeId::of::<()>())), + (Some(Status::new(400)), "/", Some(TypeId::of::())), + ]).is_ok()); + } + #[track_caller] - fn catcher<'a>(r: &'a Router, status: Status, uri: &str) -> Option<&'a Catcher> { + fn catches<'a>( + router: &'a Router, + status: Status, + uri: &'a str, + ty: for<'r> fn(&'r ()) -> Option<&'r dyn TypedError<'r>> + ) -> Option<&'a Catcher> { let client = Client::debug_with(vec![]).expect("client"); - let request = client.get(Origin::parse(uri).unwrap()); - r.catch(status, None, &request) - } - - macro_rules! assert_catcher_routing { - ( - catch: [$(($code:expr, $uri:expr)),+], - reqs: [$($r:expr),+], - with: [$(($ecode:expr, $euri:expr)),+] - ) => ({ - let catchers = vec![$(($code.into(), $uri)),+]; - let requests = vec![$($r),+]; - let expected = vec![$(($ecode.into(), $euri)),+]; - - let router = router_with_catchers(&catchers).expect("valid router"); - for (req, expected) in requests.iter().zip(expected.iter()) { - let req_status = Status::from_code(req.0).expect("valid status"); - let catcher = catcher(&router, req_status, req.1).expect("some catcher"); - assert_eq!(catcher.code, expected.0, - "\nmatched {:?}, expected {:?} for req {:?}", catcher, expected, req); - - assert_eq!(catcher.base.path(), expected.1, - "\nmatched {:?}, expected {:?} for req {:?}", catcher, expected, req); - } - }) + let request = client.req(Method::Get, Origin::parse(uri).unwrap()); + router.catch(status, ty(&()), &request) + } + + #[track_caller] + fn catches_any<'a>( + router: &'a Router, + status: Status, + uri: &'a str, + ty: for<'r> fn(&'r ()) -> Option<&'r dyn TypedError<'r>> + ) -> Option<&'a Catcher> { + let client = Client::debug_with(vec![]).expect("client"); + let request = client.req(Method::Get, Origin::parse(uri).unwrap()); + router.catch_any(status, ty(&()), &request) + } + + #[test] + fn test_catch_vs_catch_any() { + let router = make_router_catches([(None, "/", None)]).unwrap(); + + assert!(catches(&router, Status::BadRequest, "/", |_| None).is_none()); + assert!(catches_any(&router, Status::BadRequest, "/", |_| None).is_some()); } #[test] - fn test_catcher_routing() { - // TODO: Typed: update tests for new logic - catch got split into two methods. - // // Check that the default `/` catcher catches everything. - // assert_catcher_routing! { - // catch: [(None, "/")], - // reqs: [(404, "/a/b/c"), (500, "/a/b"), (415, "/a/b/d"), (422, "/a/b/c/d?foo")], - // with: [(None, "/"), (None, "/"), (None, "/"), (None, "/")] - // } - - // // Check prefixes when they're exact. - // assert_catcher_routing! { - // catch: [(None, "/"), (None, "/a"), (None, "/a/b")], - // reqs: [ - // (404, "/"), (500, "/"), - // (404, "/a"), (500, "/a"), - // (404, "/a/b"), (500, "/a/b") - // ], - // with: [ - // (None, "/"), (None, "/"), - // (None, "/a"), (None, "/a"), - // (None, "/a/b"), (None, "/a/b") - // ] - // } - - // Check prefixes when they're not exact. - // assert_catcher_routing! { - // catch: [(None, "/"), (None, "/a"), (None, "/a/b")], - // reqs: [ - // (404, "/foo"), (500, "/bar"), (422, "/baz/bar"), (418, "/poodle?yes"), - // (404, "/a/foo"), (500, "/a/bar/baz"), (510, "/a/c"), (423, "/a/c/b"), - // (404, "/a/b/c"), (500, "/a/b/c/d"), (500, "/a/b?foo"), (400, "/a/b/yes") - // ], - // with: [ - // (None, "/"), (None, "/"), (None, "/"), (None, "/"), - // (None, "/a"), (None, "/a"), (None, "/a"), (None, "/a"), - // (None, "/a/b"), (None, "/a/b"), (None, "/a/b"), (None, "/a/b") - // ] - // } - - // Check that we prefer specific to default. - // assert_catcher_routing! { - // catch: [(400, "/"), (404, "/"), (None, "/")], - // reqs: [ - // (400, "/"), (400, "/bar"), (400, "/foo/bar"), - // (404, "/"), (404, "/bar"), (404, "/foo/bar"), - // (405, "/"), (405, "/bar"), (406, "/foo/bar") - // ], - // with: [ - // (400, "/"), (400, "/"), (400, "/"), - // (404, "/"), (404, "/"), (404, "/"), - // (None, "/"), (None, "/"), (None, "/") - // ] - // } - - // Check that we prefer longer prefixes over specific. - // assert_catcher_routing! { - // catch: [(None, "/a/b"), (404, "/a"), (422, "/a")], - // reqs: [ - // (404, "/a/b"), (404, "/a/b/c"), (422, "/a/b/c"), - // (404, "/a"), (404, "/a/c"), (404, "/a/cat/bar"), - // (422, "/a"), (422, "/a/c"), (422, "/a/cat/bar") - // ], - // with: [ - // (None, "/a/b"), (None, "/a/b"), (None, "/a/b"), - // (404, "/a"), (404, "/a"), (404, "/a"), - // (422, "/a"), (422, "/a"), (422, "/a") - // ] - // } - - // Just a fun one. - // assert_catcher_routing! { - // catch: [(None, "/"), (None, "/a/b"), (500, "/a/b/c"), (500, "/a/b")], - // reqs: [(404, "/a/b/c"), (500, "/a/b"), (400, "/a/b/d"), (500, "/a/b/c/d?foo")], - // with: [(None, "/a/b"), (500, "/a/b"), (None, "/a/b"), (500, "/a/b/c")] - // } + fn test_catch_vs_catch_any_ty() { + let router = make_router_catches([ + (None, "/", None), + (None, "/", Some(TypeId::of::<()>())) + ]).unwrap(); + + assert!(catches_any(&router, Status::BadRequest, "/", |_| None).unwrap().type_id.is_none()); + assert!( + catches_any(&router, Status::BadRequest, "/", |_| Some(&())).unwrap().type_id.is_some() + ); } } From 2a63797857a10597ac426e4e739ded334f9c9e5d Mon Sep 17 00:00:00 2001 From: Matthew Pomes Date: Thu, 21 Nov 2024 22:30:51 -0600 Subject: [PATCH 13/20] Update derive docs --- core/codegen/src/lib.rs | 80 ++++++++++++++++++++++++++---- core/lib/src/response/debug.rs | 1 + core/lib/src/response/responder.rs | 1 - 3 files changed, 70 insertions(+), 12 deletions(-) diff --git a/core/codegen/src/lib.rs b/core/codegen/src/lib.rs index 731f0db79e..9e2cbcbce4 100644 --- a/core/codegen/src/lib.rs +++ b/core/codegen/src/lib.rs @@ -320,19 +320,21 @@ route_attribute!(options => Method::Options); /// The grammar for the `#[catch]` attributes is defined as: /// /// ```text -/// catch := STATUS | 'default' +/// catch := (STATUS | 'default') (',' parameter)* /// /// STATUS := valid HTTP status code (integer in [200, 599]) +/// +/// parameter := 'error' '=' '"' SINGLE_PARAMETER '"' +/// SINGLE_PARAM := '<' IDENT '>' /// ``` /// /// # Typing Requirements /// -/// The decorated function may take zero, one, or two arguments. It's type -/// signature must be one of the following, where `R:`[`Responder`]: -/// -/// * `fn() -> R` -/// * `fn(`[`&Request`]`) -> R` -/// * `fn(`[`Status`]`, `[`&Request`]`) -> R` +/// The `error` `SINGLE_PARAMETER`, if present, must be a reference to a type +/// that implements [`TypedError`]. All other parameter types must implement +/// [`FromError`]. There is a blanket impl for any type that implements +/// [`FromRequest`], so most types will work as expected. `Status`, `&Request`, +/// and `&dyn TypedError` also implement [`FromError`]. /// /// # Semantics /// @@ -341,9 +343,7 @@ route_attribute!(options => Method::Options); /// 1. An error [`Handler`]. /// /// The generated handler calls the decorated function, passing in the -/// [`Status`] and [`&Request`] values if requested. The returned value is -/// used to generate a [`Response`] via the type's [`Responder`] -/// implementation. +/// error type, and every other parameter requested. /// /// 2. A static structure used by [`catchers!`] to generate a [`Catcher`]. /// @@ -354,6 +354,8 @@ route_attribute!(options => Method::Options); /// [`&Request`]: ../rocket/struct.Request.html /// [`Status`]: ../rocket/http/struct.Status.html /// [`Handler`]: ../rocket/catcher/trait.Handler.html +/// [`TypedError`]: ../rocket/catcher/trait.TypedError.html +/// [`FromError`]: ../rocket/catcher/trait.FromError.html /// [`catchers!`]: macro.catchers.html /// [`Catcher`]: ../rocket/struct.Catcher.html /// [`Response`]: ../rocket/struct.Response.html @@ -1017,7 +1019,63 @@ pub fn derive_responder(input: TokenStream) -> TokenStream { } /// Derive for the [`TypedError`] trait. -// TODO: Typed: Docs +/// +/// The [`TypedError`] derive can be applied to structs and enums, so +/// they can be used as an error type in Rocket. +/// +/// ```rust +/// # #[macro_use] extern crate rocket; +/// #[derive(TypedError)] +/// struct InvalidCookieValue<'r> { +/// name: &'r str, +/// value: &'r str, +/// } +/// +/// #[derive(TypedError)] +/// enum HeaderError { +/// InvalidValue, +/// Missing, +/// } +/// ``` +/// +/// # Semantics +/// +/// The derive generates an implementation of [`TypedError`] for the decorated +/// struct. The exact implementation can be modiefied using the `error` attribtute. +/// When applied to the outer struct or enum variants, `status` sets the status +/// associated with this type. `source` can only be applied to individual fields, +/// and indicates that the field should be returned from the `source()` method. +/// This means that field can also be used as the `error` type in a catcher. +/// Finally, `debug` generates a default `respond_to` impl, using the +/// [`Debug`](std::fmt::Debug) implementation of the type. +/// +/// ```text +/// response := parameter (',' parameter)? +/// +/// parameter := 'status' '=' STATUS +/// | 'source' +/// | 'debug' +/// +/// STATUS := unsigned integer >= 100 and < 600 +/// ``` +/// +/// # Generics +/// +/// The `TypedError` derive allows at most one lifetime, but as many generic parameters +/// as you want. Generic parameters will be restricted to `'static`. For example: +/// +/// ```rust +/// # #[macro_use] extern crate rocket; +/// // The bound `E: 'static` will be added. +/// #[derive(TypedError)] +/// struct InvalidCookieValue<'r, E> { +/// name: &'r str, +/// value: &'r str, +/// inner_error: E, +/// } +/// ``` +/// +/// [`TypedError`]: ../rocket/catcher/trait.TypedError.html #[proc_macro_derive(TypedError, attributes(error))] pub fn derive_typed_error(input: TokenStream) -> TokenStream { emit!(derive::typed_error::derive_typed_error(input)) diff --git a/core/lib/src/response/debug.rs b/core/lib/src/response/debug.rs index c5c7f2f51b..82fda0cc6e 100644 --- a/core/lib/src/response/debug.rs +++ b/core/lib/src/response/debug.rs @@ -88,6 +88,7 @@ impl<'r, E: std::fmt::Debug> Responder<'r, 'static> for Debug { } // TODO: Typed: This is a stop-gap measure to allow any 'static type to be a `TypedError` +// I think is going to be quite useful going forward, since most error types are 'static impl<'r, E: std::fmt::Debug + Send + Sync + 'static> TypedError<'r> for Debug { fn respond_to(&self, _: &'r Request<'_>) -> Result, Status> { let type_name = std::any::type_name::(); diff --git a/core/lib/src/response/responder.rs b/core/lib/src/response/responder.rs index 239c3b5c1f..dfda9200d8 100644 --- a/core/lib/src/response/responder.rs +++ b/core/lib/src/response/responder.rs @@ -549,7 +549,6 @@ impl<'r> Responder<'r, 'static> for Status { "invalid status used as responder\n\ status must be one of 100, 200..=205, 400..=599"); - // TODO: Typed: Invalid status Err(Box::new(Status::InternalServerError)) } } From aa845adc16767a49d65ad3bece71387511b98246 Mon Sep 17 00:00:00 2001 From: Matthew Pomes Date: Thu, 21 Nov 2024 23:00:39 -0600 Subject: [PATCH 14/20] Updating docs --- contrib/sync_db_pools/lib/src/lib.rs | 208 ++++++--------------------- core/codegen/src/lib.rs | 5 + core/http/src/uri/uri.rs | 4 +- core/lib/src/fairing/mod.rs | 5 +- core/lib/src/form/context.rs | 4 +- 5 files changed, 57 insertions(+), 169 deletions(-) diff --git a/contrib/sync_db_pools/lib/src/lib.rs b/contrib/sync_db_pools/lib/src/lib.rs index 7c35e3b0ba..70c1de0139 100644 --- a/contrib/sync_db_pools/lib/src/lib.rs +++ b/contrib/sync_db_pools/lib/src/lib.rs @@ -402,124 +402,33 @@ pub mod example { /// #[database("example")] /// pub struct ExampleDb(diesel::SqliteConnection); /// ``` - pub struct ExampleDb(crate::Connection); - + pub struct ExampleDb( + #[allow(dead_code)] + crate::Connection, + ); + #[allow(dead_code)] impl ExampleDb { - /// Returns a fairing that initializes the database connection pool - /// associated with `Self`. - /// - /// The fairing _must_ be attached before `Self` can be used as a - /// request guard. - /// - /// # Example - /// - /// ```rust - /// # #[macro_use] extern crate rocket; - /// # #[macro_use] extern crate rocket_sync_db_pools; - /// # - /// # #[cfg(feature = "diesel_sqlite_pool")] { - /// use rocket_sync_db_pools::diesel; - /// - /// #[database("my_db")] - /// struct MyConn(diesel::SqliteConnection); - /// - /// #[launch] - /// fn rocket() -> _ { - /// rocket::build().attach(MyConn::fairing()) - /// } - /// # } - /// ``` + /// Returns a fairing that initializes the database connection pool. pub fn fairing() -> impl crate::rocket::fairing::Fairing { - >::fairing( - "'example' Database Pool", - "example", - ) + >::fairing("'example' Database Pool", "example") } - - /// Returns an opaque type that represents the connection pool backing - /// connections of type `Self` _as long as_ the fairing returned by - /// [`Self::fairing()`] is attached and has run on `__rocket`. - /// - /// The returned pool is `Clone`. Values of type `Self` can be retrieved - /// from the pool by calling `pool.get().await` which has the same - /// signature and semantics as [`Self::get_one()`]. - /// - /// # Example - /// - /// ```rust - /// # #[macro_use] extern crate rocket; - /// # #[macro_use] extern crate rocket_sync_db_pools; - /// # - /// # #[cfg(feature = "diesel_sqlite_pool")] { - /// use rocket::tokio::{task, time}; - /// use rocket::fairing::AdHoc; - /// use rocket_sync_db_pools::diesel; - /// - /// #[database("my_db")] - /// struct MyConn(diesel::SqliteConnection); - /// - /// #[launch] - /// fn rocket() -> _ { - /// rocket::build() - /// .attach(MyConn::fairing()) - /// .attach(AdHoc::try_on_ignite("Background DB", |rocket| async { - /// let pool = match MyConn::pool(&rocket) { - /// Some(pool) => pool.clone(), - /// None => return Err(rocket) - /// }; - /// - /// // Start a background task that runs some database - /// // operation every 10 seconds. If a connection isn't - /// // available, retries 10 + timeout seconds later. - /// tokio::task::spawn(async move { - /// loop { - /// time::sleep(time::Duration::from_secs(10)).await; - /// if let Some(conn) = pool.get().await { - /// conn.run(|c| { /* perform db ops */ }).await; - /// } - /// } - /// }); - /// - /// Ok(rocket) - /// })) - /// } - /// # } - /// ``` + /// Returns an opaque type that represents the connection pool + /// backing connections of type `Self`. pub fn pool( __rocket: &crate::rocket::Rocket

, - ) -> Option<&crate::ConnectionPool> - { - >::pool( - &__rocket, - ) + ) -> Option< + &crate::ConnectionPool, + > { + >::pool(&__rocket) } - - /// Runs the provided function `__f` in an async-safe blocking thread. - /// The function is supplied with a mutable reference to the raw - /// connection (a value of type `&mut Self.0`). `.await`ing the return - /// value of this function yields the value returned by `__f`. - /// - /// # Example - /// - /// ```rust - /// # #[macro_use] extern crate rocket; - /// # #[macro_use] extern crate rocket_sync_db_pools; - /// # - /// # #[cfg(feature = "diesel_sqlite_pool")] { - /// use rocket_sync_db_pools::diesel; - /// - /// #[database("my_db")] - /// struct MyConn(diesel::SqliteConnection); - /// - /// #[get("/")] - /// async fn f(conn: MyConn) { - /// // The type annotation is illustrative and isn't required. - /// let result = conn.run(|c: &mut diesel::SqliteConnection| { - /// // Use `c`. - /// }).await; - /// } - /// # } - /// ``` + /// Runs the provided function `__f` in an async-safe blocking + /// thread. pub async fn run(&self, __f: F) -> R where F: FnOnce(&mut diesel::SqliteConnection) -> R + Send + 'static, @@ -527,69 +436,42 @@ pub mod example { { self.0.run(__f).await } - /// Retrieves a connection of type `Self` from the `rocket` instance. - /// Returns `Some` as long as `Self::fairing()` has been attached and - /// there is a connection available within at most `timeout` seconds. pub async fn get_one( __rocket: &crate::rocket::Rocket

, ) -> Option { - >::get_one( - &__rocket, - ) - .await - .map(Self) + >::get_one(&__rocket) + .await + .map(Self) } } - - /// Retrieves a connection from the database pool or fails with a - /// `Status::ServiceUnavailable` if doing so times out. + #[crate::rocket::async_trait] impl<'r> crate::rocket::request::FromRequest<'r> for ExampleDb { - type Error = (); - #[allow( - clippy::let_unit_value, - clippy::no_effect_underscore_binding, - clippy::shadow_same, - clippy::type_complexity, - clippy::type_repetition_in_bounds, - clippy::used_underscore_binding - )] - fn from_request<'life0, 'async_trait>( - __r: &'r crate::rocket::request::Request<'life0>, - ) -> ::core::pin::Pin< - Box< - dyn ::core::future::Future< - Output = crate::rocket::request::Outcome, - > + ::core::marker::Send - + 'async_trait, - >, - > - where - 'r: 'async_trait, - 'life0: 'async_trait, - Self: 'async_trait, - { - Box::pin(async move { - if let ::core::option::Option::Some(__ret) = ::core::option::Option::None::< - crate::rocket::request::Outcome, - > { - return __ret; - } - let __r = __r; - let __ret: crate::rocket::request::Outcome = { - < crate :: Connection < Self , diesel :: SqliteConnection > > - :: from_request (__r) . await . map (Self) - }; - #[allow(unreachable_code)] - __ret - }) + type Error = crate::ConnectionMissing; + async fn from_request( + __r: &'r crate::rocket::request::Request<'_>, + ) -> crate::rocket::request::Outcome { + >::from_request(__r) + .await + .map(Self) } } impl crate::rocket::Sentinel for ExampleDb { fn abort( - __r: &crate::rocket::Rocket, + __r: &crate::rocket::Rocket< + crate::rocket::Ignite, + >, ) -> bool { - >::abort(__r) + >::abort(__r) } } } diff --git a/core/codegen/src/lib.rs b/core/codegen/src/lib.rs index 9e2cbcbce4..c4a5ca8ac9 100644 --- a/core/codegen/src/lib.rs +++ b/core/codegen/src/lib.rs @@ -313,6 +313,11 @@ route_attribute!(options => Method::Options); /// fn default(status: Status, req: &Request) -> String { /// format!("{} ({})", status, req.uri()) /// } +/// +/// #[catch(500, error = "")] +/// fn std_io_error(e: &std::io::Error) -> String { +/// format!("Std error: {:?}", e) +/// } /// ``` /// /// # Grammar diff --git a/core/http/src/uri/uri.rs b/core/http/src/uri/uri.rs index 14360e7bc1..03ba0297fe 100644 --- a/core/http/src/uri/uri.rs +++ b/core/http/src/uri/uri.rs @@ -79,7 +79,7 @@ impl<'a> Uri<'a> { /// // Invalid URIs fail to parse. /// Uri::parse::("foo bar").expect_err("invalid URI"); /// ``` - pub fn parse(string: &'a str) -> Result, Error<'_>> + pub fn parse(string: &'a str) -> Result, Error<'a>> where T: Into> + TryFrom<&'a str, Error = Error<'a>> { T::try_from(string).map(|v| v.into()) @@ -127,7 +127,7 @@ impl<'a> Uri<'a> { /// let uri: Origin = uri!("/a/b/c?query"); /// let uri: Reference = uri!("/a/b/c?query#fragment"); /// ``` - pub fn parse_any(string: &'a str) -> Result, Error<'_>> { + pub fn parse_any(string: &'a str) -> Result, Error<'a>> { crate::parse::uri::from_str(string) } diff --git a/core/lib/src/fairing/mod.rs b/core/lib/src/fairing/mod.rs index d24a8356b7..d63e81c13a 100644 --- a/core/lib/src/fairing/mod.rs +++ b/core/lib/src/fairing/mod.rs @@ -163,8 +163,7 @@ pub type FilterResult<'r> = std::result::Result<(), Box + 'r> /// is called just after a request is received, immediately after /// pre-processing the request and running all `Request` fairings. This method /// returns a `Result`, which can be used to terminate processing of a request, -// TODO: Typed: links -/// bypassing the routing process. The error value must be a `TypedError`, which +/// bypassing the routing process. The error value must be a [`TypedError`], which /// can then be caught by a typed catcher. /// /// This method should only be used for global filters, i.e., filters that need @@ -172,6 +171,8 @@ pub type FilterResult<'r> = std::result::Result<(), Box + 'r> /// CORS, since the CORS headers of every request need to be inspected, and potentially /// rejected. /// +/// [`TypedError`]: crate::catcher::TypedError +/// /// * **Response (`on_response`)** /// /// A response callback, represented by the [`Fairing::on_response()`] diff --git a/core/lib/src/form/context.rs b/core/lib/src/form/context.rs index b733c83607..e91b281f8d 100644 --- a/core/lib/src/form/context.rs +++ b/core/lib/src/form/context.rs @@ -219,7 +219,7 @@ impl<'v> Context<'v> { /// let foo_bar = form.context.field_errors("foo.bar"); /// } /// ``` - pub fn field_errors<'a, N>(&'a self, name: N) -> impl Iterator> + '_ + pub fn field_errors<'a, N>(&'a self, name: N) -> impl Iterator> + 'a where N: AsRef + 'a { self.errors.values() @@ -267,7 +267,7 @@ impl<'v> Context<'v> { /// let foo_bar = form.context.exact_field_errors("foo.bar"); /// } /// ``` - pub fn exact_field_errors<'a, N>(&'a self, name: N) -> impl Iterator> + '_ + pub fn exact_field_errors<'a, N>(&'a self, name: N) -> impl Iterator> + 'a where N: AsRef + 'a { self.errors.values() From 0b728678da816ee38fdbbb939a24a70c204e8f84 Mon Sep 17 00:00:00 2001 From: Matthew Pomes Date: Thu, 21 Nov 2024 23:27:14 -0600 Subject: [PATCH 15/20] Clean up initial implementation --- core/lib/src/catcher/handler.rs | 1 - core/lib/src/catcher/types.rs | 41 ++++------------------------ core/lib/src/lifecycle.rs | 10 +++++-- core/lib/src/request/from_request.rs | 31 +++++++++++++++++++-- core/lib/src/route/handler.rs | 10 ++++++- examples/error-handling/src/tests.rs | 2 -- examples/serialization/src/tests.rs | 2 -- scripts/mk-docs.sh | 2 +- 8 files changed, 51 insertions(+), 48 deletions(-) diff --git a/core/lib/src/catcher/handler.rs b/core/lib/src/catcher/handler.rs index e285f8411a..ac95422cba 100644 --- a/core/lib/src/catcher/handler.rs +++ b/core/lib/src/catcher/handler.rs @@ -89,7 +89,6 @@ pub type BoxFuture<'r, T = Result<'r>> = futures::future::BoxFuture<'r, T>; /// directly as the parameter to `rocket.register("/", )`. /// 3. Unlike static-function-based handlers, this custom handler can make use /// of internal state. -// TODO: Typed: Docs #[crate::async_trait] pub trait Handler: Cloneable + Send + Sync + 'static { /// Called by Rocket when an error with `status` for a given `Request` diff --git a/core/lib/src/catcher/types.rs b/core/lib/src/catcher/types.rs index ab96544e64..751d4f671b 100644 --- a/core/lib/src/catcher/types.rs +++ b/core/lib/src/catcher/types.rs @@ -34,7 +34,7 @@ mod sealed { /// This is the core of typed catchers. If an error type (returned by /// FromParam, FromRequest, FromForm, FromData, or Responder) implements -/// this trait, it can be caught by a typed catcher. (TODO) This trait +/// this trait, it can be caught by a typed catcher. This trait /// can be derived. pub trait TypedError<'r>: AsAny> + Send + Sync + 'r { /// Generates a default response for this type (or forwards to a default catcher) @@ -46,21 +46,14 @@ pub trait TypedError<'r>: AsAny> + Send + Sync + 'r { /// A descriptive name of this error type. Defaults to the type name. fn name(&self) -> &'static str { std::any::type_name::() } - // /// The error that caused this error. Defaults to None. - // /// - // /// # Warning - // /// A typed catcher will not attempt to follow the source of an error - // /// more than (TODO: exact number) 5 times. - // fn source(&'r self) -> Option<&'r (dyn TypedError<'r> + 'r)> { None } - - // TODO: Typed: need to support case where there are multiple errors /// The error that caused this error. Defaults to None. Each source /// should only be returned for one index - this method will be called - /// with indicies starting with 0, and increasing until it returns None. + /// with indicies starting with 0, and increasing until it returns None, + /// or reaches 5. /// /// # Warning /// A typed catcher will not attempt to follow the source of an error - /// more than (TODO: exact number) 5 times. + /// more than 5 times. fn source(&'r self, _idx: usize) -> Option<&'r (dyn TypedError<'r> + 'r)> { None } /// Status code @@ -76,7 +69,6 @@ impl<'r> TypedError<'r> for Status { } fn name(&self) -> &'static str { - // TODO: Status generally shouldn't be caught "" } @@ -96,39 +88,18 @@ impl AsStatus for Box + '_> { self.status() } } -// TODO: Typed: update transient to make the possible. -// impl<'r, R: TypedError<'r> + Transient> TypedError<'r> for (Status, R) -// where R::Transience: CanTranscendTo> -// { -// fn respond_to(&self, request: &'r Request<'_>) -> Result, Status> { -// self.1.respond_to(request) -// } - -// fn name(&self) -> &'static str { -// self.1.name() -// } - -// fn source(&'r self) -> Option<&'r (dyn TypedError<'r> + 'r)> { -// Some(&self.1) -// } - -// fn status(&self) -> Status { -// self.0 -// } -// } impl<'r, A: TypedError<'r> + Transient, B: TypedError<'r> + Transient> TypedError<'r> for (A, B) where A::Transience: CanTranscendTo>, B::Transience: CanTranscendTo>, - // (A, B): Transient, - // <(A, B) as Transient>::Transience: CanTranscendTo>, { fn respond_to(&self, request: &'r Request<'_>) -> Result, Status> { self.0.respond_to(request).or_else(|_| self.1.respond_to(request)) } fn name(&self) -> &'static str { - // TODO: Typed: Should indicate that the + // TODO: This should make it more clear that both `A` and `B` work, but + // would likely require const concatenation. std::any::type_name::<(A, B)>() } diff --git a/core/lib/src/lifecycle.rs b/core/lib/src/lifecycle.rs index 24df7c6097..2eaecaff3b 100644 --- a/core/lib/src/lifecycle.rs +++ b/core/lib/src/lifecycle.rs @@ -227,7 +227,7 @@ impl Rocket { ) -> route::Outcome<'r> { // Go through all matching routes until we fail or succeed or run out of // routes to try, in which case we forward with the last status. - let mut status: Box + 'r> = Box::new(Status::NotFound); + let mut error: Box + 'r> = Box::new(Status::NotFound); for route in self.router.route(request) { // Retrieve and set the requests parameters. route.trace_info(); @@ -243,11 +243,11 @@ impl Rocket { outcome.trace_info(); match outcome { o @ Outcome::Success(_) | o @ Outcome::Error(_) => return o, - Outcome::Forward(forwarded) => (data, status) = forwarded, + Outcome::Forward(forwarded) => (data, error) = forwarded, } } - Outcome::Forward((data, status)) + Outcome::Forward((data, error)) } // Invokes the catcher for `status`. Returns the response on success. @@ -329,6 +329,7 @@ impl Rocket { /// 5 calls deep /// * Matching Status, but not Type /// * Default handler + /// * Error type's default handler /// * Rocket's default /// /// The handler selected to be invoked is the one with the lowest rank. @@ -361,6 +362,9 @@ impl Rocket { .await .map(|result| result.map_err(Some)) .unwrap_or_else(|| Err(None)) + // TODO: Typed: should this be run in a `catch_unwind` context? + } else if let Ok(res) = error.respond_to(req) { + Ok(res) } else { info!(name: "catcher", name = "rocket::default", "uri.base" = "/", code = error.status().code, "no registered catcher: using Rocket default"); diff --git a/core/lib/src/request/from_request.rs b/core/lib/src/request/from_request.rs index dff3b9febf..8d90e8b531 100644 --- a/core/lib/src/request/from_request.rs +++ b/core/lib/src/request/from_request.rs @@ -81,7 +81,7 @@ pub type Outcome = outcome::Outcome; /// the value for the corresponding parameter. As long as all other guards /// succeed, the request will be handled. /// -/// * **Error**(Status, E) +/// * **Error**(E) /// /// If the `Outcome` is [`Error`], the request will fail with the given /// status code and error. The designated error [`Catcher`](crate::Catcher) @@ -89,7 +89,7 @@ pub type Outcome = outcome::Outcome; /// of `Result` and `Option` to catch `Error`s and retrieve the /// error value. /// -/// * **Forward**(Status) +/// * **Forward**(E) /// /// If the `Outcome` is [`Forward`], the request will be forwarded to the next /// matching route until either one succeeds or there are no further matching @@ -242,6 +242,32 @@ pub type Outcome = outcome::Outcome; /// } /// ``` /// +/// ## Errors +/// +/// When a request guard fails, the error type can be caught using a catcher. A catcher +/// for the above example might look something like this: +/// +/// ```rust +/// # #[macro_use] extern crate rocket; +/// # use rocket::http::Status; +/// # use rocket::request::{self, Outcome, Request, FromRequest}; +/// # #[derive(Debug, TypedError)] +/// # #[error(status = 400)] +/// # enum ApiKeyError { +/// # Missing, +/// # Invalid, +/// # } +/// #[catch(400, error = "")] +/// fn catch_api_key_error(e: &ApiKeyError) -> &'static str { +/// match e { +/// ApiKeyError::Missing => "Api key required", +/// ApiKeyError::Invalid => "Api key is invalid", +/// } +/// } +/// ``` +/// +/// See [typed catchers](crate::catch) for more information. +/// /// # Request-Local State /// /// Request guards that perform expensive operations, such as those that query a @@ -379,7 +405,6 @@ pub type Outcome = outcome::Outcome; /// User` and `Admin<'a>`) as the data is now owned by the request's cache. /// /// [request-local state]: https://rocket.rs/master/guide/state/#request-local-state -// TODO: Typed: docs #[crate::async_trait] pub trait FromRequest<'r>: Sized { /// The associated error to be returned if derivation fails. diff --git a/core/lib/src/route/handler.rs b/core/lib/src/route/handler.rs index 9057fd3963..88fb33ee90 100644 --- a/core/lib/src/route/handler.rs +++ b/core/lib/src/route/handler.rs @@ -30,6 +30,15 @@ pub type BoxFuture<'r, T = Outcome<'r>> = futures::future::BoxFuture<'r, T>; /// This is an _async_ trait. Implementations must be decorated /// [`#[rocket::async_trait]`](crate::async_trait). /// +/// ## Errors +/// +/// If the handler errors or forwards, the implementation must include a +/// [`Box`]. Any type that implements [`TypedError`] can +/// be boxed upcast, see [`TypedError`] docs for more information. +/// +/// [`Box`]: crate::catcher::TypedError +/// [`TypedError`]: crate::catcher::TypedError +/// /// # Example /// /// Say you'd like to write a handler that changes its functionality based on an @@ -137,7 +146,6 @@ pub type BoxFuture<'r, T = Outcome<'r>> = futures::future::BoxFuture<'r, T>; /// Use this alternative when a single configuration is desired and your custom /// handler is private to your application. For all other cases, a custom /// `Handler` implementation is preferred. -// TODO: Typed: Docs #[crate::async_trait] pub trait Handler: Cloneable + Send + Sync + 'static { /// Called by Rocket when a `Request` with its associated `Data` should be diff --git a/examples/error-handling/src/tests.rs b/examples/error-handling/src/tests.rs index 735fd9f566..785f21cb93 100644 --- a/examples/error-handling/src/tests.rs +++ b/examples/error-handling/src/tests.rs @@ -75,8 +75,6 @@ fn test_hello_invalid_age() { fn test_hello_sergio() { let client = Client::tracked(super::rocket()).unwrap(); - // TODO: typed: This logic has changed, either needs to be fixed - // or this test changed. for path in &["oops", "-129"] { let request = client.get(format!("/hello/Sergio/{}", path)); let expected = super::sergio_error(); diff --git a/examples/serialization/src/tests.rs b/examples/serialization/src/tests.rs index bff72896f1..5ede944d17 100644 --- a/examples/serialization/src/tests.rs +++ b/examples/serialization/src/tests.rs @@ -38,9 +38,7 @@ fn json_bad_get_put() { // Try to put a message without a proper body. let res = client.put("/json/80").header(ContentType::JSON).dispatch(); - // TODO: Typed: This behavior has changed assert_eq!(res.status(), Status::UnprocessableEntity); - // Status::BadRequest); // Try to put a message with a semantically invalid body. let res = client.put("/json/0") diff --git a/scripts/mk-docs.sh b/scripts/mk-docs.sh index 1f3d6149b8..7fbaecb537 100755 --- a/scripts/mk-docs.sh +++ b/scripts/mk-docs.sh @@ -25,7 +25,7 @@ pushd "${PROJECT_ROOT}" > /dev/null 2>&1 --crate-version ${DOC_VERSION} \ --enable-index-page \ --generate-link-to-definition" \ - cargo doc -Zrustdoc-map --no-deps --all-features \ + cargo +nightly doc -Zrustdoc-map --no-deps --all-features \ -p rocket \ -p rocket_db_pools \ -p rocket_sync_db_pools \ From 1a646075475493b23b3f32722cd3ab029032adfe Mon Sep 17 00:00:00 2001 From: Matthew Pomes Date: Fri, 22 Nov 2024 10:46:14 -0600 Subject: [PATCH 16/20] Fix minor issues --- core/codegen/src/attribute/catch/mod.rs | 3 ++- core/codegen/src/attribute/route/mod.rs | 15 +++++++++++---- core/codegen/src/derive/typed_error.rs | 13 ++++++++----- core/codegen/src/lib.rs | 2 +- core/codegen/tests/catcher.rs | 6 ++++-- core/codegen/tests/typed_error.rs | 11 +++++++++++ core/lib/src/router/collider.rs | 4 +++- core/lib/src/router/router.rs | 5 +++-- 8 files changed, 43 insertions(+), 16 deletions(-) diff --git a/core/codegen/src/attribute/catch/mod.rs b/core/codegen/src/attribute/catch/mod.rs index bd1f081cf1..2f1ba84a22 100644 --- a/core/codegen/src/attribute/catch/mod.rs +++ b/core/codegen/src/attribute/catch/mod.rs @@ -61,7 +61,8 @@ pub fn _catch( ::rocket::trace::error!( downcast_to = stringify!(#ty), error_name = #__error.name(), - "Failed to downcast error. This should never happen, please open an issue with details." + "Failed to downcast error. This should never happen, please \ + open an issue with details." ); return #_Err(#Status::InternalServerError); }, diff --git a/core/codegen/src/attribute/route/mod.rs b/core/codegen/src/attribute/route/mod.rs index ba43576332..169a5e4be5 100644 --- a/core/codegen/src/attribute/route/mod.rs +++ b/core/codegen/src/attribute/route/mod.rs @@ -113,6 +113,12 @@ fn query_decls(route: &Route) -> Option { "{_err}" ); } } ); + ::rocket::trace::info!( + name: "forward", + target: concat!("rocket::codegen::route::", module_path!()), + error_name = #TypedError::name(&__e), + "parameter guard forwarding" + ); return #Outcome::Forward(( #__data, @@ -141,6 +147,7 @@ fn request_guard_decl(guard: &Guard) -> TokenStream { parameter = stringify!(#ident), type_name = stringify!(#ty), status = #TypedError::status(&__e).code, + error_name = #TypedError::name(&__e), "request guard forwarding" ); @@ -156,8 +163,8 @@ fn request_guard_decl(guard: &Guard) -> TokenStream { target: concat!("rocket::codegen::route::", module_path!()), parameter = stringify!(#ident), type_name = stringify!(#ty), + status = #TypedError::status(&__c).code, error_name = #TypedError::name(&__c), - // reason = %#display_hack!(__e), "request guard failed" ); @@ -181,8 +188,7 @@ fn param_guard_decl(guard: &Guard) -> TokenStream { target: concat!("rocket::codegen::route::", module_path!()), parameter = #name, type_name = stringify!(#ty), - name = #TypedError::name(&__error), - // reason = %#display_hack!(__error), + error_name = #TypedError::name(&__error), "path guard forwarding" ); @@ -248,6 +254,7 @@ fn data_guard_decl(guard: &Guard) -> TokenStream { parameter = stringify!(#ident), type_name = stringify!(#ty), status = #TypedError::status(&__e).code, + error_name = #TypedError::name(&__e), "data guard forwarding" ); @@ -263,7 +270,7 @@ fn data_guard_decl(guard: &Guard) -> TokenStream { target: concat!("rocket::codegen::route::", module_path!()), parameter = stringify!(#ident), type_name = stringify!(#ty), - // reason = %#display_hack!(__e), + error_name = #TypedError::name(&__e), "data guard failed" ); diff --git a/core/codegen/src/derive/typed_error.rs b/core/codegen/src/derive/typed_error.rs index 17efb4bd4e..c3c9a68bfd 100644 --- a/core/codegen/src/derive/typed_error.rs +++ b/core/codegen/src/derive/typed_error.rs @@ -128,9 +128,12 @@ pub fn derive_typed_error(input: proc_macro::TokenStream) -> TokenStream { }) ) .validator(ValidatorBuild::new() - .input_validate(|_, i| match i.generics().lifetimes().count() > 1 { - true => Err(i.generics().span().error("only one lifetime is supported")), - false => Ok(()) + .input_validate(|_, i| if i.generics().lifetimes().count() > 1 { + Err(i.generics().span().error("only one lifetime is supported")) + } else if i.generics().const_params().count() > 0 { + Err(i.generics().span().error("const params are not supported")) + } else { + Ok(()) }) ) .inner_mapper(MapperBuild::new() @@ -146,14 +149,14 @@ pub fn derive_typed_error(input: proc_macro::TokenStream) -> TokenStream { match g { syn::GenericParam::Lifetime(_) => quote!{ 'static }, syn::GenericParam::Type(TypeParam { ident, .. }) => quote! { #ident }, - syn::GenericParam::Const(ConstParam { .. }) => todo!(), + syn::GenericParam::Const(ConstParam { .. }) => unreachable!(), } }); let trans = input.generics() .lifetimes() .map(|LifetimeParam { lifetime, .. }| quote!{#_catcher::Inv<#lifetime>}); quote!{ - type Static = #name <#(#args)*>; + type Static = #name <#(#args,)*>; type Transience = (#(#trans,)*); } }) diff --git a/core/codegen/src/lib.rs b/core/codegen/src/lib.rs index c4a5ca8ac9..1282b4cf33 100644 --- a/core/codegen/src/lib.rs +++ b/core/codegen/src/lib.rs @@ -1035,7 +1035,7 @@ pub fn derive_responder(input: TokenStream) -> TokenStream { /// name: &'r str, /// value: &'r str, /// } -/// +/// /// #[derive(TypedError)] /// enum HeaderError { /// InvalidValue, diff --git a/core/codegen/tests/catcher.rs b/core/codegen/tests/catcher.rs index 3887b9a9d9..e6dd4f8b5e 100644 --- a/core/codegen/tests/catcher.rs +++ b/core/codegen/tests/catcher.rs @@ -105,11 +105,13 @@ fn test_basic_params() { #[catch(default, error = "")] fn test_io_error(e: &io::Error) -> String { format!("{e:?}") } #[catch(default, error = "<_e>")] -fn test_parse_int_error(_e: &ParseIntError) -> String { println!("ParseIntError"); format!("ParseIntError") } +fn test_parse_int_error(_e: &ParseIntError) -> String { format!("ParseIntError") } #[catch(default, error = "<_e>")] fn test_parse_bool_error(_e: &ParseBoolError) -> String { format!("ParseBoolError") } #[catch(default, error = "")] -fn test_param_parse_bool_error(e: &FromParamError<'_, ParseBoolError>) -> String { format!("ParseBoolError: {}", e.raw) } +fn test_param_parse_bool_error(e: &FromParamError<'_, ParseBoolError>) -> String { + format!("ParseBoolError: {}", e.raw) +} #[test] diff --git a/core/codegen/tests/typed_error.rs b/core/codegen/tests/typed_error.rs index cbe9ae6cfd..1f47bef56f 100644 --- a/core/codegen/tests/typed_error.rs +++ b/core/codegen/tests/typed_error.rs @@ -67,3 +67,14 @@ fn validate_static() { )); boxed_error(Box::new(val)); } + +#[derive(TypedError)] +pub enum Generic { + First(E), +} + +#[derive(TypedError)] +pub struct GenericWithLifetime<'r, E> { + s: &'r str, + inner: E, +} diff --git a/core/lib/src/router/collider.rs b/core/lib/src/router/collider.rs index c62d3ef6ad..7cdadf9473 100644 --- a/core/lib/src/router/collider.rs +++ b/core/lib/src/router/collider.rs @@ -141,7 +141,9 @@ impl Catcher { /// assert!(!a.collides_with(&b)); /// ``` pub fn collides_with(&self, other: &Self) -> bool { - self.code == other.code && self.base().segments().eq(other.base().segments()) && self.type_id == other.type_id + self.code == other.code + && self.base().segments().eq(other.base().segments()) + && self.type_id == other.type_id } } diff --git a/core/lib/src/router/router.rs b/core/lib/src/router/router.rs index b810c21bd1..d95540c664 100644 --- a/core/lib/src/router/router.rs +++ b/core/lib/src/router/router.rs @@ -593,7 +593,8 @@ mod test { { let mut router = Router::new(); for (status, base, ty) in catchers { - let mut catcher = Catcher::new(status.map(|s| s.code), catcher::dummy_handler).rebase(Origin::parse(base).unwrap()); + let mut catcher = Catcher::new(status.map(|s| s.code), catcher::dummy_handler) + .rebase(Origin::parse(base).unwrap()); catcher.type_id = ty; router.catchers.push(catcher); } @@ -646,7 +647,7 @@ mod test { let request = client.req(Method::Get, Origin::parse(uri).unwrap()); router.catch_any(status, ty(&()), &request) } - + #[test] fn test_catch_vs_catch_any() { let router = make_router_catches([(None, "/", None)]).unwrap(); From 4ad4c08e60db7c6ca65ccb25e03f92ee0b22a9ea Mon Sep 17 00:00:00 2001 From: Matthew Pomes Date: Fri, 22 Nov 2024 12:56:26 -0600 Subject: [PATCH 17/20] Update s2n-quic-h3 to fix compatibility issue --- core/lib/Cargo.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/lib/Cargo.toml b/core/lib/Cargo.toml index 4a50a641d8..503b4450a6 100644 --- a/core/lib/Cargo.toml +++ b/core/lib/Cargo.toml @@ -131,7 +131,8 @@ optional = true [dependencies.s2n-quic-h3] # git = "https://github.com/SergioBenitez/s2n-quic-h3.git" # rev = "6613956" -path = "../../../s2n-quic-h3" +git = "https://github.com/the10thwiz/s2n-quic-h3.git" +rev = "a141cc1" optional = true [target.'cfg(unix)'.dependencies] From 443d38151d6da2314024c021bfce00488113a415 Mon Sep 17 00:00:00 2001 From: Matthew Pomes Date: Fri, 22 Nov 2024 12:59:23 -0600 Subject: [PATCH 18/20] Remove old dependency --- examples/error-handling/Cargo.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/error-handling/Cargo.toml b/examples/error-handling/Cargo.toml index a8f17a5bbf..b2155b5726 100644 --- a/examples/error-handling/Cargo.toml +++ b/examples/error-handling/Cargo.toml @@ -7,4 +7,3 @@ publish = false [dependencies] rocket = { path = "../../core/lib", features = ["json"] } -transient = { path = "/code/matthew/transient" } From 10b9faa30aedeae9ddfa72e61760c1b2e0e12a00 Mon Sep 17 00:00:00 2001 From: Matthew Pomes Date: Fri, 22 Nov 2024 13:30:06 -0600 Subject: [PATCH 19/20] Re-add display for error to tracing --- core/codegen/src/attribute/route/mod.rs | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/core/codegen/src/attribute/route/mod.rs b/core/codegen/src/attribute/route/mod.rs index 169a5e4be5..4226dc7b1f 100644 --- a/core/codegen/src/attribute/route/mod.rs +++ b/core/codegen/src/attribute/route/mod.rs @@ -41,7 +41,7 @@ fn query_decls(route: &Route) -> Option { } define_spanned_export!(Span::call_site() => - __req, __data, _form, Outcome, _Ok, _Err, _Some, _None, TypedError + __req, __data, _form, Outcome, _Ok, _Err, _Some, _None, TypedError, display_hack ); // Record all of the static parameters for later filtering. @@ -116,6 +116,7 @@ fn query_decls(route: &Route) -> Option { ::rocket::trace::info!( name: "forward", target: concat!("rocket::codegen::route::", module_path!()), + error = %#display_hack!(&__e), error_name = #TypedError::name(&__e), "parameter guard forwarding" ); @@ -134,7 +135,7 @@ fn query_decls(route: &Route) -> Option { fn request_guard_decl(guard: &Guard) -> TokenStream { let (ident, ty) = (guard.fn_ident.rocketized(), &guard.ty); define_spanned_export!(ty.span() => - __req, __data, _request, FromRequest, Outcome, TypedError + __req, __data, _request, FromRequest, Outcome, TypedError, display_hack ); quote_spanned! { ty.span() => @@ -146,6 +147,7 @@ fn request_guard_decl(guard: &Guard) -> TokenStream { target: concat!("rocket::codegen::route::", module_path!()), parameter = stringify!(#ident), type_name = stringify!(#ty), + error = %#display_hack!(&__e), status = #TypedError::status(&__e).code, error_name = #TypedError::name(&__e), "request guard forwarding" @@ -157,18 +159,19 @@ fn request_guard_decl(guard: &Guard) -> TokenStream { )); }, #[allow(unreachable_code)] - #Outcome::Error(__c) => { + #Outcome::Error(__e) => { ::rocket::trace::info!( name: "failure", target: concat!("rocket::codegen::route::", module_path!()), parameter = stringify!(#ident), type_name = stringify!(#ty), - status = #TypedError::status(&__c).code, - error_name = #TypedError::name(&__c), + error = %#display_hack!(&__e), + status = #TypedError::status(&__e).code, + error_name = #TypedError::name(&__e), "request guard failed" ); - return #Outcome::Error(Box::new(__c) as Box + '__r>); + return #Outcome::Error(Box::new(__e) as Box + '__r>); } }; } @@ -178,7 +181,8 @@ fn param_guard_decl(guard: &Guard) -> TokenStream { let (i, name, ty) = (guard.index, &guard.name, &guard.ty); define_spanned_export!(ty.span() => __req, __data, _None, _Some, _Ok, _Err, - Outcome, FromSegments, FromParam, Status, TypedError, FromParamError, FromSegmentsError + Outcome, FromSegments, FromParam, Status, TypedError, FromParamError, FromSegmentsError, + display_hack ); // Returned when a dynamic parameter fails to parse. @@ -188,6 +192,7 @@ fn param_guard_decl(guard: &Guard) -> TokenStream { target: concat!("rocket::codegen::route::", module_path!()), parameter = #name, type_name = stringify!(#ty), + error = %#display_hack!(&__error), error_name = #TypedError::name(&__error), "path guard forwarding" ); @@ -242,7 +247,7 @@ fn param_guard_decl(guard: &Guard) -> TokenStream { fn data_guard_decl(guard: &Guard) -> TokenStream { let (ident, ty) = (guard.fn_ident.rocketized(), &guard.ty); - define_spanned_export!(ty.span() => __req, __data, FromData, Outcome, TypedError); + define_spanned_export!(ty.span() => __req, __data, FromData, Outcome, TypedError, display_hack); quote_spanned! { ty.span() => let #ident: #ty = match <#ty as #FromData>::from_data(#__req, #__data).await { @@ -253,6 +258,7 @@ fn data_guard_decl(guard: &Guard) -> TokenStream { target: concat!("rocket::codegen::route::", module_path!()), parameter = stringify!(#ident), type_name = stringify!(#ty), + error = %#display_hack!(&__e), status = #TypedError::status(&__e).code, error_name = #TypedError::name(&__e), "data guard forwarding" @@ -270,6 +276,7 @@ fn data_guard_decl(guard: &Guard) -> TokenStream { target: concat!("rocket::codegen::route::", module_path!()), parameter = stringify!(#ident), type_name = stringify!(#ty), + error = %#display_hack!(&__e), error_name = #TypedError::name(&__e), "data guard failed" ); From b3806e66ade22630e9dda07b873cf857c5bd0149 Mon Sep 17 00:00:00 2001 From: Matthew Pomes Date: Fri, 22 Nov 2024 14:10:42 -0600 Subject: [PATCH 20/20] Fix more small things --- core/lib/src/data/data.rs | 2 +- core/lib/src/data/from_data.rs | 6 ++++-- core/lib/src/request/from_param.rs | 8 ++++---- core/lib/src/request/from_request.rs | 5 +++-- docs/guide/12-pastebin.md | 4 ++-- 5 files changed, 14 insertions(+), 11 deletions(-) diff --git a/core/lib/src/data/data.rs b/core/lib/src/data/data.rs index d5ea8d9e1d..82db59fb92 100644 --- a/core/lib/src/data/data.rs +++ b/core/lib/src/data/data.rs @@ -114,7 +114,7 @@ impl<'r> Data<'r> { /// use rocket::data::{Data, FromData, Outcome}; /// use rocket::http::Status; /// # struct MyType; - /// # #[derive(rocket::TypedError)] + /// # #[derive(rocket::TypedError, Debug)] /// # struct MyError; /// /// #[rocket::async_trait] diff --git a/core/lib/src/data/from_data.rs b/core/lib/src/data/from_data.rs index 758c6466dd..8abf2dc79d 100644 --- a/core/lib/src/data/from_data.rs +++ b/core/lib/src/data/from_data.rs @@ -1,3 +1,5 @@ +use std::fmt; + use crate::catcher::TypedError; use crate::http::RawStr; use crate::request::{Request, local_cache}; @@ -182,7 +184,7 @@ pub type Outcome<'r, T, E = >::Error> /// use rocket::request::Request; /// use rocket::data::{self, Data, FromData}; /// # struct MyType; -/// # #[derive(rocket::TypedError)] +/// # #[derive(rocket::TypedError, Debug)] /// # struct MyError; /// /// #[rocket::async_trait] @@ -312,7 +314,7 @@ pub type Outcome<'r, T, E = >::Error> #[crate::async_trait] pub trait FromData<'r>: Sized { /// The associated error to be returned when the guard fails. - type Error: TypedError<'r> + 'r; + type Error: TypedError<'r> + fmt::Debug + 'r; /// Asynchronously validates, parses, and converts an instance of `Self` /// from the incoming request body data. diff --git a/core/lib/src/request/from_param.rs b/core/lib/src/request/from_param.rs index 8257bf7ddc..aed9356c34 100644 --- a/core/lib/src/request/from_param.rs +++ b/core/lib/src/request/from_param.rs @@ -160,7 +160,7 @@ use crate::http::{uri::{Segments, error::PathError, fmt::Path}, Status}; /// use rocket::TypedError; /// # #[allow(dead_code)] /// # struct MyParam<'r> { key: &'r str, value: usize } -/// #[derive(TypedError)] +/// #[derive(TypedError, Debug)] /// struct MyParamError<'a>(&'a str); /// /// impl<'r> FromParam<'r> for MyParam<'r> { @@ -192,7 +192,7 @@ use crate::http::{uri::{Segments, error::PathError, fmt::Path}, Status}; /// # #[macro_use] extern crate rocket; /// # use rocket::request::FromParam; /// # use rocket::TypedError; -/// # #[derive(TypedError)] +/// # #[derive(TypedError, Debug)] /// # struct MyParamError<'a>(&'a str); /// # #[allow(dead_code)] /// # struct MyParam<'r> { key: &'r str, value: usize } @@ -215,7 +215,7 @@ use crate::http::{uri::{Segments, error::PathError, fmt::Path}, Status}; /// ``` pub trait FromParam<'a>: Sized { /// The associated error to be returned if parsing/validation fails. - type Error: TypedError<'a>; + type Error: TypedError<'a> + fmt::Debug + 'a; /// Parses and validates an instance of `Self` from a path parameter string /// or returns an `Error` if parsing or validation fails. @@ -397,7 +397,7 @@ impl<'a, T: FromParam<'a>> FromParam<'a> for Option { /// the `Utf8Error`. pub trait FromSegments<'r>: Sized { /// The associated error to be returned when parsing fails. - type Error: TypedError<'r>; + type Error: TypedError<'r> + fmt::Debug + 'r; /// Parses an instance of `Self` from many dynamic path parameter strings or /// returns an `Error` if one cannot be parsed. diff --git a/core/lib/src/request/from_request.rs b/core/lib/src/request/from_request.rs index 8d90e8b531..453ffe4ac2 100644 --- a/core/lib/src/request/from_request.rs +++ b/core/lib/src/request/from_request.rs @@ -1,4 +1,5 @@ use std::convert::Infallible; +use std::fmt; use std::net::{IpAddr, SocketAddr}; use crate::catcher::TypedError; @@ -36,7 +37,7 @@ pub type Outcome = outcome::Outcome; /// use rocket::request::{self, Request, FromRequest}; /// # struct MyType; /// # use rocket::TypedError; -/// # #[derive(TypedError)] +/// # #[derive(TypedError, Debug)] /// # struct MyError; /// /// #[rocket::async_trait] @@ -408,7 +409,7 @@ pub type Outcome = outcome::Outcome; #[crate::async_trait] pub trait FromRequest<'r>: Sized { /// The associated error to be returned if derivation fails. - type Error: TypedError<'r> + 'r; + type Error: TypedError<'r> + fmt::Debug + 'r; /// Derives an instance of `Self` from the incoming request metadata. /// diff --git a/docs/guide/12-pastebin.md b/docs/guide/12-pastebin.md index 36182be5f8..56147a345d 100644 --- a/docs/guide/12-pastebin.md +++ b/docs/guide/12-pastebin.md @@ -367,7 +367,7 @@ use rocket::tokio::fs::File; # pub fn new(size: usize) -> PasteId<'static> { todo!() } # pub fn file_path(&self) -> PathBuf { todo!() } # } -# #[derive(rocket::TypedError)] +# #[derive(Debug, rocket::TypedError)] # pub struct InvalidPasteId; # impl<'a> FromParam<'a> for PasteId<'a> { # type Error = InvalidPasteId; @@ -446,7 +446,7 @@ pub struct PasteId<'a>(Cow<'a, str>); # pub fn file_path(&self) -> PathBuf { todo!() } # } # -# #[derive(rocket::TypedError)] +# #[derive(Debug, rocket::TypedError)] # pub struct InvalidPasteId; # impl<'a> FromParam<'a> for PasteId<'a> { # type Error = InvalidPasteId;