diff --git a/autopush/tests/test_websocket.py b/autopush/tests/test_websocket.py index 6b7444eb..0c657760 100644 --- a/autopush/tests/test_websocket.py +++ b/autopush/tests/test_websocket.py @@ -6,7 +6,7 @@ from collections import defaultdict import twisted.internet.base -from autopush.tests.test_db import make_webpush_notification +from autobahn.twisted.util import sleep from boto.dynamodb2.exceptions import ( ProvisionedThroughputExceededException, ItemNotFound @@ -17,7 +17,11 @@ from nose.tools import assert_raises, eq_, ok_ from txstatsd.metrics.metrics import Metrics from twisted.internet import reactor -from twisted.internet.defer import Deferred +from twisted.internet.defer import ( + inlineCallbacks, + returnValue, + Deferred +) from twisted.internet.error import ConnectError from twisted.trial import unittest @@ -26,6 +30,7 @@ from autopush.settings import AutopushSettings from autopush.tests import MockAssist from autopush.utils import WebPushNotification +from autopush.tests.test_db import make_webpush_notification from autopush.websocket import ( PushState, PushServerFactory, @@ -136,36 +141,36 @@ def _connect(self): def _send_message(self, msg): self.proto.onMessage(json.dumps(msg).encode('utf8'), False) - def _wait_for_message(self, d, count=0.0): - args = self.send_mock.call_args_list - if len(args) < 1: - if count > 5.0: # pragma: nocover - try: - raise Exception("Timeout waiting for a message to send") - except: - d.errback() - reactor.callLater(0.1, self._wait_for_message, d, count+0.1) - return - - args = self.send_mock.call_args_list.pop(0) - return d.callback(args) - - def _wait_for_close(self, d): # pragma: nocover - if self.close_mock.call_args is not None: - d.callback(self.close_mock.call_args) - return - - reactor.callLater(0.1, self._wait_for_close, d) - - def _check_response(self, func): - """Waits for a message to be sent, and runs the func with it""" - def handle_message(result): - args, _ = result - func(json.loads(args[0])) - d = Deferred() - d.addCallback(handle_message) - self._wait_for_message(d) - return d + @inlineCallbacks + def _wait_for_close(self): + """Wait for a sendClose call""" + result = yield self._wait_for(lambda: self.close_mock.call_args) + returnValue(result) + + @inlineCallbacks + def _wait_for(self, predicate, duration=5, delay=0.1): + """Wait for predicate() to succeed, returning its result""" + start = time.time() + while (time.time() - start) < duration: + result = predicate() + if result: + returnValue(result) + yield sleep(delay) + else: # pragma: nocover + raise Exception("Timeout waiting for a message to send") + + @inlineCallbacks + def get_response(self): + """Wait for a JSON message to be received. + + Returns the message as a dict. + + """ + calls = self.send_mock.call_args_list + yield self._wait_for(lambda: len(calls)) + args = calls.pop(0) + msg = args[0][0] + returnValue(json.loads(msg)) def test_exc_catcher(self): req = Mock() @@ -318,52 +323,38 @@ def check_subbed(s): self.proto.processHandshake() self.proto.log_failure.assert_called() + @inlineCallbacks def test_binary_msg(self): self.proto.onMessage(b"asdfasdf", True) - d = Deferred() - d.addCallback(lambda x: True) - self._wait_for_close(d) - return d + yield self._wait_for_close() + @inlineCallbacks def test_not_dict(self): self.proto.onMessage("[]", False) - d = Deferred() - d.addCallback(lambda x: True) - self._wait_for_close(d) - return d + yield self._wait_for_close() + @inlineCallbacks def test_bad_json(self): self.proto.onMessage("}{{bad_json!!", False) - d = Deferred() - d.addCallback(lambda x: True) - self._wait_for_close(d) - return d + yield self._wait_for_close() + @inlineCallbacks def test_no_messagetype_after_hello(self): self._connect() self.proto.ps.uaid = dummy_uaid_str self._send_message(dict(data="wassup")) + close_args = yield self._wait_for_close() + _, kwargs = close_args + eq_(len(kwargs), 0) - def check_result(close_args): - _, kwargs = close_args - eq_(len(kwargs), 0) - d = Deferred() - d.addCallback(check_result) - self._wait_for_close(d) - return d - + @inlineCallbacks def test_unknown_messagetype(self): self._connect() self.proto.ps.uaid = dummy_uaid_str self._send_message(dict(messageType="wassup")) - - def check_result(close_args): - _, kwargs = close_args - eq_(len(kwargs), 0) - d = Deferred() - d.addCallback(check_result) - self._wait_for_close(d) - return d + close_args = yield self._wait_for_close() + _, kwargs = close_args + eq_(len(kwargs), 0) def test_close_with_cleanup(self): self._connect() @@ -379,6 +370,7 @@ def test_close_with_cleanup(self): name, _, _ = notif_mock.mock_calls[0] eq_(name, "cancel") + @inlineCallbacks def test_close_with_delivery_cleanup(self): self._connect() self.proto.ps.uaid = dummy_uaid_str @@ -396,19 +388,10 @@ def test_close_with_delivery_cleanup(self): # Close the connection self.proto.onClose(True, None, None) + yield self._wait_for(lambda: mock_agent.mock_calls) + self.flushLoggedErrors() - d = Deferred() - - def wait_for_agent_call(): # pragma: nocover - if not mock_agent.mock_calls: - reactor.callLater(0.1, wait_for_agent_call) - return - - self.flushLoggedErrors() - d.callback(True) - reactor.callLater(0.1, wait_for_agent_call) - return d - + @inlineCallbacks def test_close_with_delivery_cleanup_using_webpush(self): self._connect() self.proto.ps.uaid = dummy_uaid.hex @@ -426,19 +409,10 @@ def test_close_with_delivery_cleanup_using_webpush(self): # Close the connection self.proto.onClose(True, None, None) + yield self._wait_for(lambda: mock_agent.mock_calls) + self.flushLoggedErrors() - d = Deferred() - - def wait_for_agent_call(): # pragma: nocover - if not mock_agent.mock_calls: - reactor.callLater(0.1, wait_for_agent_call) - return - - self.flushLoggedErrors() - d.callback(True) - reactor.callLater(0.1, wait_for_agent_call) - return d - + @inlineCallbacks def test_close_with_delivery_cleanup_and_get_no_result(self): self._connect() self.proto.ps.uaid = uuid.uuid4().hex @@ -456,19 +430,12 @@ def test_close_with_delivery_cleanup_and_get_no_result(self): # Close the connection self.proto.onClose(True, None, None) + yield self._wait_for(lambda: len(mock_metrics.mock_calls) > 2) + eq_(len(mock_metrics.mock_calls), 3) + mock_metrics.increment.assert_called_with( + "client.notify_uaid_failure", tags=None) - d = Deferred() - - def wait_for_agent_call(): # pragma: nocover - if not mock_metrics.mock_calls: - reactor.callLater(0.1, wait_for_agent_call) - - mock_metrics.increment.assert_called_with( - "client.notify_uaid_failure", tags=None) - d.callback(True) - reactor.callLater(0.1, wait_for_agent_call) - return d - + @inlineCallbacks def test_close_with_delivery_cleanup_and_get_uaid_error(self): self._connect() self.proto.ps.uaid = uuid.uuid4().hex @@ -490,19 +457,12 @@ def raise_item(*args, **kwargs): # Close the connection self.proto.onClose(True, None, None) + yield self._wait_for(lambda: len(mock_metrics.mock_calls) > 2) + eq_(len(mock_metrics.mock_calls), 3) + mock_metrics.increment.assert_called_with( + "client.lookup_uaid_failure", tags=None) - d = Deferred() - - def wait_for_agent_call(): # pragma: nocover - if not mock_metrics.mock_calls: - reactor.callLater(0.1, wait_for_agent_call) - - mock_metrics.increment.assert_called_with( - "client.lookup_uaid_failure", tags=None) - d.callback(True) - reactor.callLater(0.1, wait_for_agent_call) - return d - + @inlineCallbacks def test_close_with_delivery_cleanup_and_no_node_id(self): self._connect() self.proto.ps.uaid = uuid.uuid4().hex @@ -520,17 +480,9 @@ def test_close_with_delivery_cleanup_and_no_node_id(self): # Close the connection self.proto.onClose(True, None, None) + yield self._wait_for(lambda: mock_node_get.mock_calls) - d = Deferred() - - def wait_for_agent_call(): # pragma: nocover - if not mock_node_get.mock_calls: - reactor.callLater(0.1, wait_for_agent_call) - - d.callback(True) - reactor.callLater(0.1, wait_for_agent_call) - return d - + @inlineCallbacks def test_hello_old(self): orig_uaid = "deadbeef00000000abad1dea00000000" @@ -583,16 +535,15 @@ def fake_msg(data): channelIDs=[], use_webpush=True)) - def check_result(msg): - eq_(self.proto.ps.rotate_message_table, False) - # it's fine you've not connected in a while, but - # you should recycle your endpoints since they're probably - # invalid by now anyway. - eq_(msg["status"], 200) - ok_(msg["uaid"] != orig_uaid) - - return self._check_response(check_result) + msg = yield self.get_response() + eq_(self.proto.ps.rotate_message_table, False) + # it's fine you've not connected in a while, but you should + # recycle your endpoints since they're probably invalid by now + # anyway. + eq_(msg["status"], 200) + ok_(msg["uaid"] != orig_uaid) + @inlineCallbacks def test_hello_tomorrow(self): orig_uaid = "deadbeef00000000abad1dea00000000" router = self.proto.ap_settings.router @@ -640,33 +591,18 @@ def fake_msg(data): uaid=orig_uaid, channelIDs=[], use_webpush=True)) - - d = Deferred() - - def check_rotation(time_spent): - if time_spent > 3: # pragma: nocover - d.errback(Exception("Failed to rotate message table")) - - if self.proto.ps.rotate_message_table: # pragma: nocover - reactor.callLater(0.2, check_rotation, 0.2 + time_spent) - return - - eq_(self.proto.ps.rotate_message_table, False) - d.callback(True) - - def check_result(msg): - # it's fine you've not connected in a while, but - # you should recycle your endpoints since they're probably - # invalid by now anyway. - eq_(msg["status"], 200) - eq_(msg["uaid"], orig_uaid) - - # Wait to see that the message table gets rotated - reactor.callLater(0.2, check_rotation, 0.2) - - self._check_response(check_result) - return d - + msg = yield self.get_response() + # it's fine you've not connected in a while, but you should + # recycle your endpoints since they're probably invalid by now + # anyway. + eq_(msg["status"], 200) + eq_(msg["uaid"], orig_uaid) + + # Wait to see that the message table gets rotated + yield self._wait_for(lambda: not self.proto.ps.rotate_message_table) + eq_(self.proto.ps.rotate_message_table, False) + + @inlineCallbacks def test_hello_tomorrow_provision_error(self): orig_uaid = "deadbeef00000000abad1dea00000000" router = self.proto.ap_settings.router @@ -726,22 +662,8 @@ def raise_error(*args): uaid=orig_uaid, channelIDs=[], use_webpush=True)) - - d = Deferred() - d.addBoth(lambda x: patch_range.stop()) - - def check_rotation(time_spent): - if time_spent > 3: # pragma: nocover - d.errback(Exception("Failed to rotate message table")) - - if self.proto.ps.rotate_message_table: # pragma: nocover - reactor.callLater(0.2, check_rotation, 0.2 + time_spent) - return - - eq_(self.proto.ps.rotate_message_table, False) - d.callback(True) - - def check_result(msg): + try: + msg = yield self.get_response() # it's fine you've not connected in a while, but # you should recycle your endpoints since they're probably # invalid by now anyway. @@ -749,57 +671,47 @@ def check_result(msg): eq_(msg["uaid"], orig_uaid) # Wait to see that the message table gets rotated - reactor.callLater(0.2, check_rotation, 0.2) - - self._check_response(check_result) - return d + yield self._wait_for( + lambda: not self.proto.ps.rotate_message_table + ) + eq_(self.proto.ps.rotate_message_table, False) + finally: + patch_range.stop() + @inlineCallbacks def test_hello(self): self._connect() self._send_message(dict(messageType="hello", channelIDs=[])) - def check_result(msg): - eq_(msg["status"], 200) - return self._check_response(check_result) + msg = yield self.get_response() + eq_(msg["status"], 200) + @inlineCallbacks def test_hello_webpush_uses_one_db_call(self): db.TRACK_DB_CALLS = True db.DB_CALLS = [] self._connect() self._send_message(dict(messageType="hello", use_webpush=True, channelIDs=[])) + msg = yield self.get_response() + yield self._wait_for(lambda: len(db.DB_CALLS) > 2, duration=3) + eq_(db.DB_CALLS, + ['register_user', 'fetch_messages', 'fetch_timestamp_messages']) + eq_(msg["status"], 200) + db.DB_CALLS = [] + db.TRACK_DB_CALLS = False - d = Deferred() - - def check_result(msg, duration=0): - if len(db.DB_CALLS) < 3: # pragma: nocover - if duration > 3.0: # pragma: nocover - raise Exception("db calls isn't 3 yet") - else: - reactor.callLater(0.1, check_result, msg, duration+0.1) - return - - eq_(db.DB_CALLS, ['register_user', 'fetch_messages', - 'fetch_timestamp_messages']) - eq_(msg["status"], 200) - db.DB_CALLS = [] - db.TRACK_DB_CALLS = False - d.callback(True) - f = self._check_response(check_result) - f.addErrback(lambda x: d.callback(True)) - return d - + @inlineCallbacks def test_hello_with_webpush(self): self._connect() self._send_message(dict(messageType="hello", use_webpush=True, channelIDs=[])) - - def check_result(msg): - eq_(msg["status"], 200) - ok_("use_webpush" in msg) eq_(self.proto.base_tags, ['use_webpush:True']) - return self._check_response(check_result) + msg = yield self.get_response() + eq_(msg["status"], 200) + ok_("use_webpush" in msg) + @inlineCallbacks def test_hello_with_missing_router_type(self): self._connect() uaid = uuid.uuid4().hex @@ -811,12 +723,11 @@ def test_hello_with_missing_router_type(self): self._send_message(dict(messageType="hello", channelIDs=[], uaid=uaid)) + msg = yield self.get_response() + eq_(msg["status"], 200) + ok_(msg["uaid"] != uaid) - def check_result(msg): - eq_(msg["status"], 200) - ok_(msg["uaid"] != uaid) - return self._check_response(check_result) - + @inlineCallbacks def test_hello_with_missing_current_month(self): self._connect() uaid = uuid.uuid4().hex @@ -828,12 +739,11 @@ def test_hello_with_missing_current_month(self): )) self._send_message(dict(messageType="hello", channelIDs=[], uaid=uaid, use_webpush=True)) + msg = yield self.get_response() + eq_(msg["status"], 200) + ok_(msg["uaid"] != uaid) - def check_result(msg): - eq_(msg["status"], 200) - ok_(msg["uaid"] != uaid) - return self._check_response(check_result) - + @inlineCallbacks def test_hello_with_uaid(self): self._connect() uaid = uuid.uuid4().hex @@ -845,12 +755,11 @@ def test_hello_with_uaid(self): )) self._send_message(dict(messageType="hello", channelIDs=[], uaid=uaid)) + msg = yield self.get_response() + eq_(msg["status"], 200) + eq_(msg["uaid"], uaid) - def check_result(msg): - eq_(msg["status"], 200) - eq_(msg["uaid"], uaid) - return self._check_response(check_result) - + @inlineCallbacks def test_hello_resets_record(self): self._connect() uaid = uuid.uuid4().hex @@ -862,46 +771,42 @@ def test_hello_resets_record(self): )) self._send_message(dict(messageType="hello", channelIDs=[], uaid=uaid)) + msg = yield self.get_response() + eq_(msg["status"], 200) + eq_(msg["uaid"], uaid) + eq_(self.proto.ps.reset_uaid, True) - def check_result(msg): - eq_(msg["status"], 200) - eq_(msg["uaid"], uaid) - eq_(self.proto.ps.reset_uaid, True) - return self._check_response(check_result) - + @inlineCallbacks def test_hello_with_bad_uaid(self): self._connect() uaid = "ajsidlfjlsdjflasjjailsdf" self._send_message(dict(messageType="hello", channelIDs=[], uaid=uaid)) + msg = yield self.get_response() + eq_(msg["status"], 200) + ok_(msg["uaid"] != uaid) - def check_result(msg): - eq_(msg["status"], 200) - ok_(msg["uaid"] != uaid) - return self._check_response(check_result) - + @inlineCallbacks def test_hello_with_bad_uaid_dash(self): self._connect() uaid = str(uuid.uuid4()) self._send_message(dict(messageType="hello", channelIDs=[], uaid=uaid)) + msg = yield self.get_response() + eq_(msg["status"], 200) + ok_(msg["uaid"] != uaid) - def check_result(msg): - eq_(msg["status"], 200) - ok_(msg["uaid"] != uaid) - return self._check_response(check_result) - + @inlineCallbacks def test_hello_with_bad_uaid_case(self): self._connect() uaid = uuid.uuid4().hex.upper() self._send_message(dict(messageType="hello", channelIDs=[], uaid=uaid)) + msg = yield self.get_response() + eq_(msg["status"], 200) + ok_(msg["uaid"] != uaid) - def check_result(msg): - eq_(msg["status"], 200) - ok_(msg["uaid"] != uaid) - return self._check_response(check_result) - + @inlineCallbacks def test_hello_failure(self): self._connect() # Fail out the register_user call @@ -909,14 +814,12 @@ def test_hello_failure(self): router.table.connection.update_item = Mock(side_effect=KeyError) self._send_message(dict(messageType="hello", channelIDs=[], stop=1)) + msg = yield self.get_response() + eq_(msg["status"], 503) + eq_(msg["reason"], "error") + self.flushLoggedErrors() - def check_result(msg): - eq_(msg["status"], 503) - eq_(msg["reason"], "error") - self.flushLoggedErrors() - - return self._check_response(check_result) - + @inlineCallbacks def test_hello_provisioned_during_check(self): self._connect() self.proto.randrange = Mock(return_value=0.1) @@ -929,14 +832,12 @@ def throw_error(*args, **kwargs): router.table.connection.update_item = Mock(side_effect=throw_error) self._send_message(dict(messageType="hello", channelIDs=[])) + msg = yield self.get_response() + eq_(msg["status"], 503) + eq_(msg["reason"], "error - overloaded") + self.flushLoggedErrors() - def check_result(msg): - eq_(msg["status"], 503) - eq_(msg["reason"], "error - overloaded") - self.flushLoggedErrors() - - return self._check_response(check_result) - + @inlineCallbacks def test_hello_jsonresponseerror(self): self._connect() @@ -950,14 +851,12 @@ def throw_error(*args, **kwargs): router.table.connection.update_item = Mock(side_effect=throw_error) self._send_message(dict(messageType="hello", channelIDs=[])) + msg = yield self.get_response() + eq_(msg["status"], 503) + eq_(msg["reason"], "error - overloaded") + self.flushLoggedErrors() - def check_result(msg): - eq_(msg["status"], 503) - eq_(msg["reason"], "error - overloaded") - self.flushLoggedErrors() - - return self._check_response(check_result) - + @inlineCallbacks def test_hello_check_fail(self): self._connect() @@ -966,49 +865,35 @@ def test_hello_check_fail(self): Mock(return_value=(False, {})) self._send_message(dict(messageType="hello", channelIDs=[])) + msg = yield self.get_response() + calls = self.proto.ap_settings.router.register_user.mock_calls + eq_(len(calls), 1) + eq_(msg["status"], 500) + eq_(msg["reason"], "already_connected") - def check_result(msg): - calls = self.proto.ap_settings.router.register_user.mock_calls - eq_(len(calls), 1) - eq_(msg["status"], 500) - eq_(msg["reason"], "already_connected") - return self._check_response(check_result) - + @inlineCallbacks def test_hello_dupe(self): self._connect() self._send_message(dict(messageType="hello", channelIDs=[])) + msg = yield self.get_response() + eq_(msg["status"], 200) - d = Deferred() - d.addCallback(lambda x: True) - - def check_second_hello(msg): - eq_(msg["status"], 401) - d.callback(True) - - def check_first_hello(msg): - eq_(msg["status"], 200) - # Send another hello - self._send_message(dict(messageType="hello", channelIDs=[])) - self._check_response(check_second_hello) - f = self._check_response(check_first_hello) - f.addErrback(lambda x: d.errback(x)) - return d + # Send another hello + self._send_message(dict(messageType="hello", channelIDs=[])) + msg = yield self.get_response() + eq_(msg["status"], 401) + @inlineCallbacks def test_hello_timeout(self): connected = time.time() self.proto.ap_settings.hello_timeout = 3 self._connect() + close_args = yield self._wait_for_close() + _, kwargs = close_args + eq_(len(kwargs), 0) + ok_(time.time() - connected >= 3) - def check_elapsed(close_args): - _, kwargs = close_args - eq_(len(kwargs), 0) - ok_(time.time() - connected >= 3) - - d = Deferred() - d.addCallback(check_elapsed) - self._wait_for_close(d) - return d - + @inlineCallbacks def test_hello_timeout_with_wake_timeout(self): self.proto.ap_settings.hello_timeout = 3 self.proto.ap_settings.wake_timeout = 3 @@ -1021,17 +906,12 @@ def test_hello_timeout_with_wake_timeout(self): "mnc": "banana", "netid": "gorp", "ignored": "ok"})) + close_args = yield self._wait_for_close() + ok_(ms_time() - self.proto.ps.connected_at >= 3000) + _, kwargs = close_args + eq_(kwargs, {"code": 4774, "reason": "UDP Idle"}) - def check_elapsed(close_args): - ok_(ms_time() - self.proto.ps.connected_at >= 3000) - _, kwargs = close_args - eq_(kwargs, {"code": 4774, "reason": "UDP Idle"}) - - d = Deferred() - d.addCallback(check_elapsed) - self._wait_for_close(d) - return d - + @inlineCallbacks def test_hello_udp(self): self._connect() self._send_message(dict(messageType="hello", channelIDs=[], @@ -1041,16 +921,15 @@ def test_hello_udp(self): "mnc": "banana", "netid": "gorp", "ignored": "ok"})) - - def check_result(msg): - eq_(msg["status"], 200) - route_data = self.proto.ap_settings.router.get_uaid( - msg["uaid"]).get('wake_data') - eq_(route_data, { - 'data': {"ip": "127.0.0.1", "port": 9999, "mcc": "hammer", - "mnc": "banana", "netid": "gorp"}}) - return self._check_response(check_result) - + msg = yield self.get_response() + eq_(msg["status"], 200) + route_data = self.proto.ap_settings.router.get_uaid( + msg["uaid"]).get('wake_data') + eq_(route_data, + {'data': {"ip": "127.0.0.1", "port": 9999, "mcc": "hammer", + "mnc": "banana", "netid": "gorp"}}) + + @inlineCallbacks def test_bad_hello_udp(self): self._connect() self._send_message(dict(messageType="hello", channelIDs=[], @@ -1059,70 +938,46 @@ def test_bad_hello_udp(self): "mnc": "banana", "netid": "gorp", "ignored": "ok"})) + msg = yield self.get_response() + eq_(msg["status"], 200) + ok_("wake_data" not in + self.proto.ap_settings.router.get_uaid(msg["uaid"]).keys()) - def check_result(msg): - eq_(msg["status"], 200) - ok_("wake_data" not in - self.proto.ap_settings.router.get_uaid(msg["uaid"]).keys()) - return self._check_response(check_result) - + @inlineCallbacks def test_not_hello(self): self._connect() self._send_message(dict(messageType="wooooo")) + close_args = yield self._wait_for_close() + _, kwargs = close_args + eq_(len(kwargs), 0) - def check_result(close_args): - _, kwargs = close_args - eq_(len(kwargs), 0) - d = Deferred() - d.addCallback(check_result) - self._wait_for_close(d) - return d - + @inlineCallbacks def test_hello_env(self): self._connect() self._send_message(dict(messageType="hello", channelIDs=[])) + msg = yield self.get_response() + eq_(msg["env"], "test") - def check_result(msg): - eq_(msg["env"], "test") - return self._check_response(check_result) - + @inlineCallbacks def test_ping(self): self._connect() self._send_message(dict(messageType="hello", channelIDs=[])) + msg = yield self.get_response() + eq_(msg["status"], 200) + self._send_message({}) + msg = yield self.get_response() + eq_(msg, {}) - d = Deferred() - - def check_ping_result(msg): - eq_(msg, {}) - d.callback(True) - - def check_result(msg): - eq_(msg["status"], 200) - self._send_message({}) - g = self._check_response(check_ping_result) - g.addErrback(lambda x: d.errback(x)) - - f = self._check_response(check_result) - f.addErrback(lambda x: d.errback(x)) - return d - + @inlineCallbacks def test_ping_too_much(self): self._connect() self._send_message(dict(messageType="hello", channelIDs=[])) - - d = Deferred() - - def check_result(msg): - eq_(msg["status"], 200) - self.proto.ps.last_ping = time.time() - 30 - self.proto.sendClose = Mock() - self._send_message({}) - ok_(self.proto.sendClose.called) - d.callback(True) - - f = self._check_response(check_result) - f.addErrback(lambda x: d.errback(x)) - return d + msg = yield self.get_response() + eq_(msg["status"], 200) + self.proto.ps.last_ping = time.time() - 30 + self.proto.sendClose = Mock() + self._send_message({}) + ok_(self.proto.sendClose.called) def test_auto_ping(self): self.proto.ps.ping_time_out = False @@ -1192,30 +1047,22 @@ def check_result(result): ok_(d is not None) return d + @inlineCallbacks def test_register(self): self._connect() self._send_message(dict(messageType="hello", channelIDs=[], stop=1)) + msg = yield self.get_response() + ok_("messageType" in msg) - d = Deferred() - d.addCallback(lambda x: True) - - def check_register_result(msg): - eq_(msg["status"], 200) - eq_(msg["messageType"], "register") - ok_("pushEndpoint" in msg) - assert_called_included(self.proto.log.info, format="Register") - d.callback(True) - - def check_hello_result(msg): - ok_("messageType" in msg) - self._send_message(dict(messageType="register", - channelID=str(uuid.uuid4()))) - self._check_response(check_register_result) - - f = self._check_response(check_hello_result) - f.addErrback(lambda x: d.errback(x)) - return d + self._send_message(dict(messageType="register", + channelID=str(uuid.uuid4()))) + msg = yield self.get_response() + eq_(msg["status"], 200) + eq_(msg["messageType"], "register") + ok_("pushEndpoint" in msg) + assert_called_included(self.proto.log.info, format="Register") + @inlineCallbacks def test_register_webpush(self): self._connect() self.proto.ps.use_webpush = True @@ -1223,17 +1070,11 @@ def test_register_webpush(self): self.proto.ps.uaid = uuid.uuid4().hex self.proto.ap_settings.message.register_channel = Mock() - d = Deferred() - - def check_register_result(msg): - ok_(self.proto.ap_settings.message.register_channel.called) - assert_called_included(self.proto.log.info, format="Register") - d.callback(True) - - res = self.proto.process_register(dict(channelID=chid)) - res.addCallback(check_register_result) - return d + yield self.proto.process_register(dict(channelID=chid)) + ok_(self.proto.ap_settings.message.register_channel.called) + assert_called_included(self.proto.log.info, format="Register") + @inlineCallbacks def test_register_webpush_with_key(self): self._connect() self.proto.ps.use_webpush = True @@ -1253,108 +1094,66 @@ def echo(string): self.proto.ap_settings.fernet.encrypt = echo - d = Deferred() - - def check_register_result(msg, endpoint): - eq_(endpoint, - self.proto.sendJSON.call_args[0][0]['pushEndpoint']) - ok_(self.proto.ap_settings.message.register_channel.called) - assert_called_included(self.proto.log.info, format="Register") - d.callback(True) - - res = self.proto.process_register( + yield self.proto.process_register( dict(channelID=chid, - key=base64url_encode(test_key))) - res.addCallback(check_register_result, test_endpoint) - return d + key=base64url_encode(test_key)) + ) + eq_(test_endpoint, + self.proto.sendJSON.call_args[0][0]['pushEndpoint']) + ok_(self.proto.ap_settings.message.register_channel.called) + assert_called_included(self.proto.log.info, format="Register") + @inlineCallbacks def test_register_no_chid(self): self._connect() self._send_message(dict(messageType="hello", channelIDs=[])) + msg = yield self.get_response() + ok_("messageType" in msg) - d = Deferred() - d.addCallback(lambda x: True) - - def check_register_result(msg): - eq_(msg["status"], 401) - eq_(msg["messageType"], "register") - d.callback(True) - - def check_hello_result(msg): - ok_("messageType" in msg) - self._send_message(dict(messageType="register")) - self._check_response(check_register_result) - - f = self._check_response(check_hello_result) - f.addErrback(lambda x: d.errback(x)) - return d + self._send_message(dict(messageType="register")) + msg = yield self.get_response() + eq_(msg["status"], 401) + eq_(msg["messageType"], "register") + @inlineCallbacks def test_register_bad_chid(self): self._connect() self._send_message(dict(messageType="hello", channelIDs=[])) + msg = yield self.get_response() + ok_("messageType" in msg) - d = Deferred() - d.addCallback(lambda x: True) - - def check_register_result(msg): - eq_(msg["status"], 401) - eq_(msg["messageType"], "register") - d.callback(True) - - def check_hello_result(msg): - ok_("messageType" in msg) - self._send_message(dict(messageType="register", channelID="oof")) - self._check_response(check_register_result) - - f = self._check_response(check_hello_result) - f.addErrback(lambda x: d.errback(x)) - return d + self._send_message(dict(messageType="register", channelID="oof")) + msg = yield self.get_response() + eq_(msg["status"], 401) + eq_(msg["messageType"], "register") + @inlineCallbacks def test_register_bad_chid_upper(self): self._connect() self._send_message(dict(messageType="hello", channelIDs=[])) + msg = yield self.get_response() + ok_("messageType" in msg) - d = Deferred() - d.addCallback(lambda x: True) - - def check_register_result(msg): - eq_(msg["status"], 401) - eq_(msg["messageType"], "register") - d.callback(True) - - def check_hello_result(msg): - ok_("messageType" in msg) - self._send_message(dict(messageType="register", - channelID=str(uuid.uuid4()).upper())) - self._check_response(check_register_result) - - f = self._check_response(check_hello_result) - f.addErrback(lambda x: d.errback(x)) - return d + self._send_message(dict(messageType="register", + channelID=str(uuid.uuid4()).upper())) + msg = yield self.get_response() + eq_(msg["status"], 401) + eq_(msg["messageType"], "register") + @inlineCallbacks def test_register_bad_chid_nodash(self): self._connect() self._send_message(dict(messageType="hello", channelIDs=[])) + msg = yield self.get_response() + ok_("messageType" in msg) - d = Deferred() - d.addCallback(lambda x: True) - - def check_register_result(msg): - eq_(msg["status"], 401) - eq_(msg["messageType"], "register") - d.callback(True) - - def check_hello_result(msg): - ok_("messageType" in msg) - self._send_message( - dict(messageType="register", - channelID=str(uuid.uuid4()).replace('-', ''))) - self._check_response(check_register_result) - - f = self._check_response(check_hello_result) - f.addErrback(lambda x: d.errback(x)) - return d + self._send_message(dict(messageType="register", + channelID=str(uuid.uuid4()).replace('-', ''))) + msg = yield self.get_response() + eq_(msg["status"], 401) + eq_(msg["messageType"], "register") + @inlineCallbacks def test_register_bad_crypto(self): self._connect() self.proto.ps.uaid = uuid.uuid4().hex @@ -1366,19 +1165,10 @@ def throw_error(*args, **kwargs): **{"encrypt.side_effect": throw_error}) self._send_message(dict(messageType="register", channelID=str(uuid.uuid4()))) - - d = Deferred() - d.addCallback(lambda x: True) - - def check_register_result(msg): - eq_(msg["status"], 500) - eq_(msg["messageType"], "register") - self.proto.log.failure.assert_called() - d.callback(True) - - f = self._check_response(check_register_result) - f.addErrback(lambda x: d.errback(x)) - return d + msg = yield self.get_response() + eq_(msg["status"], 500) + eq_(msg["messageType"], "register") + self.proto.log.failure.assert_called() def test_register_kill_others(self): self._connect() @@ -1408,6 +1198,7 @@ def test_register_kill_others_fail(self): d.errback(ConnectError()) return d + @inlineCallbacks def test_register_over_provisioning(self): self._connect() self.proto.ps.use_webpush = True @@ -1420,20 +1211,13 @@ def throw_provisioned(*args, **kwargs): register.side_effect = throw_provisioned - d = Deferred() - - def check_register_result(_): - ok_(self.proto.ap_settings.message.register_channel.called) - ok_(self.send_mock.called) - args, _ = self.send_mock.call_args - msg = json.loads(args[0]) - eq_(msg["messageType"], "error") - eq_(msg["reason"], "overloaded") - d.callback(True) - - res = self.proto.process_register(dict(channelID=chid)) - res.addCallback(check_register_result) - return d + yield self.proto.process_register(dict(channelID=chid)) + ok_(self.proto.ap_settings.message.register_channel.called) + ok_(self.send_mock.called) + args, _ = self.send_mock.call_args + msg = json.loads(args[0]) + eq_(msg["messageType"], "error") + eq_(msg["reason"], "overloaded") def test_check_kill_self(self): self._connect() @@ -1484,74 +1268,49 @@ def test_unregister_with_webpush(self): self.proto.process_unregister(dict(channelID=chid)) ok_(self.proto.force_retry.called) + @inlineCallbacks def test_ws_unregister(self): + chid = str(uuid.uuid4()) self._connect() self._send_message(dict(messageType="hello", channelIDs=[])) + msg = yield self.get_response() + eq_(msg["messageType"], "hello") + eq_(msg["status"], 200) - d = Deferred() - d.addCallback(lambda x: True) - chid = str(uuid.uuid4()) - - def check_unregister_result(msg): - eq_(msg["status"], 200) - eq_(msg["channelID"], chid) - eq_(len(self.proto.log.mock_calls), 2) - assert_called_included(self.proto.log.info, format="Unregister") - d.callback(True) - - def check_hello_result(msg): - eq_(msg["messageType"], "hello") - eq_(msg["status"], 200) - self._send_message(dict(messageType="unregister", - code=104, - channelID=chid)) - self._check_response(check_unregister_result) - - f = self._check_response(check_hello_result) - f.addErrback(lambda x: d.errback(x)) - return d + self._send_message(dict(messageType="unregister", + code=104, + channelID=chid)) + msg = yield self.get_response() + eq_(msg["status"], 200) + eq_(msg["channelID"], chid) + eq_(len(self.proto.log.mock_calls), 2) + assert_called_included(self.proto.log.info, format="Unregister") + @inlineCallbacks def test_ws_unregister_without_chid(self): self._connect() self.proto.ps.uaid = uuid.uuid4().hex self._send_message(dict(messageType="unregister")) + msg = yield self.get_response() + eq_(msg["status"], 401) + eq_(msg["messageType"], "unregister") - d = Deferred() - d.addCallback(lambda x: True) - - def check_unregister_result(msg): - eq_(msg["status"], 401) - eq_(msg["messageType"], "unregister") - d.callback(True) - - f = self._check_response(check_unregister_result) - f.addErrback(lambda x: d.errback(x)) - return d - + @inlineCallbacks def test_ws_unregister_bad_chid(self): self._connect() self.proto.ps.uaid = uuid.uuid4().hex self._send_message(dict(messageType="unregister", channelID="}{$@!asdf")) + msg = yield self.get_response() + eq_(msg["status"], 401) + eq_(msg["messageType"], "unregister") - d = Deferred() - - def check_unregister_result(msg): - eq_(msg["status"], 401) - eq_(msg["messageType"], "unregister") - d.callback(True) - - f = self._check_response(check_unregister_result) - f.addErrback(lambda x: d.errback(x)) - return d - + @inlineCallbacks def test_ws_unregister_fail(self): self._connect() self.proto.ps.uaid = uuid.uuid4().hex chid = str(uuid.uuid4()) - d = Deferred() - # Replace storage delete with call to fail table = self.proto.ap_settings.storage.table delete = table.delete_item @@ -1564,18 +1323,9 @@ def raise_exception(*args, **kwargs): table.delete_item = MockAssist([raise_exception, True]) self._send_message(dict(messageType="unregister", channelID=chid)) - - def wait_for_times(): # pragma: nocover - if not self.proto.log.failure.called: - reactor.callLater(0.1, wait_for_times) - else: - self.proto.log.failure.assert_called_once() - assert_called_included(self.proto.log.info, - format="Unregister") - d.callback(True) - - reactor.callLater(0.1, wait_for_times) - return d + yield self._wait_for(lambda: self.proto.log.failure.called) + self.proto.log.failure.assert_called_once() + assert_called_included(self.proto.log.info, format="Unregister") def test_notification(self): self._connect() @@ -1646,36 +1396,30 @@ def test_notification_avoid_newer_delivery(self): args = self.send_mock.call_args eq_(args, None) + @inlineCallbacks def test_ack(self): + chid = str(uuid.uuid4()) self._connect() self._send_message(dict(messageType="hello", channelIDs=[])) - d = Deferred() - chid = str(uuid.uuid4()) - # stick a notification to ack in self.proto.ps.direct_updates[chid] = 12 - def check_hello_result(msg): - eq_(msg["status"], 200) + msg = yield self.get_response() + eq_(msg["status"], 200) - # Send our ack - self._send_message(dict(messageType="ack", - updates=[{"channelID": chid, - "version": 12}])) - - # Verify it was cleared out - eq_(len(self.proto.ps.direct_updates), 0) - eq_(len(self.proto.log.info.mock_calls), 2) - assert_called_included(self.proto.log.info, - format="Ack", - router_key="simplepush", - message_source="direct") - d.callback(True) + # Send our ack + self._send_message(dict(messageType="ack", + updates=[{"channelID": chid, + "version": 12}])) - f = self._check_response(check_hello_result) - f.addErrback(lambda x: d.errback(x)) - return d + # Verify it was cleared out + eq_(len(self.proto.ps.direct_updates), 0) + eq_(len(self.proto.log.info.mock_calls), 2) + assert_called_included(self.proto.log.info, + format="Ack", + router_key="simplepush", + message_source="direct") def test_ack_with_bad_input(self): self._connect() @@ -1765,6 +1509,7 @@ def test_ack_remove_missing(self): self.proto.ps.updates_sent[dummy_chid_str] = [] self.proto._handle_webpush_update_remove(None, dummy_chid_str, notif) + @inlineCallbacks def test_ack_fails_first_time(self): self._connect() self.proto.ps.uaid = uuid.uuid4().hex @@ -1793,21 +1538,11 @@ def __call__(self, *args, **kwargs): self.proto.process_notifications = Mock() self.proto.ps._check_notifications = True - d = Deferred() - - def wait_for_delete(): # pragma: nocover - calls = self.transport_mock.mock_calls - if len(calls) < 2: - reactor.callLater(0.1, wait_for_delete) - return - - eq_(self.proto.ps.updates_sent, {}) - process_calls = self.proto.process_notifications.mock_calls - eq_(len(process_calls), 1) - d.callback(True) - - reactor.callLater(0.1, wait_for_delete) - return d + yield sleep(0.1) + yield self._wait_for(lambda: len(self.transport_mock.mock_calls) > 1) + eq_(self.proto.ps.updates_sent, {}) + process_calls = self.proto.process_notifications.mock_calls + eq_(len(process_calls), 1) def test_ack_missing_updates(self): self._connect() @@ -1871,6 +1606,7 @@ def wait(result): self.proto.ps._notification_fetch.addErrback(lambda x: d.errback(x)) return d + @inlineCallbacks def test_process_notifications_overload(self): twisted.internet.base.DelayedCall.debug = True self._connect() @@ -1894,18 +1630,10 @@ def throw_error(*args): self.proto.deferToLater = Mock() self.proto.process_notifications() - - # Tag on our own to follow up - d = Deferred() - - def wait(result): - ok_(self.proto.deferToLater.called) - ok_(mock_randrange.called) - patch_randrange.stop() - d.callback(True) - self.proto.ps._notification_fetch.addCallback(wait) - self.proto.ps._notification_fetch.addErrback(lambda x: d.errback(x)) - return d + yield self.proto.ps._notification_fetch + ok_(self.proto.deferToLater.called) + ok_(mock_randrange.called) + patch_randrange.stop() def test_process_notification_error(self): self._connect() @@ -2046,6 +1774,7 @@ def check(*args, **kwargs): self.proto.ps._notification_fetch.addBoth(check) return d + @inlineCallbacks def test_notification_results(self): # Populate the database for ourself uaid = uuid.uuid4().hex @@ -2074,34 +1803,24 @@ def test_notification_results(self): self._send_message(dict(messageType="hello", channelIDs=[], uaid=uaid)) + msg = yield self.get_response() + eq_(msg["status"], 200) + eq_(msg["messageType"], "hello") - d = Deferred() - - def check_notifs(msg): - eq_(msg["messageType"], "notification") - eq_(len(msg["updates"]), 2) - for update in msg["updates"]: - uchid = update["channelID"] - ver = update["version"] - if uchid == chid: - eq_(ver, 12) - elif uchid == chid3: - eq_(ver, 9) - ok_(uchid in [chid, chid3]) - d.callback(True) - - def check_result(msg): - eq_(msg["status"], 200) - eq_(msg["messageType"], "hello") - - # Now wait for the notification results - nd = self._check_response(check_notifs) - nd.addErrback(lambda x: d.errback(x)) - - cd = self._check_response(check_result) - cd.addErrback(lambda x: d.errback(x)) - return d - + # Now wait for the notification results + msg = yield self.get_response() + eq_(msg["messageType"], "notification") + eq_(len(msg["updates"]), 2) + for update in msg["updates"]: + uchid = update["channelID"] + ver = update["version"] + if uchid == chid: + eq_(ver, 12) + elif uchid == chid3: + eq_(ver, 9) + ok_(uchid in [chid, chid3]) + + @inlineCallbacks def test_notification_dont_deliver_after_ack(self): self._connect() @@ -2124,47 +1843,32 @@ def test_notification_dont_deliver_after_ack(self): eq_(len(results), 1) self._send_message(dict(messageType="hello", channelIDs=[], uaid=uaid)) + msg = yield self.get_response() + eq_(msg["status"], 200) - d = Deferred() - - def wait_for_clear(count=0.0): - if self.proto.ps.updates_sent: # pragma: nocover - if count > 5.0: - d.errback(Exception("Time-out waiting")) - reactor.callLater(0.1, wait_for_clear, count+0.1) - return - - # Accepting again - eq_(self.proto.ps.updates_sent, {}) - - # Check that storage is clear - notifs = storage.fetch_notifications(uaid) - eq_(len(notifs), 0) - d.callback(True) + # Now wait for the notification + msg = yield self.get_response() + eq_(msg["messageType"], "notification") + updates = msg["updates"] + eq_(len(updates), 1) + eq_(updates[0]["channelID"], chid) + eq_(updates[0]["version"], 10) - def check_notif_result(msg): - eq_(msg["messageType"], "notification") - updates = msg["updates"] - eq_(len(updates), 1) - eq_(updates[0]["channelID"], chid) - eq_(updates[0]["version"], 10) - # Send our ack - self._send_message(dict(messageType="ack", - updates=[{"channelID": chid, - "version": 10}])) - # Wait for updates to be cleared and notifications accepted again - reactor.callLater(0.1, wait_for_clear) - - def check_hello_result(msg): - eq_(msg["status"], 200) + # Send our ack + self._send_message(dict(messageType="ack", + updates=[{"channelID": chid, + "version": 10}])) + # Wait for updates to be cleared and notifications accepted again + yield sleep(0.1) + yield self._wait_for(lambda: not self.proto.ps.updates_sent) + # Accepting again + eq_(self.proto.ps.updates_sent, {}) - # Now wait for the notification - nd = self._check_response(check_notif_result) - nd.addErrback(lambda x: d.errback(x)) - f = self._check_response(check_hello_result) - f.addErrback(lambda x: d.errback(x)) - return d + # Check that storage is clear + notifs = storage.fetch_notifications(uaid) + eq_(len(notifs), 0) + @inlineCallbacks def test_notification_dont_deliver(self): # Populate the database for ourself uaid = uuid.uuid4().hex @@ -2188,50 +1892,29 @@ def test_notification_dont_deliver(self): self._connect() self._send_message(dict(messageType="hello", channelIDs=[], uaid=uaid)) + yield self.proto.ps._register + # Setup updates_sent to avoid a notification send + self.proto.ps.updates_sent[chid] = 14 - d = Deferred() - - def check_mock_call(count=0.0): - calls = self.proto.process_notifications.mock_calls - if len(calls) < 1: - if count > 5.0: # pragma: nocover - raise Exception("Time-out waiting") - reactor.callLater(0.1, check_mock_call, count+0.1) - return + # Notification check has started, indicate to check + # notifications again + self.proto.ps._check_notifications = True - eq_(len(calls), 1) - d.callback(True) + # Now replace process_notifications so it won't be + # run again + self.proto.process_notifications = Mock() - def check_call(result): - send_calls = self.send_mock.mock_calls - # There should be one, for the hello response - # No notifications should've been delivered after - # this notifiation check - eq_(len(send_calls), 1) - - # Now we wait for the mock call to run - reactor.callLater(0.1, check_mock_call) - - # Run immediately after hello was processed - def after_hello(result): - # Setup updates_sent to avoid a notification send - self.proto.ps.updates_sent[chid] = 14 - - # Notification check has started, indicate to check - # notifications again - self.proto.ps._check_notifications = True - - # Now replace process_notifications so it won't be - # run again - self.proto.process_notifications = Mock() - - # Chain our check for the call - self.proto.ps._notification_fetch.addBoth(check_call) - self.proto.ps._notification_fetch.addErrback( - lambda x: d.errback(x)) - self.proto.ps._register.addCallback(after_hello) - self.proto.ps._register.addErrback(lambda x: d.errback(x)) - return d + # Chain our check for the call + yield self.proto.ps._notification_fetch + send_calls = self.send_mock.mock_calls + # There should be one, for the hello response No notifications + # should've been delivered after this notifiation check + eq_(len(send_calls), 1) + + # Now we wait for the mock call to run + calls = self.proto.process_notifications.mock_calls + yield self._wait_for(lambda: len(calls)) + eq_(len(calls), 1) def test_incomplete_uaid(self): mm = self.proto.ap_settings.router = Mock()