diff --git a/sdk/core/src/bytes_response.rs b/sdk/core/src/bytes_response.rs deleted file mode 100644 index f71690eea6..0000000000 --- a/sdk/core/src/bytes_response.rs +++ /dev/null @@ -1,98 +0,0 @@ -use bytes::Bytes; -use http::{HeaderMap, StatusCode}; - -#[derive(Debug, Clone, PartialEq)] -pub(crate) struct BytesResponse { - status: StatusCode, - headers: HeaderMap, - body: Bytes, -} - -impl BytesResponse { - pub(crate) fn deconstruct(self) -> (StatusCode, HeaderMap, Bytes) { - (self.status, self.headers, self.body) - } -} - -#[cfg(feature = "mock_transport_framework")] -mod mock_transport { - use super::*; - use crate::{collect_pinned_stream, BytesStream, Response}; - use http::{header, HeaderMap, StatusCode}; - use serde::{Deserialize, Serialize}; - use std::collections::BTreeMap; - - impl BytesResponse { - pub(crate) fn new(status: StatusCode, headers: HeaderMap, body: Bytes) -> Self { - Self { - status, - headers, - body, - } - } - - pub(crate) async fn duplicate( - response: crate::Response, - ) -> Result<(crate::Response, Self), crate::StreamError> { - let (status_code, header_map, pinned_stream) = response.deconstruct(); - let response_bytes = collect_pinned_stream(pinned_stream).await?; - - let response = Response::new( - status_code, - header_map.clone(), - Box::pin(BytesStream::new(response_bytes.clone())), - ); - let bytes_response = BytesResponse::new(status_code, header_map, response_bytes); - - Ok((response, bytes_response)) - } - } - - impl<'de> Deserialize<'de> for BytesResponse { - fn deserialize(deserializer: D) -> Result - where - D: serde::de::Deserializer<'de>, - { - use serde::de::Error; - let r = SerializedBytesResponse::deserialize(deserializer)?; - let mut headers = HeaderMap::new(); - for (n, v) in r.headers.iter() { - let name = - header::HeaderName::from_lowercase(n.as_bytes()).map_err(Error::custom)?; - let value = header::HeaderValue::from_str(&v).map_err(Error::custom)?; - headers.insert(name, value); - } - let body = Bytes::from(base64::decode(r.body).map_err(Error::custom)?); - let status = StatusCode::from_u16(r.status).map_err(Error::custom)?; - - Ok(Self::new(status, headers, body)) - } - } - - impl Serialize for BytesResponse { - fn serialize(&self, serializer: S) -> Result - where - S: serde::ser::Serializer, - { - let mut headers = BTreeMap::new(); - for (h, v) in self.headers.iter() { - headers.insert(h.as_str().into(), v.to_str().unwrap().into()); - } - let status = self.status.as_u16(); - let body = base64::encode(&self.body as &[u8]); - let s = SerializedBytesResponse { - status, - headers, - body, - }; - s.serialize(serializer) - } - } - - #[derive(Serialize, Deserialize)] - pub(crate) struct SerializedBytesResponse { - status: u16, - headers: BTreeMap, - body: String, - } -} diff --git a/sdk/core/src/errors.rs b/sdk/core/src/errors.rs index 45e4630a7e..cbdb6d3aee 100644 --- a/sdk/core/src/errors.rs +++ b/sdk/core/src/errors.rs @@ -217,32 +217,6 @@ pub enum TraversingError { ParsingError(#[from] ParsingError), } -/// An error relating to the mock transport framework. -#[cfg(feature = "mock_transport_framework")] -#[derive(Debug, thiserror::Error)] -pub enum MockFrameworkError { - #[error("the mock testing framework has not been initialized")] - UninitializedTransaction, - #[error("{0}: {1}")] - IOError(String, std::io::Error), - #[error("{0}")] - TransactionStorageError(String), - #[error("{0}")] - MissingTransaction(String), - #[error("mismatched request uri. Actual '{0}', Expected: '{1}'")] - MismatchedRequestUri(String, String), - #[error("received request have header {0} but it was not present in the read request")] - MissingRequestHeader(String), - #[error("different number of headers in request. Actual: {0}, Expected: {1}")] - MismatchedRequestHeadersCount(usize, usize), - #[error("request header {0} value is different. Actual: {1}, Expected: {2}")] - MismatchedRequestHeader(String, String, String), - #[error("mismatched HTTP request method. Actual: {0}, Expected: {1}")] - MismatchedRequestHTTPMethod(http::Method, http::Method), - #[error("mismatched request body. Actual: {0:?}, Expected: {1:?}")] - MismatchedRequestBody(Vec, Vec), -} - /// Extract the headers and body from a `hyper` HTTP response. #[cfg(feature = "enable_hyper")] #[inline] diff --git a/sdk/core/src/lib.rs b/sdk/core/src/lib.rs index 4ec75cdc1e..f93a29b7ee 100644 --- a/sdk/core/src/lib.rs +++ b/sdk/core/src/lib.rs @@ -14,7 +14,6 @@ extern crate serde_derive; #[macro_use] mod macros; -mod bytes_response; mod bytes_stream; mod constants; mod context; @@ -33,7 +32,7 @@ mod sleep; pub mod auth; pub mod headers; #[cfg(feature = "mock_transport_framework")] -mod mock_transaction; +pub mod mock; pub mod parsing; pub mod prelude; pub mod util; @@ -48,8 +47,6 @@ pub use errors::*; #[doc(inline)] pub use headers::AddAsHeader; pub use http_client::{new_http_client, to_json, HttpClient}; -#[cfg(feature = "mock_transport_framework")] -pub use mock_transaction::constants::*; pub use models::*; pub use options::*; pub use pipeline::Pipeline; diff --git a/sdk/core/src/mock/mock_request.rs b/sdk/core/src/mock/mock_request.rs new file mode 100644 index 0000000000..6b924e6cad --- /dev/null +++ b/sdk/core/src/mock/mock_request.rs @@ -0,0 +1,145 @@ +use crate::{Body, Request}; +use http::{HeaderMap, Method, Uri}; +use serde::de::Visitor; +use serde::ser::{Serialize, SerializeStruct, Serializer}; +use serde::{Deserialize, Deserializer}; +use std::collections::HashMap; +use std::str::FromStr; + +const FIELDS: &[&str] = &["uri", "method", "headers", "body"]; + +impl Request { + fn new(uri: Uri, method: Method, headers: HeaderMap, body: Body) -> Self { + Self { + uri, + method, + headers, + body, + } + } +} + +impl<'de> Deserialize<'de> for Request { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_struct("Request", FIELDS, RequestVisitor) + } +} + +struct RequestVisitor; + +impl<'de> Visitor<'de> for RequestVisitor { + type Value = Request; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct Request") + } + + fn visit_map(self, mut map: A) -> Result + where + A: serde::de::MapAccess<'de>, + { + let uri: (&str, &str) = match map.next_entry()? { + Some((a, b)) => (a, b), + None => return Err(serde::de::Error::custom("missing uri")), + }; + + if uri.0 != FIELDS[0] { + return Err(serde::de::Error::custom(format!( + "unexpected field {}, expected {}", + uri.0, FIELDS[0] + ))); + } + + let method: (&str, &str) = match map.next_entry()? { + Some((a, b)) => (a, b), + None => return Err(serde::de::Error::custom("missing method")), + }; + + if method.0 != FIELDS[1] { + return Err(serde::de::Error::custom(format!( + "unexpected field {}, expected {}", + method.0, FIELDS[1] + ))); + } + + let headers: (&str, HashMap<&str, String>) = match map.next_entry()? { + Some((a, b)) => (a, b), + None => return Err(serde::de::Error::custom("missing header map")), + }; + if headers.0 != FIELDS[2] { + return Err(serde::de::Error::custom(format!( + "unexpected field {}, expected {}", + headers.0, FIELDS[2] + ))); + } + + let body: (&str, String) = match map.next_entry()? { + Some((a, b)) => (a, b), + None => return Err(serde::de::Error::custom("missing body")), + }; + if body.0 != FIELDS[3] { + return Err(serde::de::Error::custom(format!( + "unexpected field {}, expected {}", + body.0, FIELDS[3] + ))); + } + + let body = base64::decode(&body.1).map_err(serde::de::Error::custom)?; + + let mut hm = HeaderMap::new(); + for (k, v) in headers.1.into_iter() { + hm.append( + http::header::HeaderName::from_lowercase(k.as_bytes()) + .map_err(serde::de::Error::custom)?, + http::HeaderValue::from_str(&v).map_err(serde::de::Error::custom)?, + ); + } + + Ok(Self::Value::new( + Uri::from_str(uri.1).expect("expected a valid uri"), + Method::from_str(method.1).expect("expected a valid HTTP method"), + hm, + bytes::Bytes::from(body).into(), + )) + } +} + +impl Serialize for Request { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut hm = std::collections::BTreeMap::new(); + for (h, v) in self.headers().iter() { + if h.as_str().to_lowercase() == "authorization" { + hm.insert(h.to_string(), "<>"); + } else { + hm.insert(h.to_string(), v.to_str().unwrap()); + } + } + + let mut state = serializer.serialize_struct("Request", 4)?; + state.serialize_field( + FIELDS[0], + &self + .uri + .path_and_query() + .map(|p| p.to_string()) + .unwrap_or_else(String::new), + )?; + state.serialize_field(FIELDS[1], &self.method.to_string())?; + state.serialize_field(FIELDS[2], &hm)?; + state.serialize_field( + FIELDS[3], + &match &self.body { + Body::Bytes(bytes) => base64::encode(bytes as &[u8]), + Body::SeekableStream(_) => unimplemented!(), + }, + )?; + + state.end() + } +} diff --git a/sdk/core/src/mock/mock_response.rs b/sdk/core/src/mock/mock_response.rs new file mode 100644 index 0000000000..51d33d47e9 --- /dev/null +++ b/sdk/core/src/mock/mock_response.rs @@ -0,0 +1,98 @@ +use bytes::Bytes; +use http::{header, HeaderMap, StatusCode}; +use serde::{Deserialize, Serialize}; +use std::collections::BTreeMap; + +use crate::{collect_pinned_stream, BytesStream, Response}; + +#[derive(Debug, Clone, PartialEq)] +pub(crate) struct MockResponse { + status: StatusCode, + headers: HeaderMap, + body: Bytes, +} + +impl From for Response { + fn from(mock_response: MockResponse) -> Self { + let bytes_stream: crate::BytesStream = mock_response.body.into(); + + Self::new( + mock_response.status, + mock_response.headers, + Box::pin(bytes_stream), + ) + } +} + +impl MockResponse { + pub(crate) fn new(status: StatusCode, headers: HeaderMap, body: Bytes) -> Self { + Self { + status, + headers, + body, + } + } + + pub(crate) async fn duplicate( + response: crate::Response, + ) -> Result<(crate::Response, Self), crate::StreamError> { + let (status_code, header_map, pinned_stream) = response.deconstruct(); + let response_bytes = collect_pinned_stream(pinned_stream).await?; + + let response = Response::new( + status_code, + header_map.clone(), + Box::pin(BytesStream::new(response_bytes.clone())), + ); + let mock_response = MockResponse::new(status_code, header_map, response_bytes); + + Ok((response, mock_response)) + } +} + +impl<'de> Deserialize<'de> for MockResponse { + fn deserialize(deserializer: D) -> Result + where + D: serde::de::Deserializer<'de>, + { + use serde::de::Error; + let r = SerializedMockResponse::deserialize(deserializer)?; + let mut headers = HeaderMap::new(); + for (n, v) in r.headers.iter() { + let name = header::HeaderName::from_lowercase(n.as_bytes()).map_err(Error::custom)?; + let value = header::HeaderValue::from_str(&v).map_err(Error::custom)?; + headers.insert(name, value); + } + let body = Bytes::from(base64::decode(r.body).map_err(Error::custom)?); + let status = StatusCode::from_u16(r.status).map_err(Error::custom)?; + + Ok(Self::new(status, headers, body)) + } +} + +impl Serialize for MockResponse { + fn serialize(&self, serializer: S) -> Result + where + S: serde::ser::Serializer, + { + let mut headers = BTreeMap::new(); + for (h, v) in self.headers.iter() { + headers.insert(h.as_str().into(), v.to_str().unwrap().into()); + } + let status = self.status.as_u16(); + let body = base64::encode(&self.body as &[u8]); + let s = SerializedMockResponse { + status, + headers, + body, + }; + s.serialize(serializer) + } +} + +#[derive(Serialize, Deserialize)] +pub(crate) struct SerializedMockResponse { + status: u16, + headers: BTreeMap, + body: String, +} diff --git a/sdk/core/src/mock_transaction.rs b/sdk/core/src/mock/mock_transaction.rs similarity index 85% rename from sdk/core/src/mock_transaction.rs rename to sdk/core/src/mock/mock_transaction.rs index 317454182a..2f29b5ee03 100644 --- a/sdk/core/src/mock_transaction.rs +++ b/sdk/core/src/mock/mock_transaction.rs @@ -2,12 +2,6 @@ use std::path::PathBuf; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; -pub mod constants { - pub const TESTING_MODE_KEY: &str = "TESTING_MODE"; - pub const TESTING_MODE_REPLAY: &str = "REPLAY"; - pub const TESTING_MODE_RECORD: &str = "RECORD"; -} - #[derive(Debug, Clone)] pub(crate) struct MockTransaction { pub(crate) name: String, @@ -37,9 +31,9 @@ impl MockTransaction { pub(crate) fn file_path( &self, create_when_not_exist: bool, - ) -> Result { + ) -> Result { let mut path = PathBuf::from(workspace_root().map_err(|e| { - crate::MockFrameworkError::TransactionStorageError(format!( + super::MockFrameworkError::TransactionStorageError(format!( "could not read the workspace_root from the cargo metadata: {}", e, )) @@ -59,13 +53,13 @@ impl MockTransaction { if !path.exists() { if create_when_not_exist { std::fs::create_dir_all(&path).map_err(|e| { - crate::MockFrameworkError::IOError( + super::MockFrameworkError::IOError( format!("cannot create transaction folder: {}", path.display()), e, ) })?; } else { - return Err(crate::MockFrameworkError::MissingTransaction(format!( + return Err(super::MockFrameworkError::MissingTransaction(format!( "the transaction location '{}' does not exist", path.canonicalize().unwrap_or(path).display() ))); diff --git a/sdk/core/src/mock/mod.rs b/sdk/core/src/mock/mod.rs new file mode 100644 index 0000000000..a3d0d9be32 --- /dev/null +++ b/sdk/core/src/mock/mod.rs @@ -0,0 +1,70 @@ +mod mock_request; +mod mock_response; +mod mock_transaction; +mod player_policy; +mod recorder_policy; + +use mock_transaction::MockTransaction; +use player_policy::MockTransportPlayerPolicy; +use recorder_policy::MockTransportRecorderPolicy; +use std::sync::Arc; + +pub const TESTING_MODE_KEY: &str = "TESTING_MODE"; +pub const TESTING_MODE_REPLAY: &str = "REPLAY"; +pub const TESTING_MODE_RECORD: &str = "RECORD"; + +/// An error relating to the mock transport framework. +#[cfg(feature = "mock_transport_framework")] +#[derive(Debug, thiserror::Error)] +pub(crate) enum MockFrameworkError { + #[error("{0}: {1}")] + IOError(String, std::io::Error), + #[error("{0}")] + TransactionStorageError(String), + #[error("{0}")] + MissingTransaction(String), + #[error("mismatched request uri. Actual '{0}', Expected: '{1}'")] + MismatchedRequestUri(String, String), + #[error("received request have header {0} but it was not present in the read request")] + MissingRequestHeader(String), + #[error("different number of headers in request. Actual: {0}, Expected: {1}")] + MismatchedRequestHeadersCount(usize, usize), + #[error("request header {0} value is different. Actual: {1}, Expected: {2}")] + MismatchedRequestHeader(String, String, String), + #[error("mismatched HTTP request method. Actual: {0}, Expected: {1}")] + MismatchedRequestHTTPMethod(http::Method, http::Method), + #[error("mismatched request body. Actual: {0:?}, Expected: {1:?}")] + MismatchedRequestBody(Vec, Vec), +} + +// Replace the default transport policy at runtime +// +// Replacement happens if these two conditions are met: +// 1. The mock_transport_framework is enabled +// 2. The environmental variable TESTING_MODE is either RECORD or PLAY +pub(crate) fn set_mock_transport_policy( + policy: &mut std::sync::Arc, + transport_options: crate::TransportOptions, +) { + match std::env::var(TESTING_MODE_KEY) + .as_deref() + .unwrap_or(TESTING_MODE_REPLAY) + { + TESTING_MODE_RECORD => { + log::warn!("mock testing framework record mode enabled"); + *policy = Arc::new(MockTransportRecorderPolicy::new(transport_options)) + } + TESTING_MODE_REPLAY => { + log::info!("mock testing framework replay mode enabled"); + *policy = Arc::new(MockTransportPlayerPolicy::new(transport_options)) + } + m => { + log::error!( + "invalid TESTING_MODE '{}' selected. Supported options are '{}' and '{}'", + m, + TESTING_MODE_RECORD, + TESTING_MODE_REPLAY + ); + } + }; +} diff --git a/sdk/core/src/policies/mock_transport_player_policy.rs b/sdk/core/src/mock/player_policy.rs similarity index 75% rename from sdk/core/src/policies/mock_transport_player_policy.rs rename to sdk/core/src/mock/player_policy.rs index 6a8b17dbb7..f342ddad9f 100644 --- a/sdk/core/src/policies/mock_transport_player_policy.rs +++ b/sdk/core/src/mock/player_policy.rs @@ -1,7 +1,7 @@ -use crate::bytes_response::BytesResponse; -use crate::mock_transaction::MockTransaction; +use super::mock_response::MockResponse; +use super::mock_transaction::MockTransaction; use crate::policies::{Policy, PolicyResult}; -use crate::{Context, MockFrameworkError, Request, Response, TransportOptions}; +use crate::{Context, Request, Response, TransportOptions}; use std::sync::Arc; #[derive(Debug, Clone)] @@ -48,7 +48,7 @@ impl Policy for MockTransportPlayerPolicy { }; let expected_request: Request = serde_json::from_str(&expected_request)?; - let expected_response = serde_json::from_str::(&expected_response)?; + let expected_response = serde_json::from_str::(&expected_response)?; let expected_uri = expected_request.uri().to_string(); let actual_uri = request @@ -57,7 +57,7 @@ impl Policy for MockTransportPlayerPolicy { .map(|p| p.to_string()) .unwrap_or_else(String::new); if expected_uri != actual_uri { - return Err(Box::new(MockFrameworkError::MismatchedRequestUri( + return Err(Box::new(super::MockFrameworkError::MismatchedRequestUri( actual_uri, expected_uri, ))); @@ -84,34 +84,40 @@ impl Policy for MockTransportPlayerPolicy { // 1. There are no extra headers (in both the received and read request). // 2. Each header has the same value. if actual_headers.len() != expected_headers.len() { - return Err(Box::new(MockFrameworkError::MismatchedRequestHeadersCount( - actual_headers.len(), - expected_headers.len(), - ))); + return Err(Box::new( + super::MockFrameworkError::MismatchedRequestHeadersCount( + actual_headers.len(), + expected_headers.len(), + ), + )); } for (actual_header_key, actual_header_value) in actual_headers.iter() { let (_, expected_header_value) = expected_headers .iter() .find(|(h, _)| actual_header_key.as_str() == h.as_str()) - .ok_or(MockFrameworkError::MissingRequestHeader( + .ok_or(super::MockFrameworkError::MissingRequestHeader( actual_header_key.as_str().to_owned(), ))?; if actual_header_value != expected_header_value { - return Err(Box::new(MockFrameworkError::MismatchedRequestHeader( - actual_header_key.as_str().to_owned(), - actual_header_value.to_str().unwrap().to_owned(), - expected_header_value.to_str().unwrap().to_owned(), - ))); + return Err(Box::new( + super::MockFrameworkError::MismatchedRequestHeader( + actual_header_key.as_str().to_owned(), + actual_header_value.to_str().unwrap().to_owned(), + expected_header_value.to_str().unwrap().to_owned(), + ), + )); } } if expected_request.method() != request.method() { - return Err(Box::new(MockFrameworkError::MismatchedRequestHTTPMethod( - expected_request.method(), - request.method(), - ))); + return Err(Box::new( + super::MockFrameworkError::MismatchedRequestHTTPMethod( + expected_request.method(), + request.method(), + ), + )); } let actual_body = match request.body() { @@ -125,7 +131,7 @@ impl Policy for MockTransportPlayerPolicy { }; if actual_body != expected_body { - return Err(Box::new(MockFrameworkError::MismatchedRequestBody( + return Err(Box::new(super::MockFrameworkError::MismatchedRequestBody( actual_body.to_vec(), expected_body.to_vec(), ))); diff --git a/sdk/core/src/policies/mock_transport_recorder_policy.rs b/sdk/core/src/mock/recorder_policy.rs similarity index 78% rename from sdk/core/src/policies/mock_transport_recorder_policy.rs rename to sdk/core/src/mock/recorder_policy.rs index 37e04d7125..03601f6855 100644 --- a/sdk/core/src/policies/mock_transport_recorder_policy.rs +++ b/sdk/core/src/mock/recorder_policy.rs @@ -1,7 +1,7 @@ -use crate::bytes_response::BytesResponse; -use crate::mock_transaction::MockTransaction; +use super::mock_response::MockResponse; +use super::MockTransaction; use crate::policies::{Policy, PolicyResult}; -use crate::{Context, MockFrameworkError, Request, Response, TransportOptions}; +use crate::{Context, Request, Response, TransportOptions}; use std::io::Write; use std::sync::Arc; @@ -45,7 +45,9 @@ impl Policy for MockTransportRecorderPolicy { let mut request_contents_stream = std::fs::File::create(&request_path).unwrap(); request_contents_stream .write_all(request_contents.as_str().as_bytes()) - .map_err(|e| MockFrameworkError::IOError("cannot write request file".into(), e))?; + .map_err(|e| { + super::MockFrameworkError::IOError("cannot write request file".into(), e) + })?; } let response = self @@ -56,13 +58,15 @@ impl Policy for MockTransportRecorderPolicy { // we need to duplicate the response because we are about to consume the response stream. // We replace the HTTP stream with a memory-backed stream. - let (response, bytes_response) = BytesResponse::duplicate(response).await?; - let response_contents = serde_json::to_string(&bytes_response).unwrap(); + let (response, mock_response) = MockResponse::duplicate(response).await?; + let response_contents = serde_json::to_string(&mock_response).unwrap(); { let mut response_contents_stream = std::fs::File::create(&response_path).unwrap(); response_contents_stream .write_all(response_contents.as_bytes()) - .map_err(|e| MockFrameworkError::IOError("cannot write response file".into(), e))?; + .map_err(|e| { + super::MockFrameworkError::IOError("cannot write response file".into(), e) + })?; } self.transaction.increment_number(); diff --git a/sdk/core/src/pipeline.rs b/sdk/core/src/pipeline.rs index 77e50100a1..76018d77c5 100644 --- a/sdk/core/src/pipeline.rs +++ b/sdk/core/src/pipeline.rs @@ -80,36 +80,8 @@ impl Pipeline { let mut policy: Arc = Arc::new(TransportPolicy::new(options.transport.clone())); - // This code replaces the default transport policy at runtime if these two conditions - // are met: - // 1. The mock_transport_framework is enabled - // 2. The environmental variable TESTING_MODE is either RECORD or PLAY #[cfg(feature = "mock_transport_framework")] - match std::env::var(crate::TESTING_MODE_KEY) - .as_deref() - .unwrap_or(crate::TESTING_MODE_REPLAY) - { - crate::TESTING_MODE_RECORD => { - log::warn!("mock testing framework record mode enabled"); - policy = Arc::new(crate::policies::MockTransportRecorderPolicy::new( - options.transport, - )) - } - crate::TESTING_MODE_REPLAY => { - log::info!("mock testing framework replay mode enabled"); - policy = Arc::new(crate::policies::MockTransportPlayerPolicy::new( - options.transport, - )) - } - m => { - log::error!( - "invalid TESTING_MODE '{}' selected. Supported options are '{}' and '{}'", - m, - crate::TESTING_MODE_RECORD, - crate::TESTING_MODE_REPLAY - ); - } - }; + crate::mock::set_mock_transport_policy(&mut policy, options.transport); pipeline.push(policy); } diff --git a/sdk/core/src/policies/mod.rs b/sdk/core/src/policies/mod.rs index bc6f5652ca..aa94b3258c 100644 --- a/sdk/core/src/policies/mod.rs +++ b/sdk/core/src/policies/mod.rs @@ -1,18 +1,10 @@ mod custom_headers_injector_policy; -#[cfg(feature = "mock_transport_framework")] -mod mock_transport_player_policy; -#[cfg(feature = "mock_transport_framework")] -mod mock_transport_recorder_policy; mod retry_policies; mod telemetry_policy; mod transport; use crate::{Context, Request, Response}; pub use custom_headers_injector_policy::{CustomHeaders, CustomHeadersInjectorPolicy}; -#[cfg(feature = "mock_transport_framework")] -pub use mock_transport_player_policy::MockTransportPlayerPolicy; -#[cfg(feature = "mock_transport_framework")] -pub use mock_transport_recorder_policy::MockTransportRecorderPolicy; pub use retry_policies::*; use std::error::Error; use std::sync::Arc; diff --git a/sdk/core/src/request.rs b/sdk/core/src/request.rs index 08071d4318..6bc2726b77 100644 --- a/sdk/core/src/request.rs +++ b/sdk/core/src/request.rs @@ -1,6 +1,5 @@ use crate::SeekableStream; use http::{HeaderMap, Method, Uri}; -use serde::ser::{Serialize, SerializeStruct, Serializer}; use std::fmt::Debug; /// An HTTP Body. @@ -24,66 +23,16 @@ impl From> for Body { } } -const FIELDS: &[&str] = &["uri", "method", "headers", "body"]; - /// A pipeline request. /// /// A pipeline request is composed by a destination (uri), a method, a collection of headers and a /// body. Policies are expected to enrich the request by mutating it. #[derive(Debug, Clone)] pub struct Request { - uri: Uri, - method: Method, - headers: HeaderMap, - body: Body, -} - -impl Request { - fn new(uri: Uri, method: Method, headers: HeaderMap, body: Body) -> Self { - Self { - uri, - method, - headers, - body, - } - } -} - -impl Serialize for Request { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - let mut hm = std::collections::BTreeMap::new(); - for (h, v) in self.headers().iter() { - if h.as_str().to_lowercase() == "authorization" { - hm.insert(h.to_string(), "<>"); - } else { - hm.insert(h.to_string(), v.to_str().unwrap()); - } - } - - let mut state = serializer.serialize_struct("Request", 4)?; - state.serialize_field( - FIELDS[0], - &self - .uri - .path_and_query() - .map(|p| p.to_string()) - .unwrap_or_else(String::new), - )?; - state.serialize_field(FIELDS[1], &self.method.to_string())?; - state.serialize_field(FIELDS[2], &hm)?; - state.serialize_field( - FIELDS[3], - &match &self.body { - Body::Bytes(bytes) => base64::encode(bytes as &[u8]), - Body::SeekableStream(_) => unimplemented!(), - }, - )?; - - state.end() - } + pub(crate) uri: Uri, + pub(crate) method: Method, + pub(crate) headers: HeaderMap, + pub(crate) body: Body, } impl Request { @@ -137,96 +86,3 @@ pub(crate) struct Parts { pub method: Method, pub headers: HeaderMap, } - -use serde::de::Visitor; -use serde::{Deserialize, Deserializer}; -use std::collections::HashMap; -use std::str::FromStr; - -impl<'de> Deserialize<'de> for Request { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - deserializer.deserialize_struct("Request", FIELDS, RequestVisitor) - } -} - -struct RequestVisitor; - -impl<'de> Visitor<'de> for RequestVisitor { - type Value = Request; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct Request") - } - - fn visit_map(self, mut map: A) -> Result - where - A: serde::de::MapAccess<'de>, - { - let uri: (&str, &str) = match map.next_entry()? { - Some((a, b)) => (a, b), - None => return Err(serde::de::Error::custom("missing uri")), - }; - - if uri.0 != FIELDS[0] { - return Err(serde::de::Error::custom(format!( - "unexpected field {}, expected {}", - uri.0, FIELDS[0] - ))); - } - - let method: (&str, &str) = match map.next_entry()? { - Some((a, b)) => (a, b), - None => return Err(serde::de::Error::custom("missing method")), - }; - - if method.0 != FIELDS[1] { - return Err(serde::de::Error::custom(format!( - "unexpected field {}, expected {}", - method.0, FIELDS[1] - ))); - } - - let headers: (&str, HashMap<&str, String>) = match map.next_entry()? { - Some((a, b)) => (a, b), - None => return Err(serde::de::Error::custom("missing header map")), - }; - if headers.0 != FIELDS[2] { - return Err(serde::de::Error::custom(format!( - "unexpected field {}, expected {}", - headers.0, FIELDS[2] - ))); - } - - let body: (&str, String) = match map.next_entry()? { - Some((a, b)) => (a, b), - None => return Err(serde::de::Error::custom("missing body")), - }; - if body.0 != FIELDS[3] { - return Err(serde::de::Error::custom(format!( - "unexpected field {}, expected {}", - body.0, FIELDS[3] - ))); - } - - let body = base64::decode(&body.1).map_err(serde::de::Error::custom)?; - - let mut hm = HeaderMap::new(); - for (k, v) in headers.1.into_iter() { - hm.append( - http::header::HeaderName::from_lowercase(k.as_bytes()) - .map_err(serde::de::Error::custom)?, - http::HeaderValue::from_str(&v).map_err(serde::de::Error::custom)?, - ); - } - - Ok(Self::Value::new( - Uri::from_str(uri.1).expect("expected a valid uri"), - Method::from_str(method.1).expect("expected a valid HTTP method"), - hm, - bytes::Bytes::from(body).into(), - )) - } -} diff --git a/sdk/core/src/response.rs b/sdk/core/src/response.rs index afeb555023..efd09a433a 100644 --- a/sdk/core/src/response.rs +++ b/sdk/core/src/response.rs @@ -1,5 +1,3 @@ -use crate::bytes_response::BytesResponse; -use crate::BytesStream; use crate::StreamError; use bytes::Bytes; use futures::Stream; @@ -16,7 +14,6 @@ pub(crate) struct ResponseBuilder { } impl ResponseBuilder { - #[allow(dead_code)] pub fn new(status: StatusCode) -> Self { Self { status, @@ -30,7 +27,6 @@ impl ResponseBuilder { self } - #[allow(dead_code)] pub fn with_pinned_stream(self, response: PinnedStream) -> Response { Response::new(self.status, self.headers, response) } @@ -108,17 +104,3 @@ pub async fn pinned_stream_into_utf8_string(stream: PinnedStream) -> String { .to_owned(); body } - -impl From for Response { - fn from(bytes_response: BytesResponse) -> Self { - let (status, headers, body) = bytes_response.deconstruct(); - - let bytes_stream: BytesStream = body.into(); - - Self { - status, - headers, - body: Box::pin(bytes_stream), - } - } -} diff --git a/sdk/cosmos/examples/create_delete_database.rs.disabled b/sdk/cosmos/examples/create_delete_database.rs.disabled index 0193e0a4f2..9691e01f5a 100644 --- a/sdk/cosmos/examples/create_delete_database.rs.disabled +++ b/sdk/cosmos/examples/create_delete_database.rs.disabled @@ -15,7 +15,7 @@ async fn main() -> Result<(), Box> { .nth(1) .expect("please specify database name as first command line parameter"); - azure_core::mock_transport::start_transaction("create_delete_database"); + azure_core::mock::start_transaction("create_delete_database"); // This is how you construct an authorization token. // Remember to pick the correct token type. diff --git a/sdk/cosmos/tests/setup.rs b/sdk/cosmos/tests/setup.rs index 82eb9316af..b44c2da53b 100644 --- a/sdk/cosmos/tests/setup.rs +++ b/sdk/cosmos/tests/setup.rs @@ -26,12 +26,12 @@ fn get_authorization_token() -> Result, ) -> Result { - let account_name = (std::env::var(azure_core::TESTING_MODE_KEY).as_deref() - == Ok(azure_core::TESTING_MODE_RECORD)) + let account_name = (std::env::var(azure_core::mock::TESTING_MODE_KEY).as_deref() + == Ok(azure_core::mock::TESTING_MODE_RECORD)) .then(get_account) .unwrap_or_else(String::new); - let authorization_token = (std::env::var(azure_core::TESTING_MODE_KEY).as_deref() - == Ok(azure_core::TESTING_MODE_RECORD)) + let authorization_token = (std::env::var(azure_core::mock::TESTING_MODE_KEY).as_deref() + == Ok(azure_core::mock::TESTING_MODE_RECORD)) .then(|| get_authorization_token().ok()) .flatten() .unwrap_or_else(|| AuthorizationToken::new_resource(String::new()));