diff --git a/Cargo.lock b/Cargo.lock index d99fcd25a..144a233e8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -361,6 +361,7 @@ dependencies = [ "hex 0.4.2 (registry+https://github.com/rust-lang/crates.io-index)", "jsonwebtoken 7.1.1 (registry+https://github.com/rust-lang/crates.io-index)", "lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)", + "mockall 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)", "mockito 0.26.0 (registry+https://github.com/rust-lang/crates.io-index)", "openssl 0.10.29 (registry+https://github.com/rust-lang/crates.io-index)", "regex 1.3.9 (registry+https://github.com/rust-lang/crates.io-index)", @@ -1033,6 +1034,11 @@ dependencies = [ "strsim 0.9.3 (registry+https://github.com/rust-lang/crates.io-index)", ] +[[package]] +name = "downcast" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" + [[package]] name = "dtoa" version = "0.4.5" @@ -1154,6 +1160,14 @@ dependencies = [ "miniz_oxide 0.3.6 (registry+https://github.com/rust-lang/crates.io-index)", ] +[[package]] +name = "float-cmp" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "num-traits 0.2.11 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "fnv" version = "1.0.6" @@ -1172,6 +1186,11 @@ name = "foreign-types-shared" version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" +[[package]] +name = "fragile" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" + [[package]] name = "fuchsia-cprng" version = "0.1.1" @@ -1883,6 +1902,31 @@ dependencies = [ "winapi 0.3.8 (registry+https://github.com/rust-lang/crates.io-index)", ] +[[package]] +name = "mockall" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)", + "downcast 0.10.0 (registry+https://github.com/rust-lang/crates.io-index)", + "fragile 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)", + "lazy_static 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)", + "mockall_derive 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)", + "predicates 1.0.5 (registry+https://github.com/rust-lang/crates.io-index)", + "predicates-tree 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "mockall_derive" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "cfg-if 0.1.10 (registry+https://github.com/rust-lang/crates.io-index)", + "proc-macro2 1.0.18 (registry+https://github.com/rust-lang/crates.io-index)", + "quote 1.0.6 (registry+https://github.com/rust-lang/crates.io-index)", + "syn 1.0.33 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "mockito" version = "0.26.0" @@ -1952,6 +1996,11 @@ dependencies = [ "version_check 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)", ] +[[package]] +name = "normalize-line-endings" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" + [[package]] name = "num-bigint" version = "0.2.6" @@ -2154,6 +2203,32 @@ name = "ppv-lite86" version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" +[[package]] +name = "predicates" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "difference 2.0.0 (registry+https://github.com/rust-lang/crates.io-index)", + "float-cmp 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)", + "normalize-line-endings 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)", + "predicates-core 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)", + "regex 1.3.9 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "predicates-core" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" + +[[package]] +name = "predicates-tree" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "predicates-core 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)", + "treeline 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "proc-macro-error" version = "1.0.3" @@ -3745,6 +3820,11 @@ name = "tower-service" version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" +[[package]] +name = "treeline" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" + [[package]] name = "trust-dns-proto" version = "0.18.0-alpha.2" @@ -4283,6 +4363,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" "checksum dirs-sys 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "afa0b23de8fd801745c471deffa6e12d248f962c9fd4b4c33787b055599bde7b" "checksum discard 1.0.4 (registry+https://github.com/rust-lang/crates.io-index)" = "212d0f5754cb6769937f4501cc0e67f4f4483c8d2c3e1e922ee9edbe4ab4c7c0" "checksum docopt 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "7f525a586d310c87df72ebcd98009e57f1cc030c8c268305287a476beb653969" +"checksum downcast 0.10.0 (registry+https://github.com/rust-lang/crates.io-index)" = "4bb454f0228b18c7f4c3b0ebbee346ed9c52e7443b0999cd543ff3571205701d" "checksum dtoa 0.4.5 (registry+https://github.com/rust-lang/crates.io-index)" = "4358a9e11b9a09cf52383b451b49a169e8d797b68aa02301ff586d70d9661ea3" "checksum either 1.5.3 (registry+https://github.com/rust-lang/crates.io-index)" = "bb1f6b1ce1c140482ea30ddd3335fc0024ac7ee112895426e0a629a6c20adfe3" "checksum encoding_rs 0.8.22 (registry+https://github.com/rust-lang/crates.io-index)" = "cd8d03faa7fe0c1431609dfad7bbe827af30f82e1e2ae6f7ee4fca6bd764bc28" @@ -4297,9 +4378,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index" "checksum fernet 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)" = "e7ac567fd75ce6bc28b68e63b5beaa3ce34f56bafd1122f64f8647c822e38a8b" "checksum fixedbitset 0.1.9 (registry+https://github.com/rust-lang/crates.io-index)" = "86d4de0081402f5e88cdac65c8dcdcc73118c1a7a465e2a05f0da05843a8ea33" "checksum flate2 1.0.13 (registry+https://github.com/rust-lang/crates.io-index)" = "6bd6d6f4752952feb71363cffc9ebac9411b75b87c6ab6058c40c8900cf43c0f" +"checksum float-cmp 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e1267f4ac4f343772758f7b1bdcbe767c218bbab93bb432acbf5162bbf85a6c4" "checksum fnv 1.0.6 (registry+https://github.com/rust-lang/crates.io-index)" = "2fad85553e09a6f881f739c29f0b00b0f01357c743266d478b68951ce23285f3" "checksum foreign-types 0.3.2 (registry+https://github.com/rust-lang/crates.io-index)" = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" "checksum foreign-types-shared 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" +"checksum fragile 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "69a039c3498dc930fe810151a34ba0c1c70b02b8625035592e74432f678591f2" "checksum fuchsia-cprng 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "a06f77d526c1a601b7c4cdd98f54b5eaabffc14d5f2f0296febdc7f357c6d3ba" "checksum fuchsia-zircon 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "2e9763c69ebaae630ba35f74888db465e49e259ba1bc0eda7d06f4a067615d82" "checksum fuchsia-zircon-sys 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "3dcaa9ae7725d12cdb85b3ad99a434db70b468c09ded17e012d86b5c1010f7a7" @@ -4377,12 +4460,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" "checksum mio-uds 0.6.7 (registry+https://github.com/rust-lang/crates.io-index)" = "966257a94e196b11bb43aca423754d87429960a768de9414f3691d6957abf125" "checksum miow 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "8c1f2f3b1cf331de6896aabf6e9d55dca90356cc9960cca7eaaf408a355ae919" "checksum miow 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "396aa0f2003d7df8395cb93e09871561ccc3e785f0acb369170e8cc74ddf9226" +"checksum mockall 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)" = "256489d4d106cd2bc9e98ed0337402db0044de0621745d5d9eb70a14295ff77b" +"checksum mockall_derive 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)" = "826e14e8643cb12103b56efb963e5f9640b69b0f7bdcc460002092df4b0e959f" "checksum mockito 0.26.0 (registry+https://github.com/rust-lang/crates.io-index)" = "835b02e32817ac0638e05d06effef43a82820bc454ae4d28f6502cc65d1ce74f" "checksum mozsvc-common 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "efdfe5192ed6adb12e2f703d7a5f3facdfc3bda787a004930ee7ed2859aceb2e" "checksum native-tls 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "4b2df1a4c22fd44a62147fd8f13dd0f95c9d8ca7b2610299b2a2f9cf8964274e" "checksum net2 0.2.33 (registry+https://github.com/rust-lang/crates.io-index)" = "42550d9fb7b6684a6d404d9fa7250c2eb2646df731d1c06afc06dcee9e1bcf88" "checksum nodrop 0.1.14 (registry+https://github.com/rust-lang/crates.io-index)" = "72ef4a56884ca558e5ddb05a1d1e7e1bfd9a68d9ed024c21704cc98872dae1bb" "checksum nom 5.1.1 (registry+https://github.com/rust-lang/crates.io-index)" = "0b471253da97532da4b61552249c521e01e736071f71c1a4f7ebbfbf0a06aad6" +"checksum normalize-line-endings 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "61807f77802ff30975e01f4f071c8ba10c022052f98b3294119f3e615d13e5be" "checksum num-bigint 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)" = "090c7f9998ee0ff65aa5b723e4009f7b217707f1fb5ea551329cc4d6231fb304" "checksum num-integer 0.1.42 (registry+https://github.com/rust-lang/crates.io-index)" = "3f6ea62e9d81a77cd3ee9a2a5b9b609447857f3d358704331e4ef39eb247fcba" "checksum num-traits 0.1.43 (registry+https://github.com/rust-lang/crates.io-index)" = "92e5113e9fd4cc14ded8e499429f396a20f98c772a47cc8622a736e1ec843c31" @@ -4408,6 +4494,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" "checksum pin-utils 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" "checksum pkg-config 0.3.17 (registry+https://github.com/rust-lang/crates.io-index)" = "05da548ad6865900e60eaba7f589cc0783590a92e940c26953ff81ddbab2d677" "checksum ppv-lite86 0.2.6 (registry+https://github.com/rust-lang/crates.io-index)" = "74490b50b9fbe561ac330df47c08f3f33073d2d00c150f719147d7c54522fa1b" +"checksum predicates 1.0.5 (registry+https://github.com/rust-lang/crates.io-index)" = "96bfead12e90dccead362d62bb2c90a5f6fc4584963645bc7f71a735e0b0735a" +"checksum predicates-core 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "06075c3a3e92559ff8929e7a280684489ea27fe44805174c3ebd9328dcb37178" +"checksum predicates-tree 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "8e63c4859013b38a76eca2414c64911fba30def9e3202ac461a2d22831220124" "checksum proc-macro-error 1.0.3 (registry+https://github.com/rust-lang/crates.io-index)" = "fc175e9777c3116627248584e8f8b3e2987405cabe1c0adf7d1dd28f09dc7880" "checksum proc-macro-error-attr 1.0.3 (registry+https://github.com/rust-lang/crates.io-index)" = "3cc9795ca17eb581285ec44936da7fc2335a3f34f2ddd13118b6f4d515435c50" "checksum proc-macro-hack 0.5.16 (registry+https://github.com/rust-lang/crates.io-index)" = "7e0456befd48169b9f13ef0f0ad46d492cf9d2dbb918bcf38e01eed4ce3ec5e4" @@ -4558,6 +4647,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" "checksum tokio-util 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)" = "be8242891f2b6cbef26a2d7e8605133c2c554cd35b3e4948ea892d6d68436499" "checksum toml 0.5.6 (registry+https://github.com/rust-lang/crates.io-index)" = "ffc92d160b1eef40665be3a05630d003936a3bc7da7421277846c2613e92c71a" "checksum tower-service 0.3.0 (registry+https://github.com/rust-lang/crates.io-index)" = "e987b6bf443f4b5b3b6f38704195592cca41c5bb7aedd3c3693c7081f8289860" +"checksum treeline 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "a7f741b240f1a48843f9b8e0444fb55fb2a4ff67293b50a9179dfd5ea67f8d41" "checksum trust-dns-proto 0.18.0-alpha.2 (registry+https://github.com/rust-lang/crates.io-index)" = "2a7f3a2ab8a919f5eca52a468866a67ed7d3efa265d48a652a9a3452272b413f" "checksum trust-dns-resolver 0.18.0-alpha.2 (registry+https://github.com/rust-lang/crates.io-index)" = "6f90b1502b226f8b2514c6d5b37bafa8c200d7ca4102d57dc36ee0f3b7a04a2f" "checksum try-lock 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)" = "e604eb7b43c06650e854be16a2a03155743d3752dd1c943f6829e26b7a36e382" diff --git a/autoendpoint/Cargo.toml b/autoendpoint/Cargo.toml index a00d3f142..8081e7dad 100644 --- a/autoendpoint/Cargo.toml +++ b/autoendpoint/Cargo.toml @@ -45,6 +45,7 @@ validator_derive = "0.10.0" yup-oauth2 = "4.1.2" [dev-dependencies] +mockall = "0.7.1" mockito = "0.26.0" tempfile = "3.1.0" tokio = { version = "0.2.12", features = ["macros"] } diff --git a/autoendpoint/src/db/client.rs b/autoendpoint/src/db/client.rs index 225f5c5b0..dc3029198 100644 --- a/autoendpoint/src/db/client.rs +++ b/autoendpoint/src/db/client.rs @@ -3,6 +3,7 @@ use crate::db::retry::{ retry_policy, retryable_delete_error, retryable_describe_table_error, retryable_getitem_error, retryable_putitem_error, retryable_updateitem_error, }; +use async_trait::async_trait; use autopush_common::db::{DynamoDbNotification, DynamoDbUser}; use autopush_common::notification::Notification; use autopush_common::util::sec_since_epoch; @@ -22,15 +23,62 @@ use uuid::Uuid; const MAX_CHANNEL_TTL: u64 = 30 * 24 * 60 * 60; /// Provides high-level operations over the DynamoDB database +#[async_trait] +pub trait DbClient: Send + Sync { + /// Add a new user to the database. An error will occur if the user already + /// exists. + async fn add_user(&self, user: &DynamoDbUser) -> DbResult<()>; + + /// Read a user from the database + async fn get_user(&self, uaid: Uuid) -> DbResult>; + + /// Delete a user from the router table + async fn remove_user(&self, uaid: Uuid) -> DbResult<()>; + + /// Add a channel to a user + async fn add_channel(&self, uaid: Uuid, channel_id: Uuid) -> DbResult<()>; + + /// Get the set of channel IDs for a user + async fn get_channels(&self, uaid: Uuid) -> DbResult>; + + /// Remove the node ID from a user in the router table. + /// The node ID will only be cleared if `connected_at` matches up with the + /// item's `connected_at`. + async fn remove_node_id(&self, uaid: Uuid, node_id: String, connected_at: u64) -> DbResult<()>; + + /// Save a message to the message table + async fn save_message(&self, uaid: Uuid, message: Notification) -> DbResult<()>; + + /// Delete a notification + async fn remove_message(&self, uaid: Uuid, sort_key: String) -> DbResult<()>; + + /// Check if the router table exists + async fn router_table_exists(&self) -> DbResult; + + /// Check if the message table exists + async fn message_table_exists(&self) -> DbResult; + + /// Get the message table name + fn message_table(&self) -> &str; + + fn box_clone(&self) -> Box; +} + +impl Clone for Box { + fn clone(&self) -> Self { + self.box_clone() + } +} + #[derive(Clone)] -pub struct DbClient { +pub struct DbClientImpl { ddb: DynamoDbClient, metrics: StatsdClient, router_table: String, - pub message_table: String, + message_table: String, } -impl DbClient { +impl DbClientImpl { pub fn new( metrics: StatsdClient, router_table: String, @@ -57,9 +105,37 @@ impl DbClient { }) } - /// Add a new user to the database. An error will occur if the user already - /// exists. - pub async fn add_user(&self, user: &DynamoDbUser) -> DbResult<()> { + /// Check if a table exists + async fn table_exists(&self, table_name: String) -> DbResult { + let input = DescribeTableInput { table_name }; + + let output = match retry_policy() + .retry_if( + || self.ddb.describe_table(input.clone()), + retryable_describe_table_error(self.metrics.clone()), + ) + .await + { + Ok(output) => output, + Err(RusotoError::Service(DescribeTableError::ResourceNotFound(_))) => { + return Ok(false); + } + Err(e) => return Err(e.into()), + }; + + let status = output + .table + .and_then(|table| table.table_status) + .ok_or(DbError::TableStatusUnknown)? + .to_uppercase(); + + Ok(["CREATING", "UPDATING", "ACTIVE"].contains(&status.as_str())) + } +} + +#[async_trait] +impl DbClient for DbClientImpl { + async fn add_user(&self, user: &DynamoDbUser) -> DbResult<()> { let input = PutItemInput { item: serde_dynamodb::to_hashmap(user)?, table_name: self.router_table.to_string(), @@ -76,8 +152,7 @@ impl DbClient { Ok(()) } - /// Read a user from the database - pub async fn get_user(&self, uaid: Uuid) -> DbResult> { + async fn get_user(&self, uaid: Uuid) -> DbResult> { let input = GetItemInput { table_name: self.router_table.clone(), consistent_read: Some(true), @@ -97,8 +172,7 @@ impl DbClient { .map_err(DbError::from) } - /// Delete a user from the router table - pub async fn drop_user(&self, uaid: Uuid) -> DbResult<()> { + async fn remove_user(&self, uaid: Uuid) -> DbResult<()> { let input = DeleteItemInput { table_name: self.router_table.clone(), key: ddb_item! { uaid: s => uaid.to_simple().to_string() }, @@ -114,8 +188,7 @@ impl DbClient { Ok(()) } - /// Add a channel to a user - pub async fn add_channel(&self, uaid: Uuid, channel_id: Uuid) -> DbResult<()> { + async fn add_channel(&self, uaid: Uuid, channel_id: Uuid) -> DbResult<()> { let input = UpdateItemInput { table_name: self.message_table.clone(), key: ddb_item! { @@ -139,8 +212,7 @@ impl DbClient { Ok(()) } - /// Get the set of channel IDs for a user - pub async fn get_channels(&self, uaid: Uuid) -> DbResult> { + async fn get_channels(&self, uaid: Uuid) -> DbResult> { // Channel IDs are stored in a special row in the message table, where // chidmessageid = " " let input = GetItemInput { @@ -180,15 +252,7 @@ impl DbClient { Ok(channels) } - /// Remove the node ID from a user in the router table. - /// The node ID will only be cleared if `connected_at` matches up with the - /// item's `connected_at`. - pub async fn remove_node_id( - &self, - uaid: Uuid, - node_id: String, - connected_at: u64, - ) -> DbResult<()> { + async fn remove_node_id(&self, uaid: Uuid, node_id: String, connected_at: u64) -> DbResult<()> { let input = UpdateItemInput { key: ddb_item! { uaid: s => uaid.to_simple().to_string() }, update_expression: Some("REMOVE node_id".to_string()), @@ -211,8 +275,7 @@ impl DbClient { Ok(()) } - /// Store a single message - pub async fn store_message(&self, uaid: Uuid, message: Notification) -> DbResult<()> { + async fn save_message(&self, uaid: Uuid, message: Notification) -> DbResult<()> { let input = PutItemInput { item: serde_dynamodb::to_hashmap(&DynamoDbNotification::from_notif(&uaid, message))?, table_name: self.message_table.clone(), @@ -229,8 +292,7 @@ impl DbClient { Ok(()) } - /// Delete a notification - pub async fn delete_message(&self, uaid: Uuid, sort_key: String) -> DbResult<()> { + async fn remove_message(&self, uaid: Uuid, sort_key: String) -> DbResult<()> { let input = DeleteItemInput { table_name: self.message_table.clone(), key: ddb_item! { @@ -249,39 +311,19 @@ impl DbClient { Ok(()) } - /// Check if the router table exists - pub async fn router_table_exists(&self) -> DbResult { + async fn router_table_exists(&self) -> DbResult { self.table_exists(self.router_table.clone()).await } - /// Check if the message table exists - pub async fn message_table_exists(&self) -> DbResult { + async fn message_table_exists(&self) -> DbResult { self.table_exists(self.message_table.clone()).await } - /// Check if a table exists - async fn table_exists(&self, table_name: String) -> DbResult { - let describe_item = DescribeTableInput { table_name }; - - let output = match retry_policy() - .retry_if( - || self.ddb.describe_table(describe_item.clone()), - retryable_describe_table_error(self.metrics.clone()), - ) - .await - { - Ok(output) => output, - Err(RusotoError::Service(DescribeTableError::ResourceNotFound(_))) => { - return Ok(false); - } - Err(e) => return Err(e.into()), - }; - - let status = output - .table - .and_then(|table| table.table_status) - .ok_or(DbError::TableStatusUnknown)?; + fn message_table(&self) -> &str { + &self.message_table + } - Ok(["CREATING", "UPDATING", "ACTIVE"].contains(&status.as_str())) + fn box_clone(&self) -> Box { + Box::new(self.clone()) } } diff --git a/autoendpoint/src/db/mock.rs b/autoendpoint/src/db/mock.rs new file mode 100644 index 000000000..f81bad99e --- /dev/null +++ b/autoendpoint/src/db/mock.rs @@ -0,0 +1,102 @@ +// mockall::mock currently generates these warnings +#![allow(clippy::unused_unit)] +#![allow(clippy::ptr_arg)] + +use crate::db::client::DbClient; +use crate::db::error::DbResult; +use async_trait::async_trait; +use autopush_common::db::DynamoDbUser; +use autopush_common::notification::Notification; +use std::collections::HashSet; +use std::sync::Arc; +use uuid::Uuid; + +// mockall currently has issues mocking async traits with #[automock], so we use +// this workaround. See https://github.com/asomers/mockall/issues/75 +mockall::mock! { + pub DbClient { + fn add_user(&self, user: &DynamoDbUser) -> DbResult<()>; + + fn get_user(&self, uaid: Uuid) -> DbResult>; + + fn remove_user(&self, uaid: Uuid) -> DbResult<()>; + + fn add_channel(&self, uaid: Uuid, channel_id: Uuid) -> DbResult<()>; + + fn get_channels(&self, uaid: Uuid) -> DbResult>; + + fn remove_node_id(&self, uaid: Uuid, node_id: String, connected_at: u64) -> DbResult<()>; + + fn save_message(&self, uaid: Uuid, message: Notification) -> DbResult<()>; + + fn remove_message(&self, uaid: Uuid, sort_key: String) -> DbResult<()>; + + fn router_table_exists(&self) -> DbResult; + + fn message_table_exists(&self) -> DbResult; + + fn message_table(&self) -> &str; + + fn box_clone(&self) -> Box; + } +} + +#[async_trait] +impl DbClient for Arc { + async fn add_user(&self, user: &DynamoDbUser) -> DbResult<()> { + Arc::as_ref(self).add_user(user) + } + + async fn get_user(&self, uaid: Uuid) -> DbResult> { + Arc::as_ref(self).get_user(uaid) + } + + async fn remove_user(&self, uaid: Uuid) -> DbResult<()> { + Arc::as_ref(self).remove_user(uaid) + } + + async fn add_channel(&self, uaid: Uuid, channel_id: Uuid) -> DbResult<()> { + Arc::as_ref(self).add_channel(uaid, channel_id) + } + + async fn get_channels(&self, uaid: Uuid) -> DbResult> { + Arc::as_ref(self).get_channels(uaid) + } + + async fn remove_node_id(&self, uaid: Uuid, node_id: String, connected_at: u64) -> DbResult<()> { + Arc::as_ref(self).remove_node_id(uaid, node_id, connected_at) + } + + async fn save_message(&self, uaid: Uuid, message: Notification) -> DbResult<()> { + Arc::as_ref(self).save_message(uaid, message) + } + + async fn remove_message(&self, uaid: Uuid, sort_key: String) -> DbResult<()> { + Arc::as_ref(self).remove_message(uaid, sort_key) + } + + async fn router_table_exists(&self) -> DbResult { + Arc::as_ref(self).router_table_exists() + } + + async fn message_table_exists(&self) -> DbResult { + Arc::as_ref(self).message_table_exists() + } + + fn message_table(&self) -> &str { + Arc::as_ref(self).message_table() + } + + fn box_clone(&self) -> Box { + Box::new(Arc::clone(self)) + } +} + +impl MockDbClient { + /// Convert into a type which can be used in place of `Box`. + /// Arc is used so that the mock can be cloned. Box is used so it can be + /// easily cast to `Box`. + pub fn into_boxed_arc(self) -> Box> { + Box::new(Arc::new(self)) + } +} diff --git a/autoendpoint/src/db/mod.rs b/autoendpoint/src/db/mod.rs index 48ede147a..78111cb59 100644 --- a/autoendpoint/src/db/mod.rs +++ b/autoendpoint/src/db/mod.rs @@ -5,3 +5,6 @@ pub mod client; pub mod error; mod retry; + +#[cfg(test)] +pub mod mock; diff --git a/autoendpoint/src/extractors/user.rs b/autoendpoint/src/extractors/user.rs index 4f199579c..e62878aa9 100644 --- a/autoendpoint/src/extractors/user.rs +++ b/autoendpoint/src/extractors/user.rs @@ -23,13 +23,13 @@ pub async fn validate_user( Ok(router_type) => router_type, Err(_) => { debug!("Unknown router type, dropping user"; "user" => ?user); - drop_user(user.uaid, &state.ddb, &state.metrics).await?; + drop_user(user.uaid, state.ddb.as_ref(), &state.metrics).await?; return Err(ApiErrorKind::NoSubscription.into()); } }; if router_type == RouterType::WebPush { - validate_webpush_user(user, channel_id, &state.ddb, &state.metrics).await?; + validate_webpush_user(user, channel_id, state.ddb.as_ref(), &state.metrics).await?; } Ok(router_type) @@ -39,7 +39,7 @@ pub async fn validate_user( async fn validate_webpush_user( user: &DynamoDbUser, channel_id: &Uuid, - ddb: &DbClient, + ddb: &dyn DbClient, metrics: &StatsdClient, ) -> ApiResult<()> { // Make sure the user is active (has a valid message table) @@ -52,7 +52,7 @@ async fn validate_webpush_user( } }; - if ddb.message_table.as_str() != message_table { + if ddb.message_table() != message_table { debug!("User is inactive, dropping user"; "user" => ?user); drop_user(user.uaid, ddb, metrics).await?; return Err(ApiErrorKind::NoSubscription.into()); @@ -69,13 +69,13 @@ async fn validate_webpush_user( } /// Drop a user and increment associated metric -async fn drop_user(uaid: Uuid, ddb: &DbClient, metrics: &StatsdClient) -> ApiResult<()> { +async fn drop_user(uaid: Uuid, ddb: &dyn DbClient, metrics: &StatsdClient) -> ApiResult<()> { metrics .incr_with_tags("updates.drop_user") .with_tag("errno", "102") .send(); - ddb.drop_user(uaid).await?; + ddb.remove_user(uaid).await?; Ok(()) } diff --git a/autoendpoint/src/routers/fcm/router.rs b/autoendpoint/src/routers/fcm/router.rs index 45210b655..96cc2f180 100644 --- a/autoendpoint/src/routers/fcm/router.rs +++ b/autoendpoint/src/routers/fcm/router.rs @@ -1,3 +1,4 @@ +use crate::db::client::DbClient; use crate::error::{ApiError, ApiResult}; use crate::extractors::notification::Notification; use crate::extractors::router_data_input::RouterDataInput; @@ -12,6 +13,7 @@ use serde_json::Value; use std::collections::hash_map::RandomState; use std::collections::HashMap; use url::Url; +use uuid::Uuid; /// 28 days const MAX_TTL: usize = 28 * 24 * 60 * 60; @@ -21,6 +23,7 @@ pub struct FcmRouter { settings: FcmSettings, endpoint_url: Url, metrics: StatsdClient, + ddb: Box, /// A map from application ID to an authenticated FCM client clients: HashMap, } @@ -32,6 +35,7 @@ impl FcmRouter { endpoint_url: Url, http: reqwest::Client, metrics: StatsdClient, + ddb: Box, ) -> Result { let credentials = settings.credentials()?; let clients = Self::create_clients(&settings, credentials, http.clone()) @@ -42,6 +46,7 @@ impl FcmRouter { settings, endpoint_url, metrics, + ddb, clients, }) } @@ -91,7 +96,7 @@ impl FcmRouter { } /// Handle an error by logging, updating metrics, etc - fn handle_error(&self, error: FcmError) -> ApiError { + async fn handle_error(&self, error: FcmError, uaid: Uuid) -> ApiError { match &error { FcmError::FcmAuthentication => { error!("FCM authentication error"); @@ -106,8 +111,12 @@ impl FcmRouter { self.incr_error_metric("connection_unavailable"); } FcmError::FcmNotFound => { - debug!("FCM recipient not found"); + debug!("FCM recipient not found, removing user"); self.incr_error_metric("recipient_gone"); + + if let Err(e) = self.ddb.remove_user(uaid).await { + warn!("Error while removing user due to FCM 404: {}", e); + } } FcmError::FcmUpstream { .. } | FcmError::FcmUnknown => { warn!("FCM error: {error}", error = error.to_string()); @@ -200,7 +209,9 @@ impl Router for FcmRouter { let client = self.clients.get(app_id).ok_or(FcmError::InvalidAppId)?; trace!("Sending message to FCM: {:?}", message_data); if let Err(e) = client.send(message_data, fcm_token.to_string(), ttl).await { - return Err(self.handle_error(e)); + return Err(self + .handle_error(e, notification.subscription.user.uaid) + .await); } // Sent successfully, update metrics and make response @@ -219,6 +230,8 @@ impl Router for FcmRouter { #[cfg(test)] mod tests { + use crate::db::client::DbClient; + use crate::db::mock::MockDbClient; use crate::error::ApiErrorKind; use crate::extractors::notification::Notification; use crate::extractors::notification_headers::NotificationHeaders; @@ -234,6 +247,7 @@ mod tests { use crate::routers::{Router, RouterResponse}; use autopush_common::db::DynamoDbUser; use cadence::StatsdClient; + use mockall::predicate; use std::collections::HashMap; use std::path::PathBuf; use url::Url; @@ -248,7 +262,7 @@ mod tests { } /// Create a router for testing, using the given service auth file - async fn make_router(auth_file: PathBuf) -> FcmRouter { + async fn make_router(auth_file: PathBuf, ddb: Box) -> FcmRouter { FcmRouter::new( FcmSettings { fcm_url: Url::parse(&mockito::server_url()).unwrap(), @@ -262,6 +276,7 @@ mod tests { Url::parse("http://localhost:8080/").unwrap(), reqwest::Client::new(), StatsdClient::from_sink("autopush", cadence::NopMetricSink), + ddb, ) .await .unwrap() @@ -311,7 +326,8 @@ mod tests { #[tokio::test] async fn successful_routing_no_data() { let service_file = make_service_file(); - let router = make_router(service_file.path().to_owned()).await; + let ddb = MockDbClient::new().into_boxed_arc(); + let router = make_router(service_file.path().to_owned(), ddb).await; let _token_mock = mock_token_endpoint(); let fcm_mock = mock_fcm_endpoint_builder() .match_body( @@ -345,7 +361,8 @@ mod tests { #[tokio::test] async fn successful_routing_with_data() { let service_file = make_service_file(); - let router = make_router(service_file.path().to_owned()).await; + let ddb = MockDbClient::new().into_boxed_arc(); + let router = make_router(service_file.path().to_owned(), ddb).await; let _token_mock = mock_token_endpoint(); let fcm_mock = mock_fcm_endpoint_builder() .match_body( @@ -386,7 +403,8 @@ mod tests { #[tokio::test] async fn missing_client() { let service_file = make_service_file(); - let router = make_router(service_file.path().to_owned()).await; + let ddb = MockDbClient::new().into_boxed_arc(); + let router = make_router(service_file.path().to_owned(), ddb).await; let _token_mock = mock_token_endpoint(); let fcm_mock = mock_fcm_endpoint_builder().expect(0).create(); let mut router_data = default_router_data(); @@ -408,4 +426,34 @@ mod tests { ); fcm_mock.assert(); } + + /// If the FCM user no longer exists (404), we drop the user from our database + #[tokio::test] + async fn no_fcm_user() { + let notification = make_notification(default_router_data(), None); + let mut ddb = MockDbClient::new(); + ddb.expect_remove_user() + .with(predicate::eq(notification.subscription.user.uaid)) + .times(1) + .return_once(move |_| Ok(())); + + let service_file = make_service_file(); + let router = make_router(service_file.path().to_owned(), ddb.into_boxed_arc()).await; + let _token_mock = mock_token_endpoint(); + let _fcm_mock = mock_fcm_endpoint_builder() + .with_status(404) + .with_body(r#"{"error":{"status":"NOT_FOUND","message":"test-message"}}"#) + .create(); + + let result = router.route_notification(¬ification).await; + assert!(result.is_err()); + assert!( + matches!( + result.as_ref().unwrap_err().kind, + ApiErrorKind::Router(RouterError::Fcm(FcmError::FcmNotFound)) + ), + "result = {:?}", + result + ); + } } diff --git a/autoendpoint/src/routers/webpush.rs b/autoendpoint/src/routers/webpush.rs index 9e2abe58b..fc7bb2d9f 100644 --- a/autoendpoint/src/routers/webpush.rs +++ b/autoendpoint/src/routers/webpush.rs @@ -19,7 +19,7 @@ use uuid::Uuid; /// server is located via the database routing table. If the server is busy or /// not available, the notification is stored in the database. pub struct WebPushRouter { - pub ddb: DbClient, + pub ddb: Box, pub metrics: StatsdClient, pub http: reqwest::Client, pub endpoint_url: Url, @@ -151,7 +151,7 @@ impl WebPushRouter { /// Store a notification in the database async fn store_notification(&self, notification: &Notification) -> ApiResult<()> { self.ddb - .store_message( + .save_message( notification.subscription.user.uaid, notification.clone().into(), ) diff --git a/autoendpoint/src/routes/registration.rs b/autoendpoint/src/routes/registration.rs index 85733834d..47d90de6d 100644 --- a/autoendpoint/src/routes/registration.rs +++ b/autoendpoint/src/routes/registration.rs @@ -37,7 +37,7 @@ pub async fn register_uaid_route( let user = DynamoDbUser { router_type: path_args.router_type.to_string(), router_data: Some(router_data), - current_month: Some(state.ddb.message_table.clone()), + current_month: Some(state.ddb.message_table().to_string()), ..Default::default() }; let channel_id = router_data_input.channel_id.unwrap_or_else(Uuid::new_v4); diff --git a/autoendpoint/src/routes/webpush.rs b/autoendpoint/src/routes/webpush.rs index bb8ea02bd..9d94078bc 100644 --- a/autoendpoint/src/routes/webpush.rs +++ b/autoendpoint/src/routes/webpush.rs @@ -28,7 +28,7 @@ pub async fn delete_notification_route( trace!("message_id = {:?}", message_id); state .ddb - .delete_message(message_id.uaid(), sort_key) + .remove_message(message_id.uaid(), sort_key) .await?; Ok(HttpResponse::NoContent().finish()) diff --git a/autoendpoint/src/server.rs b/autoendpoint/src/server.rs index 9115d8f26..482c9455c 100644 --- a/autoendpoint/src/server.rs +++ b/autoendpoint/src/server.rs @@ -1,6 +1,6 @@ //! Main application server -use crate::db::client::DbClient; +use crate::db::client::{DbClient, DbClientImpl}; use crate::error::{ApiError, ApiResult}; use crate::metrics; use crate::routers::fcm::router::FcmRouter; @@ -23,7 +23,7 @@ pub struct ServerState { pub metrics: StatsdClient, pub settings: Settings, pub fernet: Arc, - pub ddb: DbClient, + pub ddb: Box, pub http: reqwest::Client, pub fcm_router: Arc, } @@ -35,11 +35,11 @@ impl Server { let metrics = metrics::metrics_from_opts(&settings)?; let bind_address = format!("{}:{}", settings.host, settings.port); let fernet = Arc::new(settings.make_fernet()); - let ddb = DbClient::new( + let ddb = Box::new(DbClientImpl::new( metrics.clone(), settings.router_table_name.clone(), settings.message_table_name.clone(), - )?; + )?); let http = reqwest::Client::new(); let fcm_router = Arc::new( FcmRouter::new( @@ -47,6 +47,7 @@ impl Server { settings.endpoint_url.clone(), http.clone(), metrics.clone(), + ddb.clone(), ) .await?, );