diff --git a/autopush/tests/__init__.py b/autopush/tests/__init__.py index e69de29b..72173082 100644 --- a/autopush/tests/__init__.py +++ b/autopush/tests/__init__.py @@ -0,0 +1,17 @@ +class MockAssist(object): + def __init__(self, results): + self.cur = 0 + self.max = len(results) + self.results = results + + def __call__(self, *args, **kwargs): + try: + r = self.results[self.cur] + print r + if callable(r): + return r() + else: + return r + finally: + if self.cur < (self.max - 1): + self.cur += 1 diff --git a/autopush/tests/test_router.py b/autopush/tests/test_router.py index 062a0e9e..6e2eceea 100644 --- a/autopush/tests/test_router.py +++ b/autopush/tests/test_router.py @@ -32,6 +32,7 @@ FCMRouter) from autopush.router.interface import RouterException, RouterResponse, IRouter from autopush.settings import AutopushSettings +from autopush.tests import MockAssist mock_dynamodb2 = mock_dynamodb2() @@ -46,25 +47,6 @@ def tearDown(): mock_dynamodb2.stop() -class MockAssist(object): - def __init__(self, results): - self.cur = 0 - self.max = len(results) - self.results = results - - def __call__(self, *args, **kwargs): - try: - r = self.results[self.cur] - print r - if callable(r): - return r() - else: - return r - finally: - if self.cur < (self.max - 1): - self.cur += 1 - - class RouterInterfaceTestCase(TestCase): def test_not_implemented(self): assert_raises(NotImplementedError, IRouter, None, None) diff --git a/autopush/tests/test_websocket.py b/autopush/tests/test_websocket.py index 49cb13f8..e2940f7c 100644 --- a/autopush/tests/test_websocket.py +++ b/autopush/tests/test_websocket.py @@ -23,6 +23,7 @@ create_rotating_message_table, ) from autopush.settings import AutopushSettings +from autopush.tests import MockAssist from autopush.websocket import ( PushState, PushServerProtocol, @@ -34,8 +35,6 @@ ) from autopush.utils import base64url_encode -from .test_router import MockAssist - def setUp(): from .test_integration import setUp @@ -554,6 +553,7 @@ def fake_msg(data): mock_msg = Mock(wraps=db.Message) mock_msg.fetch_messages.return_value = [] + mock_msg.all_channels.return_value = (None, []) self.proto.ap_settings.router.register_user = fake_msg # massage message_tables to include our fake range mt = self.proto.ps.settings.message_tables @@ -571,15 +571,117 @@ def fake_msg(data): channelIDs=[], use_webpush=True)) + d = Deferred() + + def check_rotation(time_spent): + if time_spent > 3: # pragma: nocover + d.errback(Exception("Failed to rotate message table")) + + if self.proto.ps.rotate_message_table: # pragma: nocover + reactor.callLater(0.2, check_rotation, 0.2 + time_spent) + return + + eq_(self.proto.ps.rotate_message_table, False) + d.callback(True) + def check_result(msg): + # it's fine you've not connected in a while, but + # you should recycle your endpoints since they're probably + # invalid by now anyway. + eq_(msg["status"], 200) + eq_(msg["uaid"], orig_uaid) + + # Wait to see that the message table gets rotated + reactor.callLater(0.2, check_rotation, 0.2) + + self._check_response(check_result) + return d + + def test_hello_tomorrow_provision_error(self): + orig_uaid = "deadbeef00000000abad1dea00000000" + router = self.proto.ap_settings.router + router.register_user(dict( + uaid=orig_uaid, + connected_at=ms_time(), + current_month="message_2016_3", + router_type="simplepush", + )) + + # router.register_user returns (registered, previous + target_day = datetime.date(2016, 2, 29) + msg_day = datetime.date(2016, 3, 1) + msg_date = "{}_{}_{}".format( + self.proto.ap_settings._message_prefix, + msg_day.year, + msg_day.month) + msg_data = { + "router_type": "webpush", + "node_id": "http://localhost", + "last_connect": int(msg_day.strftime("%s")), + "current_month": msg_date, + } + + def fake_msg(data): + return (True, msg_data, data) + + mock_msg = Mock(wraps=db.Message) + mock_msg.fetch_messages.return_value = [] + mock_msg.all_channels.return_value = (None, []) + self.proto.ap_settings.router.register_user = fake_msg + # massage message_tables to include our fake range + mt = self.proto.ps.settings.message_tables + mt.clear() + mt['message_2016_1'] = mock_msg + mt['message_2016_2'] = mock_msg + mt['message_2016_3'] = mock_msg + + patch_range = patch("autopush.websocket.randrange") + mock_patch = patch_range.start() + mock_patch.return_value = 1 + + def raise_error(*args): + raise ProvisionedThroughputExceededException(None, None) + + self.proto.ap_settings.router.update_message_month = MockAssist([ + raise_error, + Mock(), + ]) + + with patch.object(datetime, 'date', + Mock(wraps=datetime.date)) as patched: + patched.today.return_value = target_day + self._connect() + self._send_message(dict(messageType="hello", + uaid=orig_uaid, + channelIDs=[], + use_webpush=True)) + + d = Deferred() + d.addBoth(lambda x: patch_range.stop()) + + def check_rotation(time_spent): + if time_spent > 3: # pragma: nocover + d.errback(Exception("Failed to rotate message table")) + + if self.proto.ps.rotate_message_table: # pragma: nocover + reactor.callLater(0.2, check_rotation, 0.2 + time_spent) + return + eq_(self.proto.ps.rotate_message_table, False) + d.callback(True) + + def check_result(msg): # it's fine you've not connected in a while, but # you should recycle your endpoints since they're probably # invalid by now anyway. eq_(msg["status"], 200) eq_(msg["uaid"], orig_uaid) - return self._check_response(check_result) + # Wait to see that the message table gets rotated + reactor.callLater(0.2, check_rotation, 0.2) + + self._check_response(check_result) + return d def test_hello(self): self._connect() @@ -713,7 +815,7 @@ def check_result(msg): return self._check_response(check_result) - def test_hello_provisioned_exception(self): + def test_hello_provisioned_during_check(self): self._connect() self.proto.randrange = Mock(return_value=0.1) # Fail out the register_user call @@ -1183,6 +1285,33 @@ def test_register_kill_others_fail(self): d.errback(ConnectError()) return d + def test_register_over_provisioning(self): + self._connect() + self.proto.ps.use_webpush = True + chid = str(uuid.uuid4()) + self.proto.ps.uaid = uuid.uuid4().hex + self.proto.ap_settings.message.register_channel = register = Mock() + + def throw_provisioned(*args, **kwargs): + raise ProvisionedThroughputExceededException(None, None) + + register.side_effect = throw_provisioned + + d = Deferred() + + def check_register_result(_): + ok_(self.proto.ap_settings.message.register_channel.called) + ok_(self.send_mock.called) + args, _ = self.send_mock.call_args + msg = json.loads(args[0]) + eq_(msg["messageType"], "error") + eq_(msg["reason"], "overloaded") + d.callback(True) + + res = self.proto.process_register(dict(channelID=chid)) + res.addCallback(check_register_result) + return d + def test_check_kill_self(self): self._connect() mock_agent = Mock() @@ -1614,48 +1743,63 @@ def wait(result): self.proto.ps._notification_fetch.addErrback(lambda x: d.errback(x)) return d - def test_process_notification_error(self): + def test_process_notifications_overload(self): + twisted.internet.base.DelayedCall.debug = True self._connect() self.proto.ps.uaid = uuid.uuid4().hex - def throw_error(*args, **kwargs): - raise Exception("An error happened!") + def throw_error(*args): + raise ProvisionedThroughputExceededException(None, None) + + # Swap out fetch_notifications + self.proto.ap_settings.storage.fetch_notifications = MockAssist([ + throw_error, + [], + ]) + + # Start the randrange patch + patch_randrange = patch("autopush.websocket.randrange") + mock_randrange = patch_randrange.start() + mock_randrange.return_value = 0.1 + + # No-op the deferToLater + self.proto.deferToLater = Mock() - self.proto.ap_settings.storage = Mock( - **{"fetch_notifications.side_effect": throw_error}) - self.proto.ps._check_notifications = True self.proto.process_notifications() + # Tag on our own to follow up d = Deferred() - def check_error(result): - eq_(self.proto.ps._check_notifications, False) - ok_(self.proto.log.failure.called) + def wait(result): + ok_(self.proto.deferToLater.called) + ok_(mock_randrange.called) + patch_randrange.stop() d.callback(True) - - self.proto.ps._notification_fetch.addBoth(check_error) + self.proto.ps._notification_fetch.addCallback(wait) + self.proto.ps._notification_fetch.addErrback(lambda x: d.errback(x)) return d - def test_process_notification_provisioned_error(self): + def test_process_notification_error(self): self._connect() - self.proto.randrange = Mock(return_value=0.1) self.proto.ps.uaid = uuid.uuid4().hex def throw_error(*args, **kwargs): - raise ProvisionedThroughputExceededException(None, None) + raise Exception("An error happened!") self.proto.ap_settings.storage = Mock( **{"fetch_notifications.side_effect": throw_error}) self.proto.ps._check_notifications = True self.proto.process_notifications() - def check_result(msg): + d = Deferred() + + def check_error(result): eq_(self.proto.ps._check_notifications, False) - eq_(msg["status"], 503) - eq_(msg["reason"], "error - overloaded") - self.flushLoggedErrors() + ok_(self.proto.log.failure.called) + d.callback(True) - return self._check_response(check_result) + self.proto.ps._notification_fetch.addBoth(check_error) + return d def test_process_notif_doesnt_run_with_webpush_outstanding(self): self._connect() diff --git a/autopush/websocket.py b/autopush/websocket.py index d803001d..d0e83efd 100644 --- a/autopush/websocket.py +++ b/autopush/websocket.py @@ -583,21 +583,34 @@ def returnError(self, messageType, reason, statusCode, close=True, if close: self.sendClose() - def err_overload(self, failure, message_type): + def err_overload(self, failure, message_type, disconnect=True): """Handle database overloads - Pause producing to cease incoming notifications while we wait a random - interval up to 8 seconds before closing down the connection. Most - clients wait up to 10 seconds for a command, but this is not a - guarantee, so rather than never reply, we still shut the connection - down. + If ``disconnect`` is False, the an overload error is returned and the + client is not disconnected. + + Otherwise, pause producing to cease incoming notifications while we + wait a random interval up to 8 seconds before closing down the + connection. Most clients wait up to 10 seconds for a command, + but this is not a guarantee, so rather than never reply, we still + shut the connection down. + + :param disconnect: Whether the client should be disconnected or not. """ failure.trap(ProvisionedThroughputExceededException) - self.transport.pauseProducing() - d = self.deferToLater(self.randrange(4, 9), self.err_finish_overload, - message_type) - d.addErrback(self.trap_cancel) + + if disconnect: + self.transport.pauseProducing() + d = self.deferToLater(self.randrange(4, 9), + self.err_finish_overload, message_type) + d.addErrback(self.trap_cancel) + else: + send = {"messageType": "error", + "reason": "overloaded", + "status": 503 + } + self.sendJSON(send) def err_finish_overload(self, message_type): """Close the connection down and resume consuming input after the @@ -844,8 +857,8 @@ def process_notifications(self): d = self.deferToThread( self.ap_settings.storage.fetch_notifications, self.ps.uaid) d.addCallback(self.finish_notifications) + d.addErrback(self.error_notification_overload) d.addErrback(self.trap_cancel) - d.addErrback(self.err_overload, "notif") d.addErrback(self.error_notifications) self.ps._notification_fetch = d @@ -855,6 +868,14 @@ def error_notifications(self, fail): self.log_failure(fail) self.sendClose() + def error_notification_overload(self, fail): + """errBack for provisioned errors during notification check""" + fail.trap(ProvisionedThroughputExceededException) + # Silently ignore the error, and reschedule the notification check + # to run up to a minute in the future to distribute load farther out + d = self.deferToLater(randrange(5, 60), self.process_notifications) + d.addErrback(self.trap_cancel) + def finish_notifications(self, notifs): """callback for processing notifications from storage""" self.ps._notification_fetch = None @@ -912,8 +933,6 @@ def finish_webpush_notifications(self, notifs): # Not told to check for notifications, do we need to now rotate # the message table? if self.ps.rotate_message_table: - self.transport.pauseProducing() - self.ps.rotate_message_table = False self._rotate_message_table() return @@ -933,38 +952,53 @@ def _rotate_message_table(self): """Function to fire off a message table copy of channels + update the router current_month entry""" self.transport.pauseProducing() - d = self.deferToThread(self.ps.message.all_channels, self.ps.uaid) - d.addCallback(self._register_rotated_channels) + d = self.deferToThread(self._monthly_transition) + d.addCallback(self._finish_monthly_transition) d.addErrback(self.trap_cancel) - d.addErrback(self.err_overload, "notif") - d.addErrback(self.log_failure) + d.addErrback(self.error_monthly_rotation_overload) + d.addErrback(self.error_notifications) - def _register_rotated_channels(self, result): - """Register the channels into a new entry in the current month""" - # Update the current month now, so that we can save the channels into - # the right location + def _monthly_transition(self): + """Transition the client to use a new message month + + Utilized to migrate a users channels to a new message month and + update the router record reflecting the proper month. + + This is a blocking function that does *not* run on the event loop. + + """ + # Get the current channels for this month + _, channels = self.ps.message.all_channels(self.ps.uaid) + + # Get the current message month + cur_month = self.ap_settings.current_msg_month + if channels: + # Save the current channels into this months message table + msg_table = self.ap_settings.message_tables[cur_month] + msg_table.save_channels(self.ps.uaid, channels) + + # Finally, update the route message month + self.ap_settings.router.update_message_month(self.ps.uaid, cur_month) + + def _finish_monthly_transition(self, result): + """Mark the client as successfully transitioned and resume""" + # Update the current month now that we've moved forward a month self.ps.message_month = self.ap_settings.current_msg_month + self.ps.rotate_message_table = False + self.transport.resumeProducing() - _, channels = result - if not channels: - # No previously registered channels, skip to updating the router - # table - return self._update_router_for_message_month(None) + def error_monthly_rotation_overload(self, fail): + """Capture overload on monthly table rotation attempt - # Register the channels, then update the router - d = self.deferToThread(self.ps.message.save_channels, self.ps.uaid, - channels) - d.addCallback(self._update_router_for_message_month) - return d + If a provision exdeeded error hits while attempting monthly table + rotation, schedule it all over and re-scan the messages. Normal + websocket client flow is returned in the meantime. - def _update_router_for_message_month(self, result): - """Update the router for the message month""" - # This is returned so that the error handling in _rotate_message_table - # still applies since the deferred chain is fully followed. - d = self.deferToThread(self.ap_settings.router.update_message_month, - self.ps.uaid, self.ps.message_month) - d.addCallback(lambda x: self.transport.resumeProducing()) - return d + """ + fail.trap(ProvisionedThroughputExceededException) + self.transport.resumeProducing() + d = self.deferToLater(randrange(1, 60), self.process_notifications) + d.addErrback(self.trap_cancel) def _send_ping(self): """Helper for ping sending that tracks when the ping was sent""" @@ -1027,12 +1061,12 @@ def error_register(self, fail): def finish_register(self, endpoint, chid): """callback for successful endpoint creation, sends register reply""" if self.ps.use_webpush: - d = self.deferToThread(self.ap_settings.message.register_channel, + d = self.deferToThread(self.ps.message.register_channel, self.ps.uaid, chid) d.addCallback(self.send_register_finish, endpoint, chid) # Note: No trap_cancel needed here since the deferred here is # returned to process_register which will trap it - d.addErrback(self.err_overload, "register") + d.addErrback(self.err_overload, "register", disconnect=False) return d else: self.send_register_finish(None, endpoint, chid) @@ -1084,7 +1118,7 @@ def process_unregister(self, data): if self.ps.use_webpush: # Unregister the channel - self.force_retry(self.ap_settings.message.unregister_channel, + self.force_retry(self.ps.message.unregister_channel, self.ps.uaid, chid) else: # Delete any record from storage, we don't wait for this