Skip to content
This repository has been archived by the owner on Jul 13, 2023. It is now read-only.

Commit

Permalink
feat: handle provisioned errors gracefully
Browse files Browse the repository at this point in the history
Several db operations from websocket clients now handle provisioned exceeded
errors more gracefully. Register now returns an error message about the overload
so that a client may retry. During the initial notification check, if an error occurs the
notification check is re-scheduled to occur at a later time. Finally, if the provision
error hits during the monthly change-over, the client remains connected and the
migration is re-scheduled.

Closes #658
  • Loading branch information
bbangert committed Oct 5, 2016
1 parent 5cbfd81 commit ece6608
Show file tree
Hide file tree
Showing 4 changed files with 257 additions and 79 deletions.
17 changes: 17 additions & 0 deletions autopush/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
class MockAssist(object):
def __init__(self, results):
self.cur = 0
self.max = len(results)
self.results = results

def __call__(self, *args, **kwargs):
try:
r = self.results[self.cur]
print r
if callable(r):
return r()
else:
return r
finally:
if self.cur < (self.max - 1):
self.cur += 1
20 changes: 1 addition & 19 deletions autopush/tests/test_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
FCMRouter)
from autopush.router.interface import RouterException, RouterResponse, IRouter
from autopush.settings import AutopushSettings
from autopush.tests import MockAssist


mock_dynamodb2 = mock_dynamodb2()
Expand All @@ -46,25 +47,6 @@ def tearDown():
mock_dynamodb2.stop()


class MockAssist(object):
def __init__(self, results):
self.cur = 0
self.max = len(results)
self.results = results

def __call__(self, *args, **kwargs):
try:
r = self.results[self.cur]
print r
if callable(r):
return r()
else:
return r
finally:
if self.cur < (self.max - 1):
self.cur += 1


class RouterInterfaceTestCase(TestCase):
def test_not_implemented(self):
assert_raises(NotImplementedError, IRouter, None, None)
Expand Down
191 changes: 168 additions & 23 deletions autopush/tests/test_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
create_rotating_message_table,
)
from autopush.settings import AutopushSettings
from autopush.tests import MockAssist
from autopush.websocket import (
PushState,
PushServerProtocol,
Expand All @@ -34,8 +35,6 @@
)
from autopush.utils import base64url_encode

from .test_router import MockAssist


def setUp():
from .test_integration import setUp
Expand Down Expand Up @@ -554,6 +553,7 @@ def fake_msg(data):

mock_msg = Mock(wraps=db.Message)
mock_msg.fetch_messages.return_value = []
mock_msg.all_channels.return_value = (None, [])
self.proto.ap_settings.router.register_user = fake_msg
# massage message_tables to include our fake range
mt = self.proto.ps.settings.message_tables
Expand All @@ -571,15 +571,118 @@ def fake_msg(data):
channelIDs=[],
use_webpush=True))

d = Deferred()

def check_rotation(time_spent):
if time_spent > 3: # pragma: nocover
d.errback(Exception("Failed to rotate message table"))

if self.proto.ps.rotate_message_table: # pragma: nocover
reactor.callLater(0.2, check_rotation, 0.2 + time_spent)
return

eq_(self.proto.ps.rotate_message_table, False)
d.callback(True)

def check_result(msg):
# it's fine you've not connected in a while, but
# you should recycle your endpoints since they're probably
# invalid by now anyway.
eq_(msg["status"], 200)
eq_(msg["uaid"], orig_uaid)

# Wait to see that the message table gets rotated
reactor.callLater(0.2, check_rotation, 0.2)

self._check_response(check_result)
return d

def test_hello_tomorrow_provision_error(self):
orig_uaid = "deadbeef00000000abad1dea00000000"
router = self.proto.ap_settings.router
router.register_user(dict(
uaid=orig_uaid,
connected_at=ms_time(),
current_month="message_2016_3",
router_type="simplepush",
))

# router.register_user returns (registered, previous
target_day = datetime.date(2016, 2, 29)
msg_day = datetime.date(2016, 3, 1)
msg_date = "{}_{}_{}".format(
self.proto.ap_settings._message_prefix,
msg_day.year,
msg_day.month)
msg_data = {
"router_type": "webpush",
"node_id": "http://localhost",
"last_connect": int(msg_day.strftime("%s")),
"current_month": msg_date,
}

def fake_msg(data):
return (True, msg_data, data)

mock_msg = Mock(wraps=db.Message)
mock_msg.fetch_messages.return_value = []
mock_msg.all_channels.return_value = (None, [])
self.proto.ap_settings.router.register_user = fake_msg
# massage message_tables to include our fake range
mt = self.proto.ps.settings.message_tables
for k in mt.keys():
del(mt[k])
mt['message_2016_1'] = mock_msg
mt['message_2016_2'] = mock_msg
mt['message_2016_3'] = mock_msg

patch_range = patch("autopush.websocket.randrange")
mock_patch = patch_range.start()
mock_patch.return_value = 1

def raise_error(*args):
raise ProvisionedThroughputExceededException(None, None)

self.proto.ap_settings.router.update_message_month = MockAssist([
raise_error,
Mock(),
])

