diff --git a/src/routing.rs b/src/routing.rs index 4a74763b66..2112226c45 100644 --- a/src/routing.rs +++ b/src/routing.rs @@ -354,6 +354,16 @@ pub trait RoutingDsl: crate::sealed::Sealed + Sized { { IntoMakeServiceWithConnectInfo::new(self) } + + // TODO(david): Could doing `.or` with some random service lead to strange + // behavior? + // Could `where S: RoutingDsl` prevent misuse? + fn or(self, other: S) -> Or { + Or { + first: self, + second: other, + } + } } impl RoutingDsl for Route {} @@ -507,7 +517,11 @@ impl RoutingDsl for EmptyRouter {} impl crate::sealed::Sealed for EmptyRouter {} -impl Service> for EmptyRouter { +impl Service> for EmptyRouter +where + // TODO(david): breaking change + B: Send + Sync + 'static, +{ type Response = Response; type Error = E; type Future = EmptyRouterFuture; @@ -516,8 +530,9 @@ impl Service> for EmptyRouter { Poll::Ready(Ok(())) } - fn call(&mut self, _req: Request) -> Self::Future { + fn call(&mut self, request: Request) -> Self::Future { let mut res = Response::new(crate::body::empty()); + res.extensions_mut().insert(FromEmptyRouter { request }); *res.status_mut() = self.status; EmptyRouterFuture { future: future::ok(res), @@ -531,6 +546,16 @@ opaque_future! { future::Ready, E>>; } +/// Response extension used by [`EmptyRouter`] to send the request back to [`Or`] so +/// the other service can be called. +/// +/// Without this we would loose ownership of the response when calling the first +/// service in [`Or`]. We also wouldn't be able to identify if the response came +/// from [`EmptyRouter`] and therefore can be discarded in [`Or`]. +struct FromEmptyRouter { + request: Request, +} + #[derive(Debug, Clone)] pub(crate) struct PathPattern(Arc); @@ -1001,6 +1026,57 @@ fn strip_prefix(uri: &Uri, prefix: &str) -> Uri { Uri::from_parts(parts).unwrap() } +#[derive(Debug, Clone, Copy)] +pub struct Or { + first: A, + second: B, +} + +impl RoutingDsl for Or {} + +impl crate::sealed::Sealed for Or {} + +#[allow(warnings)] +impl Service> for Or +where + A: Service, Response = Response> + Clone, + B: Service, Response = Response, Error = A::Error> + Clone, + ReqBody: Send + Sync + 'static, + A: Send + 'static, + B: Send + 'static, + A::Future: Send + 'static, + B::Future: Send + 'static, +{ + type Response = Response; + type Error = A::Error; + // TODO(david): don't use a boxed future here + type Future = futures_util::future::BoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Request) -> Self::Future { + let mut first = self.first.clone(); + let mut second = self.second.clone(); + + Box::pin(async move { + let mut response: Response = first.oneshot(req).await?; + + let req = if let Some(ext) = response + .extensions_mut() + .remove::>() + { + ext.request + } else { + return Ok(response); + }; + + second.oneshot(req).await + }) + } +} + #[cfg(test)] mod tests { use super::*;