diff --git a/autopush/tests/test_websocket.py b/autopush/tests/test_websocket.py index e9efffe9..e35a7df3 100644 --- a/autopush/tests/test_websocket.py +++ b/autopush/tests/test_websocket.py @@ -47,6 +47,17 @@ def tearDown(): tearDown() +def assert_called_included(mock, **kwargs): + """Like assert_called_with but asserts a call was made including + the specified kwargs (but allowing additional args/kwargs)""" + mock.assert_called() + _, mock_kwargs = mock.call_args + for name, val in kwargs.iteritems(): + if name not in mock_kwargs or mock_kwargs[name] != val: + raise AssertionError("%s not called with keyword arg %s=%s" % + (mock, name, val)) + + class WebsocketTestCase(unittest.TestCase): def setUp(self): @@ -929,6 +940,7 @@ def check_register_result(msg): eq_(msg["status"], 200) eq_(msg["messageType"], "register") assert "pushEndpoint" in msg + assert_called_included(self.proto.log.info, format="Register") d.callback(True) def check_hello_result(msg): @@ -952,6 +964,7 @@ def test_register_webpush(self): def check_register_result(msg): assert self.proto.ap_settings.message.register_channel.called + assert_called_included(self.proto.log.info, format="Register") d.callback(True) res = self.proto.process_register(dict(channelID=chid)) @@ -983,6 +996,7 @@ def check_register_result(msg, test_endpoint): eq_(test_endpoint, self.proto.sendJSON.call_args[0][0]['pushEndpoint']) assert self.proto.ap_settings.message.register_channel.called + assert_called_included(self.proto.log.info, format="Register") d.callback(True) res = self.proto.process_register( @@ -1096,6 +1110,7 @@ def throw_error(*args, **kwargs): def check_register_result(msg): eq_(msg["status"], 500) eq_(msg["messageType"], "register") + self.proto.log.failure.assert_called() d.callback(True) f = self._check_response(check_register_result) @@ -1191,6 +1206,7 @@ def check_unregister_result(msg): eq_(msg["status"], 200) eq_(msg["channelID"], chid) eq_(len(self.proto.log.mock_calls), 1) + assert_called_included(self.proto.log.info, format="Unregister") d.callback(True) def check_hello_result(msg): @@ -1260,11 +1276,13 @@ def raise_exception(*args, **kwargs): channelID=chid)) def wait_for_times(): # pragma: nocover - if len(self.proto.log.failure.mock_calls) > 0: - eq_(len(self.proto.log.failure.mock_calls), 1) + if not self.proto.log.failure.called: + reactor.callLater(0.1, wait_for_times) + else: + self.proto.log.failure.assert_called_once() + assert_called_included(self.proto.log.info, + format="Unregister") d.callback(True) - return - reactor.callLater(0.1, wait_for_times) reactor.callLater(0.1, wait_for_times) return d @@ -1348,11 +1366,10 @@ def check_hello_result(msg): # Verify it was cleared out eq_(len(self.proto.ps.direct_updates), 0) eq_(len(self.proto.log.info.mock_calls), 1) - args, kwargs = self.proto.log.info.call_args - eq_(kwargs.get('format') or args[0], "Ack") - eq_(kwargs["router_key"], "simplepush") - eq_(kwargs["message_source"], "direct") - + assert_called_included(self.proto.log.info, + format="Ack", + router_key="simplepush", + message_source="direct") d.callback(True) f = self._check_response(check_hello_result) @@ -1380,10 +1397,10 @@ def test_ack_with_webpush_direct(self): )) eq_(self.proto.ps.direct_updates[chid], []) eq_(len(self.proto.log.info.mock_calls), 1) - args, kwargs = self.proto.log.info.call_args - eq_(kwargs.get('format') or args[0], "Ack") - eq_(kwargs["router_key"], "webpush") - eq_(kwargs["message_source"], "direct") + assert_called_included(self.proto.log.info, + format="Ack", + router_key="webpush", + message_source="direct") def test_ack_with_webpush_from_storage(self): self._connect() @@ -1405,10 +1422,10 @@ def test_ack_with_webpush_from_storage(self): assert self.proto.force_retry.called assert mock_defer.addBoth.called eq_(len(self.proto.log.info.mock_calls), 1) - args, kwargs = self.proto.log.info.call_args - eq_(kwargs.get('format') or args[0], "Ack") - eq_(kwargs["router_key"], "webpush") - eq_(kwargs["message_source"], "stored") + assert_called_included(self.proto.log.info, + format="Ack", + router_key="webpush", + message_source="stored") def test_nack(self): self._connect() diff --git a/autopush/websocket.py b/autopush/websocket.py index adfec15c..37e3833b 100644 --- a/autopush/websocket.py +++ b/autopush/websocket.py @@ -1010,6 +1010,7 @@ def error_register(self, fail): self.transport.resumeProducing() msg = {"messageType": "register", "status": 500} self.sendJSON(msg) + self.log_failure(fail, extra="Failed to register") def finish_register(self, endpoint, chid): """callback for successful endpoint creation, sends register reply""" @@ -1034,6 +1035,10 @@ def send_register_finish(self, result, endpoint, chid): self.sendJSON(msg) self.ps.metrics.increment("updates.client.register", tags=self.base_tags) + self.log.info(format="Register", channelID=chid, + endpoint=endpoint, + uaid_hash=self.ps.uaid_hash, + user_agent=self.ps.user_agent) def process_unregister(self, data): """Process an unregister message""" @@ -1048,12 +1053,12 @@ def process_unregister(self, data): self.ps.metrics.increment("updates.client.unregister", tags=self.base_tags) - # Log out the unregister if it has a code in it + event = dict(format="Unregister", channelID=chid, + uaid_hash=self.ps.uaid_hash, + user_agent=self.ps.user_agent) if "code" in data: - code = extract_code(data) - self.log.info(format="Unregister", channelID=chid, - uaid_hash=self.ps.uaid_hash, - user_agent=self.ps.user_agent, code=code) + event["code"] = extract_code(data) + self.log.info(**event) # Clear out any existing tracked messages for this channel if self.ps.use_webpush: