diff --git a/src/client.rs b/src/client.rs index bbe33d659..5b5e76c91 100644 --- a/src/client.rs +++ b/src/client.rs @@ -159,6 +159,7 @@ pub struct WebPushClient { // when all the unacked storeds are ack'd unacked_stored_highest: Option, connected_at: u64, + sent_from_storage: u32, stats: SessionStatistics, } @@ -175,6 +176,7 @@ impl Default for WebPushClient { unacked_stored_notifs: Default::default(), unacked_stored_highest: Default::default(), connected_at: Default::default(), + sent_from_storage: Default::default(), stats: Default::default(), } } @@ -558,7 +560,7 @@ where data: AuthClientData, }, - #[state_machine_future(transitions(DetermineAck, Send))] + #[state_machine_future(transitions(DetermineAck, Send, AwaitDropUser))] AwaitSend { smessages: Vec, data: AuthClientData, @@ -675,7 +677,19 @@ where try_ready!(await_send.data.ws.poll_complete()); let AwaitSend { smessages, data } = await_send.take(); - if !smessages.is_empty() { + let webpush_rc = data.webpush.clone(); + let webpush = webpush_rc.borrow(); + if webpush.sent_from_storage > data.srv.opts.msg_limit { + // Exceeded the max limit of stored messages: drop the user to trigger a + // re-register + debug!("Dropping user: exceeded msg_limit"); + let response = Box::new( + data.srv + .ddb + .drop_uaid(&data.srv.opts.router_table_name, &webpush.uaid), + ); + transition!(AwaitDropUser { response, data }); + } else if !smessages.is_empty() { transition!(Send { smessages, data }); } transition!(DetermineAck { data }) @@ -702,6 +716,7 @@ where )); transition!(AwaitMigrateUser { response, data }); } else if all_acked && webpush.flags.reset_uaid { + debug!("Dropping user: flagged reset_uaid"); let response = Box::new( data.srv .ddb @@ -925,6 +940,7 @@ where webpush.unacked_stored_highest = timestamp; if messages.is_empty() { webpush.flags.check = false; + webpush.sent_from_storage = 0; transition!(DetermineAck { data }); } @@ -956,14 +972,13 @@ where webpush .unacked_stored_notifs .extend(messages.iter().cloned()); - transition!(Send { - smessages: messages - .into_iter() - .inspect(|msg| emit_metrics_for_send(&data.srv.metrics, &msg, "Stored")) - .map(ServerMessage::Notification) - .collect(), - data, - }) + let smessages: Vec<_> = messages + .into_iter() + .inspect(|msg| emit_metrics_for_send(&data.srv.metrics, &msg, "Stored")) + .map(ServerMessage::Notification) + .collect(); + webpush.sent_from_storage += smessages.len() as u32; + transition!(Send { smessages, data }) } else { // No messages remaining transition!(DetermineAck { data }) diff --git a/src/server/mod.rs b/src/server/mod.rs index 7eb7a48e9..82ee881c5 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -140,6 +140,7 @@ pub struct ServerOptions { pub megaphone_api_token: Option, pub megaphone_poll_interval: Duration, pub human_logs: bool, + pub msg_limit: u32, } impl ServerOptions { @@ -197,6 +198,7 @@ impl ServerOptions { megaphone_poll_interval: ito_dur(settings.megaphone_poll_interval) .expect("megaphone poll interval cannot be 0"), human_logs: settings.human_logs, + msg_limit: settings.msg_limit, }; opts.message_table_names.sort_unstable(); opts.current_message_month = opts diff --git a/src/settings.rs b/src/settings.rs index 58be4c200..3812c4ee6 100644 --- a/src/settings.rs +++ b/src/settings.rs @@ -47,6 +47,7 @@ pub struct Settings { pub megaphone_api_token: Option, pub megaphone_poll_interval: u32, pub human_logs: bool, + pub msg_limit: u32, } impl Settings { @@ -72,6 +73,7 @@ impl Settings { s.set_default("statsd_port", 8125)?; s.set_default("megaphone_poll_interval", 30)?; s.set_default("human_logs", false)?; + s.set_default("msg_limit", 100)?; // Merge the configs from the files for filename in filenames { diff --git a/tests/test_integration.py b/tests/test_integration.py index 39532a2ba..0d2a39969 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -61,6 +61,7 @@ ROUTER_TABLE = os.environ.get("ROUTER_TABLE", "router_int_test") MESSAGE_TABLE = os.environ.get("MESSAGE_TABLE", "message_int_test") +MSG_LIMIT = 20 CRYPTO_KEY = Fernet.generate_key() CONNECTION_PORT = 9150 @@ -119,6 +120,7 @@ def setup_module(): close_handshake_timeout=5, max_connections=5000, human_logs="true", + msg_limit=MSG_LIMIT, ) rust_bin = root_dir + "/target/release/autopush_rs" possible_paths = ["/target/debug/autopush_rs", @@ -861,6 +863,26 @@ def test_with_bad_key(self): yield self.shut_down(client) + @inlineCallbacks + def test_msg_limit(self): + client = yield self.quick_register() + uaid = client.uaid + yield client.disconnect() + for i in range(MSG_LIMIT + 1): + yield client.send_notification() + yield client.connect() + yield client.hello() + assert client.uaid == uaid + for i in range(MSG_LIMIT): + result = yield client.get_notification() + assert result is not None + yield client.ack(result["channelID"], result["version"]) + yield client.disconnect() + yield client.connect() + yield client.hello() + assert client.uaid != uaid + yield self.shut_down(client) + class TestRustWebPushBroadcast(unittest.TestCase): _endpoint_defaults = dict(