From e9fc34ff626c13ec117f4ec9b091a69892bddf4f Mon Sep 17 00:00:00 2001 From: Jeeva Date: Tue, 2 May 2023 21:26:52 +0100 Subject: [PATCH] refactor(stripe): return all the missing fields in a request (#935) Co-authored-by: jeeva Co-authored-by: Sanchith Hegde <22217505+SanchithHegde@users.noreply.github.com> Co-authored-by: ItsMeShashank --- .gitignore | 2 + Cargo.lock | 13 +- crates/common_utils/Cargo.toml | 4 +- crates/common_utils/src/signals.rs | 41 ++++ crates/router/Cargo.toml | 4 +- .../src/connector/stripe/transformers.rs | 226 ++++++++++++++++-- crates/router/src/core/errors.rs | 2 + crates/router/src/macros.rs | 15 ++ crates/router/src/scheduler/utils.rs | 36 +++ 9 files changed, 315 insertions(+), 28 deletions(-) diff --git a/.gitignore b/.gitignore index ced4be3d728..a8e6412fb1a 100644 --- a/.gitignore +++ b/.gitignore @@ -256,3 +256,5 @@ loadtest/*.tmp/ # Nix output result* + +.idea/ diff --git a/Cargo.lock b/Cargo.lock index d065e56825f..0f7fa1ae308 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1221,7 +1221,7 @@ dependencies = [ "num-integer", "num-traits", "serde", - "time 0.1.43", + "time 0.1.45", "wasm-bindgen", "winapi", ] @@ -4420,11 +4420,12 @@ dependencies = [ [[package]] name = "time" -version = "0.1.43" +version = "0.1.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca8a50ef2360fbd1eeb0ecd46795a87a19024eb4b53c5dc916ca1fd95fe62438" +checksum = "1b797afad3f312d1c66a56d11d0316f916356d11bd158fbc6ca6389ff6bf805a" dependencies = [ "libc", + "wasi 0.10.0+wasi-snapshot-preview1", "winapi", ] @@ -5018,6 +5019,12 @@ version = "0.9.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519" +[[package]] +name = "wasi" +version = "0.10.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a143597ca7c7793eff794def352d41792a93c481eb1042423ff7ff72ba2c31f" + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" diff --git a/crates/common_utils/Cargo.toml b/crates/common_utils/Cargo.toml index bd3d6c6dca3..a8f24cb6603 100644 --- a/crates/common_utils/Cargo.toml +++ b/crates/common_utils/Cargo.toml @@ -37,7 +37,6 @@ ring = "0.16.20" serde = { version = "1.0.160", features = ["derive"] } serde_json = "1.0.96" serde_urlencoded = "0.7.1" -signal-hook-tokio = { version = "0.3.1", features = ["futures-v0_3"], optional = true } signal-hook = { version = "0.3.15", optional = true } tokio = { version = "1.27.0", features = ["macros", "rt-multi-thread"], optional = true } thiserror = "1.0.40" @@ -48,6 +47,9 @@ md5 = "0.7.0" masking = { version = "0.1.0", path = "../masking" } router_env = { version = "0.1.0", path = "../router_env", features = ["log_extra_implicit_fields", "log_custom_entries_to_extra"], optional = true } +[target.'cfg(not(target_os = "windows"))'.dependencies] +signal-hook-tokio = { version = "0.3.1", features = ["futures-v0_3"], optional = true } + [dev-dependencies] fake = "2.5.0" proptest = "1.1.0" diff --git a/crates/common_utils/src/signals.rs b/crates/common_utils/src/signals.rs index 44118f39e30..5bde366bf3c 100644 --- a/crates/common_utils/src/signals.rs +++ b/crates/common_utils/src/signals.rs @@ -1,6 +1,8 @@ //! Provide Interface for worker services to handle signals +#[cfg(not(target_os = "windows"))] use futures::StreamExt; +#[cfg(not(target_os = "windows"))] use router_env::logger; use tokio::sync::mpsc; @@ -8,6 +10,7 @@ use tokio::sync::mpsc; /// This functions is meant to run in parallel to the application. /// It will send a signal to the receiver when a SIGTERM or SIGINT is received /// +#[cfg(not(target_os = "windows"))] pub async fn signal_handler(mut sig: signal_hook_tokio::Signals, sender: mpsc::Sender<()>) { if let Some(signal) = sig.next().await { logger::info!( @@ -31,9 +34,47 @@ pub async fn signal_handler(mut sig: signal_hook_tokio::Signals, sender: mpsc::S } } +/// +/// This functions is meant to run in parallel to the application. +/// It will send a signal to the receiver when a SIGTERM or SIGINT is received +/// +#[cfg(target_os = "windows")] +pub async fn signal_handler(_sig: DummySignal, _sender: mpsc::Sender<()>) {} + /// /// This function is used to generate a list of signals that the signal_handler should listen for /// +#[cfg(not(target_os = "windows"))] pub fn get_allowed_signals() -> Result { signal_hook_tokio::Signals::new([signal_hook::consts::SIGTERM, signal_hook::consts::SIGINT]) } + +/// +/// This function is used to generate a list of signals that the signal_handler should listen for +/// +#[cfg(target_os = "windows")] +pub fn get_allowed_signals() -> Result { + Ok(DummySignal) +} + +/// +/// Dummy Signal Handler for windows +/// +#[cfg(target_os = "windows")] +#[derive(Debug, Clone)] +pub struct DummySignal; + +#[cfg(target_os = "windows")] +impl DummySignal { + /// + /// Dummy handler for signals in windows (empty) + /// + pub fn handle(&self) -> Self { + self.clone() + } + + /// + /// Hollow implementation, for windows compatibility + /// + pub fn close(self) {} +} diff --git a/crates/router/Cargo.toml b/crates/router/Cargo.toml index be6d7cf5bba..102d0287ff7 100644 --- a/crates/router/Cargo.toml +++ b/crates/router/Cargo.toml @@ -69,7 +69,6 @@ serde_path_to_error = "0.1.11" serde_qs = { version = "0.12.0", optional = true } serde_urlencoded = "0.7.1" serde_with = "2.3.2" -signal-hook-tokio = { version = "0.3.1", features = ["futures-v0_3"] } signal-hook = "0.3.15" strum = { version = "0.24.1", features = ["derive"] } thiserror = "1.0.40" @@ -94,6 +93,9 @@ aws-sdk-s3 = "0.25.0" aws-config = "0.55.1" infer = "0.13.0" +[target.'cfg(not(target_os = "windows"))'.dependencies] +signal-hook-tokio = { version = "0.3.1", features = ["futures-v0_3"]} + [build-dependencies] router_env = { version = "0.1.0", path = "../router_env", default-features = false } diff --git a/crates/router/src/connector/stripe/transformers.rs b/crates/router/src/connector/stripe/transformers.rs index 7b7cbe86945..6c355d52eef 100644 --- a/crates/router/src/connector/stripe/transformers.rs +++ b/crates/router/src/connector/stripe/transformers.rs @@ -1,6 +1,6 @@ use api_models::{self, enums as api_enums, payments}; use base64::Engine; -use common_utils::{errors::CustomResult, fp_utils, pii}; +use common_utils::{errors::CustomResult, pii}; use error_stack::{IntoReport, ResultExt}; use masking::{ExposeInterface, ExposeOptionInterface, Secret}; use serde::{Deserialize, Serialize}; @@ -8,7 +8,7 @@ use url::Url; use uuid::Uuid; use crate::{ - consts, + collect_missing_value_keys, consts, core::errors, services, types::{self, api, storage::enums}, @@ -429,30 +429,21 @@ fn validate_shipping_address_against_payment_method( payment_method: &StripePaymentMethodType, ) -> Result<(), error_stack::Report> { if let StripePaymentMethodType::AfterpayClearpay = payment_method { - fp_utils::when(shipping_address.name.is_none(), || { - Err(errors::ConnectorError::MissingRequiredField { - field_name: "shipping.address.first_name", - }) - })?; - - fp_utils::when(shipping_address.line1.is_none(), || { - Err(errors::ConnectorError::MissingRequiredField { - field_name: "shipping.address.line1", - }) - })?; - - fp_utils::when(shipping_address.country.is_none(), || { - Err(errors::ConnectorError::MissingRequiredField { - field_name: "shipping.address.country", - }) - })?; + let missing_fields = collect_missing_value_keys!( + ("shipping.address.first_name", shipping_address.name), + ("shipping.address.line1", shipping_address.line1), + ("shipping.address.country", shipping_address.country), + ("shipping.address.zip", shipping_address.zip) + ); - fp_utils::when(shipping_address.zip.is_none(), || { - Err(errors::ConnectorError::MissingRequiredField { - field_name: "shipping.address.zip", + if !missing_fields.is_empty() { + return Err(errors::ConnectorError::MissingRequiredFields { + field_names: missing_fields, }) - })?; + .into_report(); + } } + Ok(()) } @@ -1799,3 +1790,192 @@ pub struct DisputeObj { pub dispute_id: String, pub status: String, } + +#[cfg(test)] +mod test_validate_shipping_address_against_payment_method { + #![allow(clippy::unwrap_used)] + use api_models::enums::CountryCode; + use masking::Secret; + + use crate::{ + connector::stripe::transformers::{ + validate_shipping_address_against_payment_method, StripePaymentMethodType, + StripeShippingAddress, + }, + core::errors, + }; + + #[test] + fn should_return_ok() { + // Arrange + let stripe_shipping_address = create_stripe_shipping_address( + Some("name".to_string()), + Some("line1".to_string()), + Some(CountryCode::AD), + Some("zip".to_string()), + ); + + let payment_method = &StripePaymentMethodType::AfterpayClearpay; + + //Act + let result = validate_shipping_address_against_payment_method( + &stripe_shipping_address, + payment_method, + ); + + // Assert + assert!(result.is_ok()); + } + + #[test] + fn should_return_err_for_empty_name() { + // Arrange + let stripe_shipping_address = create_stripe_shipping_address( + None, + Some("line1".to_string()), + Some(CountryCode::AD), + Some("zip".to_string()), + ); + + let payment_method = &StripePaymentMethodType::AfterpayClearpay; + + //Act + let result = validate_shipping_address_against_payment_method( + &stripe_shipping_address, + payment_method, + ); + + // Assert + assert!(result.is_err()); + let missing_fields = get_missing_fields(result.unwrap_err().current_context()).to_owned(); + assert_eq!(missing_fields.len(), 1); + assert_eq!(missing_fields[0], "shipping.address.first_name"); + } + + #[test] + fn should_return_err_for_empty_line1() { + // Arrange + let stripe_shipping_address = create_stripe_shipping_address( + Some("name".to_string()), + None, + Some(CountryCode::AD), + Some("zip".to_string()), + ); + + let payment_method = &StripePaymentMethodType::AfterpayClearpay; + + //Act + let result = validate_shipping_address_against_payment_method( + &stripe_shipping_address, + payment_method, + ); + + // Assert + assert!(result.is_err()); + let missing_fields = get_missing_fields(result.unwrap_err().current_context()).to_owned(); + assert_eq!(missing_fields.len(), 1); + assert_eq!(missing_fields[0], "shipping.address.line1"); + } + + #[test] + fn should_return_err_for_empty_country() { + // Arrange + let stripe_shipping_address = create_stripe_shipping_address( + Some("name".to_string()), + Some("line1".to_string()), + None, + Some("zip".to_string()), + ); + + let payment_method = &StripePaymentMethodType::AfterpayClearpay; + + //Act + let result = validate_shipping_address_against_payment_method( + &stripe_shipping_address, + payment_method, + ); + + // Assert + assert!(result.is_err()); + let missing_fields = get_missing_fields(result.unwrap_err().current_context()).to_owned(); + assert_eq!(missing_fields.len(), 1); + assert_eq!(missing_fields[0], "shipping.address.country"); + } + + #[test] + fn should_return_err_for_empty_zip() { + // Arrange + let stripe_shipping_address = create_stripe_shipping_address( + Some("name".to_string()), + Some("line1".to_string()), + Some(CountryCode::AD), + None, + ); + let payment_method = &StripePaymentMethodType::AfterpayClearpay; + + //Act + let result = validate_shipping_address_against_payment_method( + &stripe_shipping_address, + payment_method, + ); + + // Assert + assert!(result.is_err()); + let missing_fields = get_missing_fields(result.unwrap_err().current_context()).to_owned(); + assert_eq!(missing_fields.len(), 1); + assert_eq!(missing_fields[0], "shipping.address.zip"); + } + + #[test] + fn should_return_error_when_missing_multiple_fields() { + // Arrange + let expected_missing_field_names: Vec<&'static str> = + vec!["shipping.address.zip", "shipping.address.country"]; + let stripe_shipping_address = create_stripe_shipping_address( + Some("name".to_string()), + Some("line1".to_string()), + None, + None, + ); + let payment_method = &StripePaymentMethodType::AfterpayClearpay; + + //Act + let result = validate_shipping_address_against_payment_method( + &stripe_shipping_address, + payment_method, + ); + + // Assert + assert!(result.is_err()); + let missing_fields = get_missing_fields(result.unwrap_err().current_context()).to_owned(); + for field in missing_fields { + assert!(expected_missing_field_names.contains(&field)); + } + } + + fn get_missing_fields(connector_error: &errors::ConnectorError) -> Vec<&'static str> { + if let errors::ConnectorError::MissingRequiredFields { field_names } = connector_error { + return field_names.to_vec(); + } + + vec![] + } + + fn create_stripe_shipping_address( + name: Option, + line1: Option, + country: Option, + zip: Option, + ) -> StripeShippingAddress { + StripeShippingAddress { + name: name.map(Secret::new), + line1: line1.map(Secret::new), + country, + zip: zip.map(Secret::new), + city: Some(String::from("city")), + line2: Some(Secret::new(String::from("line2"))), + state: Some(Secret::new(String::from("state"))), + phone: Some(Secret::new(String::from("pbone number"))), + } + } +} diff --git a/crates/router/src/core/errors.rs b/crates/router/src/core/errors.rs index c4a6a7d863e..1377f3019a7 100644 --- a/crates/router/src/core/errors.rs +++ b/crates/router/src/core/errors.rs @@ -241,6 +241,8 @@ pub enum ConnectorError { ResponseHandlingFailed, #[error("Missing required field: {field_name}")] MissingRequiredField { field_name: &'static str }, + #[error("Missing required fields: {field_names:?}")] + MissingRequiredFields { field_names: Vec<&'static str> }, #[error("Failed to obtain authentication type")] FailedToObtainAuthType, #[error("Failed to obtain certificate")] diff --git a/crates/router/src/macros.rs b/crates/router/src/macros.rs index 2cd6310faf0..33ed43fcc7a 100644 --- a/crates/router/src/macros.rs +++ b/crates/router/src/macros.rs @@ -51,3 +51,18 @@ macro_rules! async_spawn { tokio::spawn(async move { $t }); }; } + +#[macro_export] +macro_rules! collect_missing_value_keys { + [$(($key:literal, $option:expr)),+] => { + { + let mut keys: Vec<&'static str> = Vec::new(); + $( + if $option.is_none() { + keys.push($key); + } + )* + keys + } + }; +} diff --git a/crates/router/src/scheduler/utils.rs b/crates/router/src/scheduler/utils.rs index 24ddf6b2a86..a58b02561a8 100644 --- a/crates/router/src/scheduler/utils.rs +++ b/crates/router/src/scheduler/utils.rs @@ -4,8 +4,11 @@ use std::{ }; use error_stack::{report, ResultExt}; +#[cfg(not(target_os = "windows"))] +use futures::StreamExt; use redis_interface::{RedisConnectionPool, RedisEntryId}; use router_env::opentelemetry; +use tokio::sync::oneshot; use uuid::Uuid; use super::{consumer, metrics, process_data, workflows}; @@ -376,3 +379,36 @@ where Ok(()) } } + +#[cfg(not(target_os = "windows"))] +pub(crate) async fn signal_handler( + mut sig: signal_hook_tokio::Signals, + sender: oneshot::Sender<()>, +) { + if let Some(signal) = sig.next().await { + logger::info!( + "Received signal: {:?}", + signal_hook::low_level::signal_name(signal) + ); + match signal { + signal_hook::consts::SIGTERM | signal_hook::consts::SIGINT => match sender.send(()) { + Ok(_) => { + logger::info!("Request for force shutdown received") + } + Err(_) => { + logger::error!( + "The receiver is closed, a termination call might already be sent" + ) + } + }, + _ => {} + } + } +} + +#[cfg(target_os = "windows")] +pub(crate) async fn signal_handler( + _sig: common_utils::signals::DummySignal, + _sender: oneshot::Sender<()>, +) { +}