with patch.object(datetime, 'date',
Mock(wraps=datetime.date)) as patched:
patched.today.return_value = target_day
self._connect()
self._send_message(dict(messageType="hello",
uaid=orig_uaid,
channelIDs=[],
use_webpush=True))

d = Deferred()

def check_rotation(time_spent):
if time_spent > 3: # pragma: nocover
d.errback(Exception("Failed to rotate message table"))

if self.proto.ps.rotate_message_table: # pragma: nocover
reactor.callLater(0.2, check_rotation, 0.2 + time_spent)
return

eq_(self.proto.ps.rotate_message_table, False)
patch_range.stop()
d.callback(True)

def check_result(msg):
# it's fine you've not connected in a while, but
# you should recycle your endpoints since they're probably
# invalid by now anyway.
eq_(msg["status"], 200)
eq_(msg["uaid"], orig_uaid)

return self._check_response(check_result)
# Wait to see that the message table gets rotated
reactor.callLater(0.2, check_rotation, 0.2)

self._check_response(check_result)
return d

def test_hello(self):
self._connect()
Expand Down Expand Up @@ -713,7 +816,7 @@ def check_result(msg):

return self._check_response(check_result)

def test_hello_provisioned_exception(self):
def test_hello_provisioned_during_check(self):
self._connect()
self.proto.randrange = Mock(return_value=0.1)
# Fail out the register_user call
Expand Down Expand Up @@ -1183,6 +1286,33 @@ def test_register_kill_others_fail(self):
d.errback(ConnectError())
return d

def test_register_over_provisioning(self):
self._connect()
self.proto.ps.use_webpush = True
chid = str(uuid.uuid4())
self.proto.ps.uaid = uuid.uuid4().hex
self.proto.ap_settings.message.register_channel = register = Mock()

def throw_provisioned(*args, **kwargs):
raise ProvisionedThroughputExceededException(None, None)

register.side_effect = throw_provisioned

d = Deferred()

def check_register_result(_):
ok_(self.proto.ap_settings.message.register_channel.called)
ok_(self.send_mock.called)
args, _ = self.send_mock.call_args
msg = json.loads(args[0])
eq_(msg["messageType"], "error")
eq_(msg["reason"], "overloaded")
d.callback(True)

res = self.proto.process_register(dict(channelID=chid))
res.addCallback(check_register_result)
return d

def test_check_kill_self(self):
self._connect()
mock_agent = Mock()
Expand Down Expand Up @@ -1614,48 +1744,63 @@ def wait(result):
self.proto.ps._notification_fetch.addErrback(lambda x: d.errback(x))
return d

def test_process_notification_error(self):
def test_process_notifications_overload(self):
twisted.internet.base.DelayedCall.debug = True
self._connect()
self.proto.ps.uaid = uuid.uuid4().hex

def throw_error(*args, **kwargs):
raise Exception("An error happened!")
def throw_error(*args):
raise ProvisionedThroughputExceededException(None, None)

# Swap out fetch_notifications
self.proto.ap_settings.storage.fetch_notifications = MockAssist([
throw_error,
[],
])

# Start the randrange patch
patch_randrange = patch("autopush.websocket.randrange")
mock_randrange = patch_randrange.start()
mock_randrange.return_value = 0.1

# No-op the deferToLater
self.proto.deferToLater = Mock()

self.proto.ap_settings.storage = Mock(
**{"fetch_notifications.side_effect": throw_error})
self.proto.ps._check_notifications = True
self.proto.process_notifications()

# Tag on our own to follow up
d = Deferred()

def check_error(result):
eq_(self.proto.ps._check_notifications, False)
ok_(self.proto.log.failure.called)
def wait(result):
ok_(self.proto.deferToLater.called)
ok_(mock_randrange.called)
patch_randrange.stop()
d.callback(True)

self.proto.ps._notification_fetch.addBoth(check_error)
self.proto.ps._notification_fetch.addCallback(wait)
self.proto.ps._notification_fetch.addErrback(lambda x: d.errback(x))
return d

def test_process_notification_provisioned_error(self):
def test_process_notification_error(self):
self._connect()
self.proto.randrange = Mock(return_value=0.1)
self.proto.ps.uaid = uuid.uuid4().hex

def throw_error(*args, **kwargs):
raise ProvisionedThroughputExceededException(None, None)
raise Exception("An error happened!")

self.proto.ap_settings.storage = Mock(
**{"fetch_notifications.side_effect": throw_error})
self.proto.ps._check_notifications = True
self.proto.process_notifications()

def check_result(msg):
d = Deferred()

def check_error(result):
eq_(self.proto.ps._check_notifications, False)
eq_(msg["status"], 503)
eq_(msg["reason"], "error - overloaded")
self.flushLoggedErrors()
ok_(self.proto.log.failure.called)
d.callback(True)

return self._check_response(check_result)
self.proto.ps._notification_fetch.addBoth(check_error)
return d

def test_process_notif_doesnt_run_with_webpush_outstanding(self):
self._connect()
Expand Down
Loading

0 comments on commit ece6608

Please sign in to comment.