From f078f5507e642c7d2fd5e8ea88fb84744818c3af Mon Sep 17 00:00:00 2001 From: Jonas Platte Date: Thu, 5 Sep 2024 23:16:21 +0200 Subject: [PATCH] Add new SetTaskLocal layer / service --- tower-http/Cargo.toml | 1 + tower-http/src/lib.rs | 3 ++ tower-http/src/macros.rs | 9 ++++- tower-http/src/task_local.rs | 78 ++++++++++++++++++++++++++++++++++++ 4 files changed, 89 insertions(+), 2 deletions(-) create mode 100644 tower-http/src/task_local.rs diff --git a/tower-http/Cargo.toml b/tower-http/Cargo.toml index 1e4b6cc6..044933a1 100644 --- a/tower-http/Cargo.toml +++ b/tower-http/Cargo.toml @@ -102,6 +102,7 @@ request-id = ["uuid"] sensitive-headers = [] set-header = [] set-status = [] +task-local = ["tokio/rt"] timeout = ["dep:http-body", "tokio/time"] trace = ["dep:http-body", "tracing"] util = ["tower"] diff --git a/tower-http/src/lib.rs b/tower-http/src/lib.rs index 8d254e1d..345763e0 100644 --- a/tower-http/src/lib.rs +++ b/tower-http/src/lib.rs @@ -309,6 +309,9 @@ pub mod follow_redirect; #[cfg(feature = "limit")] pub mod limit; +#[cfg(feature = "task-local")] +pub mod task_local; + #[cfg(feature = "metrics")] pub mod metrics; diff --git a/tower-http/src/macros.rs b/tower-http/src/macros.rs index f58d34a6..be59446e 100644 --- a/tower-http/src/macros.rs +++ b/tower-http/src/macros.rs @@ -68,10 +68,15 @@ macro_rules! opaque_body { #[allow(unused_macros)] macro_rules! opaque_future { - ($(#[$m:meta])* pub type $name:ident<$($param:ident),+> = $actual:ty;) => { + ( + $(#[$m:meta])* + pub type $name:ident<$($param:ident),+> = $actual:ty + $( where $($tt:tt)* )? + ; + ) => { pin_project_lite::pin_project! { $(#[$m])* - pub struct $name<$($param),+> { + pub struct $name<$($param),+> $( where $($tt)* )? { #[pin] inner: $actual } diff --git a/tower-http/src/task_local.rs b/tower-http/src/task_local.rs new file mode 100644 index 00000000..56bd45d3 --- /dev/null +++ b/tower-http/src/task_local.rs @@ -0,0 +1,78 @@ +//! Middleware to set tokio task-local data. + +use std::future::Future; + +use http::{Request, Response}; +use tokio::task::{futures::TaskLocalFuture, LocalKey}; +use tower_layer::Layer; +use tower_service::Service; + +#[derive(Debug, Clone, Copy)] +pub struct SetTaskLocalLayer { + key: &'static LocalKey, + value: T, +} + +impl SetTaskLocalLayer +where + T: Clone + Send + Sync + 'static, +{ + pub fn new(key: &'static LocalKey, value: T) -> Self { + SetTaskLocalLayer { key, value } + } +} + +impl Layer for SetTaskLocalLayer +where + T: Clone + Send + Sync + 'static, +{ + type Service = SetTaskLocal; + + fn layer(&self, inner: S) -> Self::Service { + SetTaskLocal::new(inner, self.key, self.value.clone()) + } +} + +#[derive(Debug, Clone, Copy)] +pub struct SetTaskLocal { + inner: S, + key: &'static LocalKey, + value: T, +} + +impl SetTaskLocal +where + T: Clone + Send + Sync + 'static, +{ + pub fn new(inner: S, key: &'static LocalKey, value: T) -> Self { + Self { inner, key, value } + } +} + +impl Service> for SetTaskLocal +where + S: Service, Response = Response>, + T: Clone + Send + Sync + 'static, +{ + type Response = S::Response; + type Error = S::Error; + type Future = ResponseFuture; + + fn poll_ready( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request) -> Self::Future { + self.key.scope(self.value.clone(), self.inner.call(req)) + } +} + +opaque_future! { + /// Response future of [`SetTaskLocal`]. + pub type ResponseFuture = TaskLocalFuture + where + S: Service>; +}