diff --git a/Cargo.toml b/Cargo.toml index a69e37a..7a3e53e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,19 +20,21 @@ path = "src/lib.rs" [dependencies] log = "0.4" -http-types = { version = "2.11", default-features = false, features = ["hyperium_http"] } serde_json = "1" serde = "1" regex = "1" -futures-timer = "3.0.2" futures = "0.3.5" -hyper = { version = "0.14", features = ["full"] } +http = "1.0" +http-body-util = "0.1" +hyper = { version = "1.0", features = ["full"] } +hyper-util = { version = "0.1", features = ["tokio"] } tokio = { version = "1.5.0", features = ["rt"] } deadpool = "0.9.2" async-trait = "0.1" once_cell = "1" assert-json-diff = "2.0.1" base64 = "0.21.0" +url = "2.2" [dev-dependencies] async-std = { version = "1.9.0", features = ["attributes"] } @@ -40,4 +42,3 @@ surf = "2.3.2" reqwest = "0.11.3" tokio = { version = "1.5.0", features = ["macros", "rt-multi-thread"] } actix-rt = "2.2.0" -isahc = "1.3.1" diff --git a/src/http.rs b/src/http.rs index ad7ad00..c9ca8a0 100644 --- a/src/http.rs +++ b/src/http.rs @@ -1,3 +1,3 @@ -//! Convenient re-exports of `http-types`' types that are part of `wiremock`'s public API. -pub use http_types::headers::{HeaderName, HeaderValue, HeaderValues}; -pub use http_types::{Method, Url}; +//! Convenient re-exports of http types that are part of `wiremock`'s public API. +pub use http::{HeaderMap, HeaderName, HeaderValue, Method}; +pub use url::Url; diff --git a/src/matchers.rs b/src/matchers.rs index f18205b..00bbfbe 100644 --- a/src/matchers.rs +++ b/src/matchers.rs @@ -10,15 +10,14 @@ use crate::{Match, Request}; use assert_json_diff::{assert_json_matches_no_panic, CompareMode}; use base64::prelude::{Engine as _, BASE64_STANDARD}; -use http_types::headers::{HeaderName, HeaderValue, HeaderValues}; -use http_types::{Method, Url}; +use http::{HeaderName, HeaderValue, Method}; use log::debug; use regex::Regex; use serde::Serialize; use serde_json::Value; use std::convert::TryInto; -use std::ops::Deref; use std::str; +use url::Url; /// Implement the `Match` trait for all closures, out of the box, /// if their signature is compatible. @@ -342,7 +341,7 @@ impl Match for PathRegexMatcher { /// assert_eq!(status, 200); /// } /// ``` -pub struct HeaderExactMatcher(HeaderName, HeaderValues); +pub struct HeaderExactMatcher(HeaderName, Vec); /// Shorthand for [`HeaderExactMatcher::new`]. pub fn header(key: K, value: V) -> HeaderExactMatcher @@ -352,7 +351,7 @@ where V: TryInto, >::Error: std::fmt::Debug, { - HeaderExactMatcher::new(key, value.try_into().map(HeaderValues::from).unwrap()) + HeaderExactMatcher::new(key, vec![value]) } /// Shorthand for [`HeaderExactMatcher::new`] supporting multi valued headers. @@ -363,38 +362,44 @@ where V: TryInto, >::Error: std::fmt::Debug, { - let values = values - .into_iter() - .filter_map(|v| v.try_into().ok()) - .collect::(); HeaderExactMatcher::new(key, values) } impl HeaderExactMatcher { - pub fn new(key: K, value: V) -> Self + pub fn new(key: K, values: Vec) -> Self where K: TryInto, >::Error: std::fmt::Debug, - V: TryInto, - >::Error: std::fmt::Debug, + V: TryInto, + >::Error: std::fmt::Debug, { let key = key.try_into().expect("Failed to convert to header name."); - let value = value - .try_into() - .expect("Failed to convert to header value."); - Self(key, value) + let values = values + .into_iter() + .map(|value| { + value + .try_into() + .expect("Failed to convert to header value.") + }) + .collect(); + Self(key, values) } } impl Match for HeaderExactMatcher { fn matches(&self, request: &Request) -> bool { - match request.headers.get(&self.0) { - None => false, - Some(values) => { - let headers: Vec<&str> = self.1.iter().map(HeaderValue::as_str).collect(); - values.eq(headers.as_slice()) - } - } + let values = request + .headers + .get_all(&self.0) + .iter() + .filter_map(|v| v.to_str().ok()) + .flat_map(|v| { + v.split(',') + .map(str::trim) + .filter_map(|v| HeaderValue::from_str(v).ok()) + }) + .collect::>(); + values == self.1 // order matters } } @@ -513,12 +518,16 @@ impl HeaderRegexMatcher { impl Match for HeaderRegexMatcher { fn matches(&self, request: &Request) -> bool { - match request.headers.get(&self.0) { - None => false, - Some(values) => { - let has_values = values.iter().next().is_some(); - has_values && values.iter().all(|v| self.1.is_match(v.as_str())) - } + let mut it = request + .headers + .get_all(&self.0) + .iter() + .filter_map(|v| v.to_str().ok()) + .peekable(); + if it.peek().is_some() { + it.all(|v| self.1.is_match(v)) + } else { + false } } } @@ -994,7 +1003,6 @@ where /// use wiremock::{MockServer, Mock, ResponseTemplate}; /// use wiremock::matchers::basic_auth; /// use serde::{Deserialize, Serialize}; -/// use http_types::auth::BasicAuth; /// use std::convert::TryInto; /// /// #[async_std::main] @@ -1008,10 +1016,9 @@ where /// .mount(&mock_server) /// .await; /// -/// let auth = BasicAuth::new("username", "password"); /// let client: surf::Client = surf::Config::new() /// .set_base_url(surf::Url::parse(&mock_server.uri()).unwrap()) -/// .add_header(auth.name(), auth.value()).unwrap() +/// .add_header("Authorization", "Basic dXNlcm5hbWU6cGFzc3dvcmQ=").unwrap() /// .try_into().unwrap(); /// /// // Act @@ -1040,7 +1047,7 @@ impl BasicAuthMatcher { pub fn from_token(token: impl AsRef) -> Self { Self(header( "Authorization", - format!("Basic {}", token.as_ref()).deref(), + &*format!("Basic {}", token.as_ref()), )) } } @@ -1069,7 +1076,6 @@ impl Match for BasicAuthMatcher { /// use wiremock::{MockServer, Mock, ResponseTemplate}; /// use wiremock::matchers::bearer_token; /// use serde::{Deserialize, Serialize}; -/// use http_types::auth::BasicAuth; /// /// #[async_std::main] /// async fn main() { @@ -1098,7 +1104,7 @@ impl BearerTokenMatcher { pub fn from_token(token: impl AsRef) -> Self { Self(header( "Authorization", - format!("Bearer {}", token.as_ref()).deref(), + &*format!("Bearer {}", token.as_ref()), )) } } diff --git a/src/mock.rs b/src/mock.rs index 9b50edb..7a188aa 100644 --- a/src/mock.rs +++ b/src/mock.rs @@ -15,13 +15,13 @@ use std::ops::{ /// use std::convert::TryInto; /// /// // Check that a header with the specified name exists and its value has an odd length. -/// pub struct OddHeaderMatcher(http_types::headers::HeaderName); +/// pub struct OddHeaderMatcher(http::HeaderName); /// /// impl Match for OddHeaderMatcher { /// fn matches(&self, request: &Request) -> bool { /// match request.headers.get(&self.0) { /// // We are ignoring multi-valued headers for simplicity -/// Some(values) => values[0].as_str().len() % 2 == 1, +/// Some(value) => value.to_str().unwrap_or_default().len() % 2 == 1, /// None => false /// } /// } @@ -69,11 +69,11 @@ use std::ops::{ /// // Arrange /// let mock_server = MockServer::start().await; /// -/// let header_name: http_types::headers::HeaderName = "custom".try_into().unwrap(); +/// let header_name = http::HeaderName::from_static("custom"); /// // Check that a header with the specified name exists and its value has an odd length. /// let matcher = move |request: &Request| { /// match request.headers.get(&header_name) { -/// Some(values) => values[0].as_str().len() % 2 == 1, +/// Some(value) => value.to_str().unwrap_or_default().len() % 2 == 1, /// None => false /// } /// }; diff --git a/src/mock_server/bare_server.rs b/src/mock_server/bare_server.rs index fd31308..7ed6cbb 100644 --- a/src/mock_server/bare_server.rs +++ b/src/mock_server/bare_server.rs @@ -3,6 +3,8 @@ use crate::mock_set::MockId; use crate::mock_set::MountedMockSet; use crate::request::BodyPrintLimit; use crate::{mock::Mock, verification::VerificationOutcome, Request}; +use http_body_util::Full; +use hyper::body::Bytes; use std::fmt::{Debug, Write}; use std::net::{SocketAddr, TcpListener, TcpStream}; use std::pin::pin; @@ -10,7 +12,6 @@ use std::sync::atomic::AtomicBool; use std::sync::Arc; use tokio::sync::Notify; use tokio::sync::RwLock; -use tokio::task::LocalSet; /// An HTTP web-server running in the background to behave as one of your dependencies using `Mock`s /// for testing purposes. @@ -22,7 +23,7 @@ pub(crate) struct BareMockServer { state: Arc>, server_address: SocketAddr, // When `_shutdown_trigger` gets dropped the listening server terminates gracefully. - _shutdown_trigger: tokio::sync::oneshot::Sender<()>, + _shutdown_trigger: tokio::sync::watch::Sender<()>, } /// The elements of [`BareMockServer`] that are affected by each incoming request. @@ -38,7 +39,7 @@ impl MockServerState { pub(super) async fn handle_request( &mut self, mut request: Request, - ) -> (http_types::Response, Option) { + ) -> (hyper::Response>, Option) { request.body_print_limit = self.body_print_limit; // If request recording is enabled, record the incoming request // by adding it to the `received_requests` stack @@ -57,7 +58,7 @@ impl BareMockServer { request_recording: RequestRecording, body_print_limit: BodyPrintLimit, ) -> Self { - let (shutdown_trigger, shutdown_receiver) = tokio::sync::oneshot::channel(); + let (shutdown_trigger, shutdown_receiver) = tokio::sync::watch::channel(()); let received_requests = match request_recording { RequestRecording::Enabled => Some(Vec::new()), RequestRecording::Disabled => None, @@ -80,7 +81,7 @@ impl BareMockServer { .build() .expect("Cannot build local tokio runtime"); - LocalSet::new().block_on(&runtime, server_future) + runtime.block_on(server_future); }); for _ in 0..40 { if TcpStream::connect_timeout(&server_address, std::time::Duration::from_millis(25)) @@ -88,7 +89,7 @@ impl BareMockServer { { break; } - futures_timer::Delay::new(std::time::Duration::from_millis(25)).await; + tokio::time::sleep(std::time::Duration::from_millis(25)).await; } Self { @@ -162,7 +163,7 @@ impl BareMockServer { /// If request recording was disabled, it returns `None`. pub(crate) async fn received_requests(&self) -> Option> { let state = self.state.read().await; - state.received_requests.to_owned() + state.received_requests.clone() } } @@ -270,7 +271,7 @@ impl MockGuard { } // await event - notification.await + notification.await; } } @@ -292,15 +293,13 @@ impl Drop for MockGuard { if received_requests.is_empty() { "The server did not receive any request.".into() } else { - let requests = received_requests.iter().enumerate().fold( - String::new(), - |mut r, (index, request)| { - write!(r, "- Request #{idx}\n\t{request}", idx = index + 1,) - .unwrap(); - r + received_requests.iter().enumerate().fold( + "Received requests:\n".to_string(), + |mut message, (index, request)| { + _ = write!(message, "- Request #{}\n\t{}", index + 1, request); + message }, - ); - format!("Received requests:\n{requests}") + ) } } else { "Enable request recording on the mock server to get the list of incoming requests as part of the panic message.".into() @@ -320,6 +319,6 @@ impl Drop for MockGuard { state.mock_set.deactivate(*mock_id); } }; - futures::executor::block_on(future) + futures::executor::block_on(future); } } diff --git a/src/mock_server/exposed_server.rs b/src/mock_server/exposed_server.rs index 192a4aa..afce12d 100644 --- a/src/mock_server/exposed_server.rs +++ b/src/mock_server/exposed_server.rs @@ -50,7 +50,7 @@ impl Deref for InnerServer { fn deref(&self) -> &Self::Target { match self { InnerServer::Bare(b) => b, - InnerServer::Pooled(p) => p.deref(), + InnerServer::Pooled(p) => p, } } } @@ -160,7 +160,7 @@ impl MockServer { /// /// [`mount`]: Mock::mount pub async fn register(&self, mock: Mock) { - self.0.register(mock).await + self.0.register(mock).await; } /// Register a **scoped** [`Mock`] on an instance of `MockServer`. @@ -336,22 +336,21 @@ impl MockServer { if received_requests.is_empty() { "The server did not receive any request.".into() } else { - let requests = received_requests.into_iter().enumerate().fold( - String::new(), - |mut r, (index, request)| { - write!(r, "- Request #{idx}\n\t{request}", idx = index + 1).unwrap(); // infallible - r + received_requests.iter().enumerate().fold( + "Received requests:\n".to_string(), + |mut message, (index, request)| { + _ = write!(message, "- Request #{}\n\t{}", index + 1, request); + message }, - ); - format!("Received requests:\n{requests}",) + ) } } else { "Enable request recording on the mock server to get the list of incoming requests as part of the panic message.".into() }; - let verifications_errors = - failed_verifications.iter().fold(String::new(), |mut e, m| { - writeln!(e, "- {}", m.error_message()).unwrap(); // infallible - e + let verifications_errors: String = + failed_verifications.iter().fold(String::new(), |mut s, m| { + _ = writeln!(s, "- {}", m.error_message()); + s }); let error_message = format!( "Verifications failed:\n{verifications_errors}\n{received_requests_message}", @@ -424,7 +423,7 @@ impl MockServer { /// /// ```rust /// use wiremock::MockServer; - /// use http_types::Method; + /// use http::Method; /// /// #[async_std::main] /// async fn main() { @@ -439,7 +438,7 @@ impl MockServer { /// assert_eq!(received_requests.len(), 1); /// /// let received_request = &received_requests[0]; - /// assert_eq!(received_request.method, Method::Get); + /// assert_eq!(received_request.method, Method::GET); /// assert_eq!(received_request.url.path(), "/"); /// assert!(received_request.body.is_empty()); /// } @@ -484,7 +483,7 @@ impl MockServer { impl Drop for MockServer { // Clean up when the `MockServer` instance goes out of scope. fn drop(&mut self) { - futures::executor::block_on(self.verify()) + futures::executor::block_on(self.verify()); // The sender half of the channel, `shutdown_trigger`, gets dropped here // Triggering the graceful shutdown of the server itself. } diff --git a/src/mock_server/hyper.rs b/src/mock_server/hyper.rs index 8638d6a..cc31295 100644 --- a/src/mock_server/hyper.rs +++ b/src/mock_server/hyper.rs @@ -1,107 +1,76 @@ use crate::mock_server::bare_server::MockServerState; -use hyper::http; -use hyper::service::{make_service_fn, service_fn}; -use std::net::TcpListener; +use hyper::service::service_fn; +use hyper_util::rt::TokioIo; use std::sync::Arc; +use tokio::net::TcpListener; use tokio::sync::RwLock; -type DynError = Box; - /// The actual HTTP server responding to incoming requests according to the specified mocks. pub(super) async fn run_server( - listener: TcpListener, + listener: std::net::TcpListener, server_state: Arc>, - shutdown_signal: tokio::sync::oneshot::Receiver<()>, + mut shutdown_signal: tokio::sync::watch::Receiver<()>, ) { - let request_handler = make_service_fn(move |_| { + listener + .set_nonblocking(true) + .expect("Cannot set non-blocking mode on TcpListener"); + let listener = TcpListener::from_std(listener).expect("Cannot upgrade TcpListener"); + + let request_handler = move |request| { let server_state = server_state.clone(); async move { - Ok::<_, DynError>(service_fn(move |request: hyper::Request| { - let server_state = server_state.clone(); - async move { - let wiremock_request = crate::Request::from_hyper(request).await; - let (response, delay) = server_state - .write() - .await - .handle_request(wiremock_request) - .await; + let wiremock_request = crate::Request::from_hyper(request).await; + let (response, delay) = server_state + .write() + .await + .handle_request(wiremock_request) + .await; - // We do not wait for the delay within the handler otherwise we would be - // holding on to the write-side of the `RwLock` on `mock_set`. - // Holding on the lock while waiting prevents us from handling other requests until - // we have waited the whole duration specified in the delay. - // In particular, we cannot perform even perform read-only operation - - // e.g. check that mock assumptions have been verified. - // Using long delays in tests without handling the delay as we are doing here - // caused tests to hang (see https://github.com/seanmonstar/reqwest/issues/1147) - if let Some(delay) = delay { - delay.await; - } + // We do not wait for the delay within the handler otherwise we would be + // holding on to the write-side of the `RwLock` on `mock_set`. + // Holding on the lock while waiting prevents us from handling other requests until + // we have waited the whole duration specified in the delay. + // In particular, we cannot perform even perform read-only operation - + // e.g. check that mock assumptions have been verified. + // Using long delays in tests without handling the delay as we are doing here + // caused tests to hang (see https://github.com/seanmonstar/reqwest/issues/1147) + if let Some(delay) = delay { + delay.await; + } - Ok::<_, DynError>(http_types_response_to_hyper_response(response).await) - } - })) + Ok::<_, &'static str>(response) } - }); + }; - let server = hyper::Server::from_tcp(listener) - .unwrap() - .executor(LocalExec) - .serve(request_handler) - .with_graceful_shutdown(async { - // This futures resolves when either: - // - the sender half of the channel gets dropped (i.e. MockServer is dropped) - // - the sender is used, therefore sending a poison pill willingly as a shutdown signal - let _ = shutdown_signal.await; - }); - - if let Err(e) = server.await { - panic!("Mock server failed: {}", e); - } -} - -// An executor that can spawn !Send futures. -#[derive(Clone, Copy, Debug)] -struct LocalExec; - -impl hyper::rt::Executor for LocalExec -where - F: std::future::Future + 'static, // not requiring `Send` -{ - fn execute(&self, fut: F) { - // This will spawn into the currently running `LocalSet`. - tokio::task::spawn_local(fut); - } -} - -async fn http_types_response_to_hyper_response( - mut response: http_types::Response, -) -> hyper::Response { - let version = response.version().map(|v| v.into()).unwrap_or_default(); - let mut builder = http::response::Builder::new() - .status(response.status() as u16) - .version(version); - - headers_to_hyperium_headers(response.as_mut(), builder.headers_mut().unwrap()); - - let body_bytes = response.take_body().into_bytes().await.unwrap(); - let body = hyper::Body::from(body_bytes); - - builder.body(body).unwrap() -} + loop { + let (stream, _) = tokio::select! { biased; + accepted = listener.accept() => { + match accepted { + Ok(accepted) => accepted, + Err(_) => break, + } + }, + _ = shutdown_signal.changed() => { + log::info!("Mock server shutting down"); + break; + } + }; + let io = TokioIo::new(stream); -fn headers_to_hyperium_headers( - headers: &mut http_types::Headers, - hyperium_headers: &mut http::HeaderMap, -) { - for (name, values) in headers { - let name = format!("{}", name).into_bytes(); - let name = http::header::HeaderName::from_bytes(&name).unwrap(); + let request_handler = request_handler.clone(); + let mut shutdown_signal = shutdown_signal.clone(); + tokio::task::spawn(async move { + let conn = hyper::server::conn::http1::Builder::new() + .serve_connection(io, service_fn(request_handler)) + .with_upgrades(); + tokio::pin!(conn); - for value in values.iter() { - let value = format!("{}", value).into_bytes(); - let value = http::header::HeaderValue::from_bytes(&value).unwrap(); - hyperium_headers.append(&name, value); - } + loop { + tokio::select! { + _ = conn.as_mut() => break, + _ = shutdown_signal.changed() => conn.as_mut().graceful_shutdown(), + } + } + }); } } diff --git a/src/mock_set.rs b/src/mock_set.rs index 996d6f0..9786db9 100644 --- a/src/mock_set.rs +++ b/src/mock_set.rs @@ -3,14 +3,15 @@ use crate::{ verification::{VerificationOutcome, VerificationReport}, }; use crate::{Mock, Request, ResponseTemplate}; -use futures_timer::Delay; -use http_types::{Response, StatusCode}; +use http_body_util::Full; +use hyper::body::Bytes; use log::debug; use std::{ ops::{Index, IndexMut}, sync::{atomic::AtomicBool, Arc}, }; use tokio::sync::Notify; +use tokio::time::{sleep, Sleep}; /// The collection of mocks used by a `MockServer` instance to match against /// incoming requests. @@ -41,7 +42,7 @@ pub(crate) struct MockId { } impl MountedMockSet { - /// Create a new instance of MockSet. + /// Create a new instance of `MockSet`. pub(crate) fn new() -> MountedMockSet { MountedMockSet { mocks: vec![], @@ -49,7 +50,10 @@ impl MountedMockSet { } } - pub(crate) async fn handle_request(&mut self, request: Request) -> (Response, Option) { + pub(crate) async fn handle_request( + &mut self, + request: Request, + ) -> (hyper::Response>, Option) { debug!("Handling request."); let mut response_template: Option = None; self.mocks.sort_by_key(|(m, _)| m.specification.priority); @@ -63,11 +67,17 @@ impl MountedMockSet { } } if let Some(response_template) = response_template { - let delay = response_template.delay().map(|d| Delay::new(d.to_owned())); + let delay = response_template.delay().map(sleep); (response_template.generate_response(), delay) } else { debug!("Got unexpected request:\n{}", request); - (Response::new(StatusCode::NotFound), None) + ( + hyper::Response::builder() + .status(hyper::StatusCode::NOT_FOUND) + .body(Full::default()) + .unwrap(), + None, + ) } } diff --git a/src/request.rs b/src/request.rs index b3508a5..d1f8faa 100644 --- a/src/request.rs +++ b/src/request.rs @@ -1,11 +1,9 @@ -use std::iter::FromIterator; -use std::str::FromStr; -use std::{collections::HashMap, fmt}; +use std::fmt; -use futures::AsyncReadExt; -use http_types::convert::DeserializeOwned; -use http_types::headers::{HeaderName, HeaderValue, HeaderValues}; -use http_types::{Method, Url}; +use http::{HeaderMap, Method}; +use http_body_util::BodyExt; +use serde::de::DeserializeOwned; +use url::Url; pub const BODY_PRINT_LIMIT: usize = 10_000; @@ -41,7 +39,7 @@ pub enum BodyPrintLimit { pub struct Request { pub url: Url, pub method: Method, - pub headers: HashMap, + pub headers: HeaderMap, pub body: Vec, pub body_print_limit: BodyPrintLimit, } @@ -49,10 +47,12 @@ pub struct Request { impl fmt::Display for Request { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { writeln!(f, "{} {}", self.method, self.url)?; - for (name, values) in &self.headers { - let values = values + for name in self.headers.keys() { + let values = self + .headers + .get_all(name) .iter() - .map(|value| format!("{}", value)) + .map(|value| String::from_utf8_lossy(value.as_bytes())) .collect::>(); let values = values.join(","); writeln!(f, "{}: {}", name, values)?; @@ -107,35 +107,8 @@ impl Request { serde_json::from_slice(&self.body) } - pub async fn from(mut request: http_types::Request) -> Request { - let method = request.method(); - let url = request.url().to_owned(); - - let mut headers = HashMap::new(); - for (header_name, header_values) in &request { - headers.insert(header_name.to_owned(), header_values.to_owned()); - } - - let mut body: Vec = vec![]; - request - .take_body() - .into_reader() - .read_to_end(&mut body) - .await - .expect("Failed to read body"); - - Self { - url, - method, - headers, - body, - body_print_limit: BodyPrintLimit::Limited(BODY_PRINT_LIMIT), - } - } - - pub(crate) async fn from_hyper(request: hyper::Request) -> Request { + pub(crate) async fn from_hyper(request: hyper::Request) -> Request { let (parts, body) = request.into_parts(); - let method = parts.method.into(); let url = match parts.uri.authority() { Some(_) => parts.uri.to_string(), None => format!("http://localhost{}", parts.uri), @@ -143,37 +116,17 @@ impl Request { .parse() .unwrap(); - let mut headers = HashMap::new(); - for name in parts.headers.keys() { - let name = name.as_str().as_bytes().to_owned(); - let name = HeaderName::from_bytes(name).unwrap(); - let values = parts.headers.get_all(name.as_str()); - for value in values { - let value = value.as_bytes().to_owned(); - let value = HeaderValue::from_bytes(value).unwrap(); - let value_parts = value.as_str().split(','); - let value_parts = value_parts - .map(|it| it.trim()) - .filter_map(|it| HeaderValue::from_str(it).ok()); - headers - .entry(name.clone()) - .and_modify(|values: &mut HeaderValues| { - values.append(&mut HeaderValues::from_iter(value_parts.clone())) - }) - .or_insert_with(|| value_parts.collect()); - } - } - - let body = hyper::body::to_bytes(body) + let body = body + .collect() .await .expect("Failed to read request body.") - .to_vec(); + .to_bytes(); Self { url, - method, - headers, - body, + method: parts.method, + headers: parts.headers, + body: body.to_vec(), body_print_limit: BodyPrintLimit::Limited(BODY_PRINT_LIMIT), } } diff --git a/src/respond.rs b/src/respond.rs index c5df09f..fe4955e 100644 --- a/src/respond.rs +++ b/src/respond.rs @@ -75,7 +75,7 @@ use crate::{Request, ResponseTemplate}; /// `Respond` that propagates back a request header in the response: /// /// ```rust -/// use http_types::headers::HeaderName; +/// use http::HeaderName; /// use wiremock::{Match, MockServer, Mock, Request, ResponseTemplate, Respond}; /// use wiremock::matchers::path; /// use std::convert::TryInto; @@ -87,12 +87,12 @@ use crate::{Request, ResponseTemplate}; /// /// impl Respond for CorrelationIdResponder { /// fn respond(&self, request: &Request) -> ResponseTemplate { +/// const HEADER: HeaderName = HeaderName::from_static("x-correlation-id"); /// let mut response_template = self.0.clone(); -/// let header_name = HeaderName::from_str("X-Correlation-Id").unwrap(); -/// if let Some(correlation_id) = request.headers.get(&header_name) { +/// if let Some(correlation_id) = request.headers.get(&HEADER) { /// response_template = response_template.insert_header( -/// header_name, -/// correlation_id.last().to_owned() +/// HEADER, +/// correlation_id.to_owned() /// ); /// } /// response_template diff --git a/src/response_template.rs b/src/response_template.rs index b914dea..a8c5ce7 100644 --- a/src/response_template.rs +++ b/src/response_template.rs @@ -1,9 +1,8 @@ -use http_types::headers::{HeaderName, HeaderValue}; -use http_types::{Response, StatusCode}; +use http::{HeaderMap, HeaderName, HeaderValue, Response, StatusCode}; +use http_body_util::Full; +use hyper::body::Bytes; use serde::Serialize; -use std::collections::HashMap; use std::convert::TryInto; -use std::str::FromStr; use std::time::Duration; /// The blueprint for the response returned by a [`MockServer`] when a [`Mock`] matches on an incoming request. @@ -12,9 +11,9 @@ use std::time::Duration; /// [`MockServer`]: crate::MockServer #[derive(Clone, Debug)] pub struct ResponseTemplate { - mime: Option, + mime: String, status_code: StatusCode, - headers: HashMap>, + headers: HeaderMap, body: Option>, delay: Option, } @@ -37,8 +36,8 @@ impl ResponseTemplate { let status_code = s.try_into().expect("Failed to convert into status code."); Self { status_code, - headers: HashMap::new(), - mime: None, + headers: HeaderMap::new(), + mime: String::new(), body: None, delay: None, } @@ -61,14 +60,7 @@ impl ResponseTemplate { let value = value .try_into() .expect("Failed to convert into header value."); - match self.headers.get_mut(&key) { - Some(headers) => { - headers.push(value); - } - None => { - self.headers.insert(key, vec![value]); - } - } + self.headers.append(key, value); self } @@ -118,7 +110,7 @@ impl ResponseTemplate { let value = value .try_into() .expect("Failed to convert into header value."); - self.headers.insert(key, vec![value]); + self.headers.insert(key, value); self } @@ -145,10 +137,7 @@ impl ResponseTemplate { let body = serde_json::to_vec(&body).expect("Failed to convert into body."); self.body = Some(body); - self.mime = Some( - http_types::Mime::from_str("application/json") - .expect("Failed to convert into Mime header"), - ); + self.mime = "application/json".to_string(); self } @@ -163,9 +152,7 @@ impl ResponseTemplate { let body = body.try_into().expect("Failed to convert into body."); self.body = Some(body.into_bytes()); - self.mime = Some( - http_types::Mime::from_str("text/plain").expect("Failed to convert into Mime header"), - ); + self.mime = "text/plain".to_string(); self } @@ -220,8 +207,7 @@ impl ResponseTemplate { { let body = body.try_into().expect("Failed to convert into body."); self.body = Some(body); - self.mime = - Some(http_types::Mime::from_str(mime).expect("Failed to convert into Mime header")); + self.mime = mime.to_string(); self } @@ -234,7 +220,6 @@ impl ResponseTemplate { /// /// ### Example: /// ```rust - /// use isahc::config::Configurable; /// use wiremock::{MockServer, Mock, ResponseTemplate}; /// use wiremock::matchers::method; /// use std::time::Duration; @@ -272,25 +257,18 @@ impl ResponseTemplate { } /// Generate a response from the template. - pub(crate) fn generate_response(&self) -> Response { - let mut response = Response::new(self.status_code); - - // Add headers - for (header_name, header_values) in &self.headers { - response.insert_header(header_name.clone(), header_values.as_slice()); - } - - // Add body, if specified - if let Some(body) = &self.body { - response.set_body(body.clone()); - } + pub(crate) fn generate_response(&self) -> Response> { + let mut response = Response::builder().status(self.status_code); + let mut headers = self.headers.clone(); // Set content-type, if needed - if let Some(mime) = &self.mime { - response.set_content_type(mime.to_owned()); + if !self.mime.is_empty() { + headers.insert(http::header::CONTENT_TYPE, self.mime.parse().unwrap()); } + *response.headers_mut().unwrap() = headers; - response + let body = self.body.clone().unwrap_or_default(); + response.body(body.into()).unwrap() } /// Retrieve the response delay. diff --git a/tests/mocks.rs b/tests/mocks.rs index ce3dab4..e2ef13d 100644 --- a/tests/mocks.rs +++ b/tests/mocks.rs @@ -1,9 +1,9 @@ use futures::FutureExt; -use http_types::StatusCode; use serde::Serialize; use serde_json::json; use std::net::TcpStream; use std::time::Duration; +use surf::StatusCode; use wiremock::matchers::{body_json, body_partial_json, method, path, PathExactMatcher}; use wiremock::{Mock, MockServer, ResponseTemplate}; diff --git a/tests/request_header_matching.rs b/tests/request_header_matching.rs index 0c90db4..2bc5ee4 100644 --- a/tests/request_header_matching.rs +++ b/tests/request_header_matching.rs @@ -1,4 +1,3 @@ -use hyper::HeaderMap; use wiremock::matchers::{basic_auth, bearer_token, header, header_regex, headers, method}; use wiremock::{Mock, MockServer, ResponseTemplate}; @@ -88,12 +87,11 @@ async fn should_match_multi_request_header_x() { mock_server.register(mock).await; // Act - let mut header_map = HeaderMap::new(); - header_map.append("cache-control", "no-cache".parse().unwrap()); - header_map.append("cache-control", "no-store".parse().unwrap()); let should_match = reqwest::Client::new() .get(mock_server.uri()) - .headers(header_map) + // TODO: use a dedicated headers when upgrade reqwest v0.12 + .header("cache-control", "no-cache") + .header("cache-control", "no-store") .send() .await .unwrap();