diff --git a/Cargo.lock b/Cargo.lock index d2b3d2021..7d52f5c6e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -54,7 +54,7 @@ dependencies = [ "config 0.8.0 (git+https://github.com/mehcode/config-rs?rev=e8fa9fee96185ddd18ebcef8a925c75459111edb)", "docopt 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)", "env_logger 0.5.10 (registry+https://github.com/rust-lang/crates.io-index)", - "error-chain 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)", + "error-chain 0.12.0 (registry+https://github.com/rust-lang/crates.io-index)", "fernet 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)", "futures 0.1.21 (registry+https://github.com/rust-lang/crates.io-index)", "futures-backoff 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)", @@ -429,7 +429,7 @@ dependencies = [ [[package]] name = "error-chain" -version = "0.11.0" +version = "0.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" dependencies = [ "backtrace 0.3.8 (registry+https://github.com/rust-lang/crates.io-index)", @@ -2051,7 +2051,7 @@ dependencies = [ "checksum env_logger 0.5.10 (registry+https://github.com/rust-lang/crates.io-index)" = "0e6e40ebb0e66918a37b38c7acab4e10d299e0463fe2af5d29b9cc86710cfd2a" "checksum erased-serde 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "c564e32677839f1c551664c478e079c9b128a1a2d223180bffb2ddfabeded0be" "checksum error-chain 0.10.0 (registry+https://github.com/rust-lang/crates.io-index)" = "d9435d864e017c3c6afeac1654189b06cdb491cf2ff73dbf0d73b0f292f42ff8" -"checksum error-chain 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)" = "ff511d5dc435d703f4971bc399647c9bc38e20cb41452e3b9feb4765419ed3f3" +"checksum error-chain 0.12.0 (registry+https://github.com/rust-lang/crates.io-index)" = "07e791d3be96241c77c43846b665ef1384606da2cd2a48730abe606a12906e02" "checksum fake-simd 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)" = "e88a8acf291dafb59c2d96e8f59828f3838bb1a70398823ade51a84de6a6deed" "checksum fernet 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "b001fae1d5ef9a63cb117462e0c6d76eca42891db7d6140ed12e8b1792277028" "checksum fixedbitset 0.1.9 (registry+https://github.com/rust-lang/crates.io-index)" = "86d4de0081402f5e88cdac65c8dcdcc73118c1a7a465e2a05f0da05843a8ea33" diff --git a/Cargo.toml b/Cargo.toml index c4bf60729..e4012ec74 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,7 @@ chan-signal = "0.3.1" chrono = "0.4.2" docopt = "1.0.0" env_logger = { version = "0.5.10", default-features = false } -error-chain = "0.11.0" +error-chain = "0.12.0" fernet = "0.1.0" futures = "0.1.21" futures-backoff = "0.1.0" diff --git a/src/client.rs b/src/client.rs index 79c010d6b..5b5e76c91 100644 --- a/src/client.rs +++ b/src/client.rs @@ -159,6 +159,7 @@ pub struct WebPushClient { // when all the unacked storeds are ack'd unacked_stored_highest: Option, connected_at: u64, + sent_from_storage: u32, stats: SessionStatistics, } @@ -175,6 +176,7 @@ impl Default for WebPushClient { unacked_stored_notifs: Default::default(), unacked_stored_highest: Default::default(), connected_at: Default::default(), + sent_from_storage: Default::default(), stats: Default::default(), } } @@ -188,9 +190,13 @@ impl WebPushClient { #[derive(Default)] pub struct ClientFlags { + /// Whether check_storage queries for topic (not "timestamped") messages include_topic: bool, + /// Flags the need to increment the last read for timestamp for timestamped messages increment_storage: bool, + /// Whether this client needs to check storage for messages check: bool, + /// Flags the need to drop the user record reset_uaid: bool, rotate_message_table: bool, } @@ -432,7 +438,6 @@ where }; let auth_state_machine = AuthClientState::start( vec![response], - false, AuthClientData { srv: srv.clone(), ws, @@ -549,10 +554,15 @@ where + Sink + 'static, { - #[state_machine_future(start, transitions(DetermineAck, SendThenWait))] - SendThenWait { - remaining_data: Vec, - poll_complete: bool, + #[state_machine_future(start, transitions(AwaitSend, DetermineAck))] + Send { + smessages: Vec, + data: AuthClientData, + }, + + #[state_machine_future(transitions(DetermineAck, Send, AwaitDropUser))] + AwaitSend { + smessages: Vec, data: AuthClientData, }, @@ -562,9 +572,7 @@ where DetermineAck { data: AuthClientData }, #[state_machine_future( - transitions( - DetermineAck, SendThenWait, AwaitInput, AwaitRegister, AwaitUnregister, AwaitDelete - ) + transitions(DetermineAck, Send, AwaitInput, AwaitRegister, AwaitUnregister, AwaitDelete) )] AwaitInput { data: AuthClientData }, @@ -580,7 +588,7 @@ where #[state_machine_future(transitions(AwaitCheckStorage))] CheckStorage { data: AuthClientData }, - #[state_machine_future(transitions(SendThenWait, DetermineAck))] + #[state_machine_future(transitions(Send, DetermineAck))] AwaitCheckStorage { response: MyFuture, data: AuthClientData, @@ -598,14 +606,14 @@ where data: AuthClientData, }, - #[state_machine_future(transitions(SendThenWait))] + #[state_machine_future(transitions(Send))] AwaitRegister { channel_id: Uuid, response: MyFuture, data: AuthClientData, }, - #[state_machine_future(transitions(SendThenWait))] + #[state_machine_future(transitions(Send))] AwaitUnregister { channel_id: Uuid, code: u32, @@ -632,27 +640,21 @@ where + Sink + 'static, { - fn poll_send_then_wait<'a>( - send: &'a mut RentToOwn<'a, SendThenWait>, - ) -> Poll, Error> { - trace!("State: SendThenWait"); - let start_send = { - let SendThenWait { - ref mut remaining_data, - poll_complete, + fn poll_send<'a>(send: &'a mut RentToOwn<'a, Send>) -> Poll, Error> { + trace!("State: Send"); + let sent = { + let Send { + ref mut smessages, ref mut data, .. } = **send; - if poll_complete { - try_ready!(data.ws.poll_complete()); - false - } else if !remaining_data.is_empty() { - let item = remaining_data.remove(0); + if !smessages.is_empty() { + let item = smessages.remove(0); let ret = data.ws.start_send(item).chain_err(|| "unable to send")?; match ret { AsyncSink::Ready => true, AsyncSink::NotReady(returned) => { - remaining_data.insert(0, returned); + smessages.insert(0, returned); return Ok(Async::NotReady); } } @@ -661,23 +663,34 @@ where } }; - let SendThenWait { - data, - remaining_data, - .. - } = send.take(); - if start_send { - transition!(SendThenWait { - remaining_data, - poll_complete: true, - data, - }); - } else if !remaining_data.is_empty() { - transition!(SendThenWait { - remaining_data, - poll_complete: false, - data, - }); + let Send { smessages, data } = send.take(); + if sent { + transition!(AwaitSend { smessages, data }); + } + transition!(DetermineAck { data }) + } + + fn poll_await_send<'a>( + await_send: &'a mut RentToOwn<'a, AwaitSend>, + ) -> Poll, Error> { + trace!("State: AwaitSend"); + try_ready!(await_send.data.ws.poll_complete()); + + let AwaitSend { smessages, data } = await_send.take(); + let webpush_rc = data.webpush.clone(); + let webpush = webpush_rc.borrow(); + if webpush.sent_from_storage > data.srv.opts.msg_limit { + // Exceeded the max limit of stored messages: drop the user to trigger a + // re-register + debug!("Dropping user: exceeded msg_limit"); + let response = Box::new( + data.srv + .ddb + .drop_uaid(&data.srv.opts.router_table_name, &webpush.uaid), + ); + transition!(AwaitDropUser { response, data }); + } else if !smessages.is_empty() { + transition!(Send { smessages, data }); } transition!(DetermineAck { data }) } @@ -703,6 +716,7 @@ where )); transition!(AwaitMigrateUser { response, data }); } else if all_acked && webpush.flags.reset_uaid { + debug!("Dropping user: flagged reset_uaid"); let response = Box::new( data.srv .ddb @@ -731,11 +745,10 @@ where ) }; if let Some(delta) = broadcast_delta { - transition!(SendThenWait { - remaining_data: vec![ServerMessage::Broadcast { + transition!(Send { + smessages: vec![ServerMessage::Broadcast { broadcasts: Broadcast::into_hashmap(delta), }], - poll_complete: false, data, }); } else { @@ -757,7 +770,8 @@ where let uaid = webpush.uaid; let message_month = webpush.message_month.clone(); let srv = data.srv.clone(); - let fut = data.srv + let fut = data + .srv .ddb .register(&srv, &uaid, &channel_id, &message_month, key); transition!(AwaitRegister { @@ -837,9 +851,8 @@ where } debug!("Got a notification to send, sending!"); emit_metrics_for_send(&data.srv.metrics, ¬if, "Direct"); - transition!(SendThenWait { - remaining_data: vec![ServerMessage::Notification(notif)], - poll_complete: false, + transition!(Send { + smessages: vec![ServerMessage::Notification(notif)], data, }); } @@ -925,53 +938,49 @@ where webpush.flags.include_topic = include_topic; debug!("Setting unacked stored highest to {:?}", timestamp); webpush.unacked_stored_highest = timestamp; + if messages.is_empty() { + webpush.flags.check = false; + webpush.sent_from_storage = 0; + transition!(DetermineAck { data }); + } + + // Filter out TTL expired messages + let now = sec_since_epoch(); + let srv = data.srv.clone(); + messages = messages + .into_iter() + .filter_map(|n| { + if !n.expired(now) { + return Some(n); + } + if n.sortkey_timestamp.is_none() { + srv.handle.spawn( + srv.ddb + .delete_message(&webpush.message_month, &webpush.uaid, &n) + .then(|_| { + debug!("Deleting expired message without sortkey_timestamp"); + Ok(()) + }), + ); + } + None + }) + .collect(); + webpush.flags.increment_storage = !include_topic && timestamp.is_some(); + // If there's still messages send them out if !messages.is_empty() { - // Filter out TTL expired messages - let now = sec_since_epoch() as u32; - let srv = data.srv.clone(); - messages = messages + webpush + .unacked_stored_notifs + .extend(messages.iter().cloned()); + let smessages: Vec<_> = messages .into_iter() - .filter_map(|n| { - if now >= n.ttl + n.timestamp { - if n.sortkey_timestamp.is_none() { - srv.handle.spawn( - srv.ddb - .delete_message(&webpush.message_month, &webpush.uaid, &n) - .then(|_| { - debug!( - "Deleting expired message without sortkey_timestamp" - ); - Ok(()) - }), - ); - } - None - } else { - Some(n) - } - }) + .inspect(|msg| emit_metrics_for_send(&data.srv.metrics, &msg, "Stored")) + .map(ServerMessage::Notification) .collect(); - webpush.flags.increment_storage = !include_topic && timestamp.is_some(); - // If there's still messages send them out - if !messages.is_empty() { - webpush - .unacked_stored_notifs - .extend(messages.iter().cloned()); - transition!(SendThenWait { - remaining_data: messages - .into_iter() - .inspect(|msg| emit_metrics_for_send(&data.srv.metrics, &msg, "Stored")) - .map(ServerMessage::Notification) - .collect(), - poll_complete: false, - data, - }) - } else { - // No messages remaining - transition!(DetermineAck { data }) - } + webpush.sent_from_storage += smessages.len() as u32; + transition!(Send { smessages, data }) } else { - webpush.flags.check = false; + // No messages remaining transition!(DetermineAck { data }) } } @@ -1028,9 +1037,8 @@ where } }; - transition!(SendThenWait { - remaining_data: vec![msg], - poll_complete: false, + transition!(Send { + smessages: vec![msg], data: await_register.take().data, }) } @@ -1061,9 +1069,8 @@ where .incr_with_tags("ua.command.unregister") .with_tag("code", &code.to_string()) .send(); - transition!(SendThenWait { - remaining_data: vec![msg], - poll_complete: false, + transition!(Send { + smessages: vec![msg], data, }) } diff --git a/src/protocol.rs b/src/protocol.rs index e318c6ae3..f76c75605 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -151,6 +151,10 @@ impl Notification { format!("{}:{}", chid, self.version) } } + + pub fn expired(&self, at_sec: u64) -> bool { + at_sec >= self.timestamp as u64 + self.ttl as u64 + } } fn default_ttl() -> u32 { diff --git a/src/server/mod.rs b/src/server/mod.rs index 7eb7a48e9..82ee881c5 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -140,6 +140,7 @@ pub struct ServerOptions { pub megaphone_api_token: Option, pub megaphone_poll_interval: Duration, pub human_logs: bool, + pub msg_limit: u32, } impl ServerOptions { @@ -197,6 +198,7 @@ impl ServerOptions { megaphone_poll_interval: ito_dur(settings.megaphone_poll_interval) .expect("megaphone poll interval cannot be 0"), human_logs: settings.human_logs, + msg_limit: settings.msg_limit, }; opts.message_table_names.sort_unstable(); opts.current_message_month = opts diff --git a/src/settings.rs b/src/settings.rs index 58be4c200..3812c4ee6 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -47,6 +47,7 @@ pub struct Settings { pub megaphone_api_token: Option, pub megaphone_poll_interval: u32, pub human_logs: bool, + pub msg_limit: u32, } impl Settings { @@ -72,6 +73,7 @@ impl Settings { s.set_default("statsd_port", 8125)?; s.set_default("megaphone_poll_interval", 30)?; s.set_default("human_logs", false)?; + s.set_default("msg_limit", 100)?; // Merge the configs from the files for filename in filenames { diff --git a/tests/test_integration.py b/tests/test_integration.py index 39532a2ba..0d2a39969 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -61,6 +61,7 @@ ROUTER_TABLE = os.environ.get("ROUTER_TABLE", "router_int_test") MESSAGE_TABLE = os.environ.get("MESSAGE_TABLE", "message_int_test") +MSG_LIMIT = 20 CRYPTO_KEY = Fernet.generate_key() CONNECTION_PORT = 9150 @@ -119,6 +120,7 @@ def setup_module(): close_handshake_timeout=5, max_connections=5000, human_logs="true", + msg_limit=MSG_LIMIT, ) rust_bin = root_dir + "/target/release/autopush_rs" possible_paths = ["/target/debug/autopush_rs", @@ -861,6 +863,26 @@ def test_with_bad_key(self): yield self.shut_down(client) + @inlineCallbacks + def test_msg_limit(self): + client = yield self.quick_register() + uaid = client.uaid + yield client.disconnect() + for i in range(MSG_LIMIT + 1): + yield client.send_notification() + yield client.connect() + yield client.hello() + assert client.uaid == uaid + for i in range(MSG_LIMIT): + result = yield client.get_notification() + assert result is not None + yield client.ack(result["channelID"], result["version"]) + yield client.disconnect() + yield client.connect() + yield client.hello() + assert client.uaid != uaid + yield self.shut_down(client) + class TestRustWebPushBroadcast(unittest.TestCase): _endpoint_defaults = dict(