diff --git a/autopush/exceptions.py b/autopush/exceptions.py index ae775467..d63be592 100644 --- a/autopush/exceptions.py +++ b/autopush/exceptions.py @@ -33,6 +33,11 @@ class APNSException(Exception): pass +class MessageOverloadException(Exception): + """Too many messages per UAID""" + pass + + class RouterException(AutopushException): """Exception if routing has failed, may include a custom status_code and body to write to the response. diff --git a/autopush/main.py b/autopush/main.py index 50617807..840515db 100644 --- a/autopush/main.py +++ b/autopush/main.py @@ -145,6 +145,9 @@ def add_shared_args(parser): env_var="HUMAN_LOGS") parser.add_argument('--no_aws', help="Skip AWS meta information checks", action="store_true", default=False) + parser.add_argument('--msg_limit', help="Max limit for messages per uaid " + "before reset", type=int, default="100", + env_var="MSG_LIMIT") # No ENV because this is for humans add_external_router_args(parser) obsolete_args(parser) @@ -429,6 +432,7 @@ def make_settings(args, **kwargs): wake_timeout=args.wake_timeout, ami_id=ami_id, client_certs=client_certs, + msg_limit=args.msg_limit, **kwargs ) diff --git a/autopush/settings.py b/autopush/settings.py index 05c0be99..3b6f38b0 100644 --- a/autopush/settings.py +++ b/autopush/settings.py @@ -89,6 +89,7 @@ def __init__(self, preflight_uaid="deadbeef00000000deadbeef00000000", ami_id=None, client_certs=None, + msg_limit=100, ): """Initialize the Settings object @@ -170,6 +171,7 @@ def __init__(self, message_read_throughput=message_read_throughput, message_write_throughput=message_write_throughput) self._message_prefix = message_tablename + self.message_limit = msg_limit self.storage = Storage(self.storage_table, self.metrics) self.router = Router(self.router_table, self.metrics) diff --git a/autopush/tests/test_main.py b/autopush/tests/test_main.py index 505906af..39b0800b 100644 --- a/autopush/tests/test_main.py +++ b/autopush/tests/test_main.py @@ -245,6 +245,7 @@ class TestArg: fcm_auth = 'abcde' ssl_key = "keys/server.crt" ssl_cert = "keys/server.key" + msg_limit = 1000 _client_certs = dict(partner1=["1A:"*31 + "F9"], partner2=["2B:"*31 + "E8", "3C:"*31 + "D7"]) diff --git a/autopush/tests/test_websocket.py b/autopush/tests/test_websocket.py index 21a89c9c..f1bf05e0 100644 --- a/autopush/tests/test_websocket.py +++ b/autopush/tests/test_websocket.py @@ -3,6 +3,7 @@ import time import uuid from hashlib import sha256 +from collections import defaultdict import twisted.internet.base from autopush.tests.test_db import make_webpush_notification @@ -1945,6 +1946,38 @@ def test_notif_finished_with_webpush_with_old_notifications(self): ok_(self.proto.force_retry.called) ok_(not self.send_mock.called) + def test_notif_finished_with_too_many_messages(self): + self._connect() + self.proto.ps.uaid = uuid.uuid4().hex + self.proto.ps.use_webpush = True + self.proto.ps._check_notifications = True + self.proto.ps.msg_limit = 2 + self.proto.ap_settings.router.drop_user = Mock() + self.proto.ps.message.fetch_messages = Mock() + + notif = make_webpush_notification( + self.proto.ps.uaid, + dummy_chid_str, + ttl=500 + ) + self.proto.ps.updates_sent = defaultdict(lambda: []) + self.proto.ps.message.fetch_messages.return_value = ( + None, + [notif, notif, notif] + ) + + d = Deferred() + + def check(*args, **kwargs): + ok_(self.proto.ap_settings.router.drop_user.called) + ok_(self.send_mock.called) + d.callback(True) + + self.proto.force_retry = Mock() + self.proto.process_notifications() + self.proto.ps._notification_fetch.addBoth(check) + return d + def test_notification_results(self): # Populate the database for ourself uaid = uuid.uuid4().hex diff --git a/autopush/websocket.py b/autopush/websocket.py index c8d421d8..3951f001 100644 --- a/autopush/websocket.py +++ b/autopush/websocket.py @@ -76,6 +76,7 @@ dump_uaid, ) from autopush.db import Message # noqa +from autopush.exceptions import MessageOverloadException from autopush.noseplugin import track_object from autopush.protocol import IgnoreBody from autopush.utils import ( @@ -167,6 +168,8 @@ class PushState(object): 'updates_sent', 'direct_updates', + 'msg_limit', + # iProducer methods 'pauseProducing', 'resumeProducing', @@ -215,6 +218,7 @@ def __init__(self, settings, request): self._check_notifications = False self._more_notifications = False + self.msg_limit = settings.message_limit # Timestamp message defaults self.scan_timestamps = False @@ -262,6 +266,7 @@ class PushServerProtocol(WebSocketServerProtocol, policies.TimeoutMixin): parent_class = WebSocketServerProtocol randrange = randrange _log_exc = True + sent_notification_count = 0 # Defer helpers def deferToThread(self, func, *args, **kwargs): @@ -871,6 +876,9 @@ def process_notifications(self): d.addCallback(self.finish_notifications) d.addErrback(self.error_notification_overload) d.addErrback(self.trap_cancel) + d.addErrback(self.error_message_overload) + # The following errback closes the connection. It must be the last + # errback in the chain. d.addErrback(self.error_notifications) self.ps._notification_fetch = d @@ -896,6 +904,12 @@ def error_notification_overload(self, fail): d = self.deferToLater(randrange(5, 60), self.process_notifications) d.addErrback(self.trap_cancel) + def error_message_overload(self, fail): + """errBack for handling excessive messages per UAID""" + fail.trap(MessageOverloadException) + self.force_retry(self.ap_settings.router.drop_user(self.ps.uaid)) + self.sendClose() + def finish_notifications(self, notifs): """callback for processing notifications from storage""" self.ps._notification_fetch = None @@ -956,6 +970,7 @@ def finish_webpush_notifications(self, result): # No more notifications, and we've scanned timestamped. self.ps._more_notifications = False self.ps.scan_timestamps = False + self.sent_notification_count = 0 if self.ps._check_notifications: # Told to check again, start over self.ps._check_notifications = False @@ -986,6 +1001,9 @@ def finish_webpush_notifications(self, result): self.ps.updates_sent[str(notif.channel_id)].append(notif) msg = notif.websocket_format() messages_sent = True + self.sent_notification_count += 1 + if self.sent_notification_count > self.ps.msg_limit: + raise MessageOverloadException() self.sendJSON(msg) # Did we send any messages?