diff --git a/src/mock.rs b/src/mock.rs index c87999c..fafab9c 100644 --- a/src/mock.rs +++ b/src/mock.rs @@ -1,6 +1,6 @@ use crate::diff; use crate::matcher::{Matcher, PathAndQueryMatcher, RequestMatcher}; -use crate::response::{Body, Response}; +use crate::response::{Body, Header, Response}; use crate::server::RemoteMock; use crate::server::State; use crate::Request; @@ -370,11 +370,47 @@ impl Mock { self.inner .response .headers - .append(field.into_header_name(), value.to_owned()); + .append(field.into_header_name(), Header::String(value.to_string())); self } + /// + /// Sets the headers of the mock response dynamically while exposing the request object. + /// + /// You can use this method to provide custom headers for every incoming request. + /// + /// The function must be thread-safe. If it's a closure, it can't be borrowing its context. + /// Use `move` closures and `Arc` to share any data. + /// + /// ### Example + /// + /// ``` + /// let mut s = mockito::Server::new(); + /// + /// let _m = s.mock("GET", mockito::Matcher::Any).with_header_from_request("user", |request| { + /// if request.path() == "/bob" { + /// "bob".into() + /// } else if request.path() == "/alice" { + /// "alice".into() + /// } else { + /// "everyone".into() + /// } + /// }); + /// ``` + /// + pub fn with_header_from_request( + mut self, + field: T, + callback: impl Fn(&Request) -> String + Send + Sync + 'static, + ) -> Self { + self.inner.response.headers.append( + field.into_header_name(), + Header::FnWithRequest(Arc::new(move |req| callback(req))), + ); + self + } + /// /// Sets the body of the mock response. Its `Content-Length` is handled automatically. /// diff --git a/src/response.rs b/src/response.rs index 218cb87..051e668 100644 --- a/src/response.rs +++ b/src/response.rs @@ -13,10 +13,40 @@ use tokio::sync::mpsc; #[derive(Clone, Debug, PartialEq)] pub(crate) struct Response { pub status: StatusCode, - pub headers: HeaderMap, + pub headers: HeaderMap
, pub body: Body, } +#[derive(Clone)] +pub(crate) enum Header { + String(String), + FnWithRequest(Arc), +} + +impl fmt::Debug for Header { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + Header::String(ref s) => s.fmt(f), + Header::FnWithRequest(_) => f.write_str(""), + } + } +} + +impl PartialEq for Header { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Header::String(ref a), Header::String(ref b)) => a == b, + (Header::FnWithRequest(ref a), Header::FnWithRequest(ref b)) => std::ptr::eq( + a.as_ref() as *const HeaderFnWithRequest as *const u8, + b.as_ref() as *const HeaderFnWithRequest as *const u8, + ), + _ => false, + } + } +} + +type HeaderFnWithRequest = dyn Fn(&Request) -> String + Send + Sync; + type BodyFnWithWriter = dyn Fn(&mut dyn io::Write) -> io::Result<()> + Send + Sync + 'static; type BodyFnWithRequest = dyn Fn(&Request) -> Bytes + Send + Sync + 'static; @@ -57,7 +87,7 @@ impl PartialEq for Body { impl Default for Response { fn default() -> Self { let mut headers = HeaderMap::with_capacity(1); - headers.insert("connection", "close".parse().unwrap()); + headers.insert("connection", Header::String("close".to_string())); Self { status: StatusCode::OK, headers, diff --git a/src/server.rs b/src/server.rs index 03530fd..9d50b16 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,6 +1,6 @@ use crate::mock::InnerMock; use crate::request::Request; -use crate::response::{Body as ResponseBody, ChunkedStream}; +use crate::response::{Body as ResponseBody, ChunkedStream, Header}; use crate::ServerGuard; use crate::{Error, ErrorKind, Matcher, Mock}; use bytes::Bytes; @@ -559,7 +559,12 @@ fn respond_with_mock(request: Request, mock: &RemoteMock) -> Result response = response.header(name, value), + Header::FnWithRequest(header_fn) => { + response = response.header(name, header_fn(&request)) + } + } } let body = if request.method() != "HEAD" { diff --git a/tests/lib.rs b/tests/lib.rs index 3c8ce7d..4ce3b78 100644 --- a/tests/lib.rs +++ b/tests/lib.rs @@ -677,6 +677,25 @@ fn test_mock_with_header() { assert!(headers.contains(&"content-type: application/json".to_string())); } +#[test] +fn test_mock_with_header_from_request() { + let mut s = Server::new(); + s.mock("GET", Matcher::Any) + .with_header_from_request("user", |req| { + if req.path() == "/alice" { + "alice".into() + } else { + "everyone".into() + } + }) + .create(); + + let (_, headers, _) = request(s.host_with_port(), "GET /alice", ""); + assert!(headers.contains(&"user: alice".to_string())); + let (_, headers, _) = request(s.host_with_port(), "GET /anyone-else", ""); + assert!(headers.contains(&"user: everyone".to_string())); +} + #[test] fn test_mock_with_multiple_headers() { let mut s = Server::new();