diff --git a/autopush/db.py b/autopush/db.py index 393c447e..c0ede2a4 100644 --- a/autopush/db.py +++ b/autopush/db.py @@ -52,6 +52,18 @@ def hasher(uaid): return uaid +def normalize_id(id): + if not id: + return id + if (len(id) == 36 and + id[8] == id[13] == id[18] == id[23] == '-'): + return id.lower() + raw = filter(lambda x: x in '0123456789abcdef', id.lower()) + if len(raw) != 32: + raise ValueError("Invalid UUID") + return '-'.join((raw[:8], raw[8:12], raw[12:16], raw[16:20], raw[20:])) + + def make_rotating_tablename(prefix, delta=0): """Creates a tablename for table rotation based on a prefix with a given month delta.""" @@ -261,7 +273,8 @@ def save_notification(self, uaid, chid, version): cond = "attribute_not_exists(version) or version < :ver" conn.put_item( self.table.table_name, - item=self.encode(dict(uaid=hasher(uaid), chid=chid, + item=self.encode(dict(uaid=hasher(uaid), + chid=normalize_id(chid), version=version)), condition_expression=cond, expression_attribute_values={ @@ -281,10 +294,12 @@ def delete_notification(self, uaid, chid, version=None): """ try: if version: - self.table.delete_item(uaid=hasher(uaid), chid=chid, + self.table.delete_item(uaid=hasher(uaid), + chid=normalize_id(chid), expected={"version__eq": version}) else: - self.table.delete_item(uaid=hasher(uaid), chid=chid) + self.table.delete_item(uaid=hasher(uaid), + chid=normalize_id(chid)) return True except ProvisionedThroughputExceededException: self.metrics.increment("error.provisioned.delete_notification") @@ -312,7 +327,8 @@ def register_channel(self, uaid, channel_id): db_key = self.encode({"uaid": hasher(uaid), "chidmessageid": " "}) # Generate our update expression expr = "ADD chids :channel_id" - expr_values = self.encode({":channel_id": set([channel_id])}) + expr_values = self.encode({":channel_id": + set([normalize_id(channel_id)])}) conn.update_item( self.table.table_name, db_key, @@ -327,7 +343,8 @@ def unregister_channel(self, uaid, channel_id, **kwargs): conn = self.table.connection db_key = self.encode({"uaid": hasher(uaid), "chidmessageid": " "}) expr = "DELETE chids :channel_id" - expr_values = self.encode({":channel_id": set([channel_id])}) + expr_values = self.encode({":channel_id": + set([normalize_id(channel_id)])}) result = conn.update_item( self.table.table_name, @@ -375,7 +392,7 @@ def store_message(self, uaid, channel_id, message_id, ttl, data=None, with the message id""" item = dict( uaid=hasher(uaid), - chidmessageid="%s:%s" % (channel_id, message_id), + chidmessageid="%s:%s" % (normalize_id(channel_id), message_id), data=data, headers=headers, ttl=ttl, @@ -407,7 +424,7 @@ def update_message(self, uaid, channel_id, message_id, ttl, data=None, item["headers"] = headers item["data"] = data try: - chidmessageid = "%s:%s" % (channel_id, message_id) + chidmessageid = "%s:%s" % (normalize_id(channel_id), message_id) db_key = self.encode({"uaid": hasher(uaid), "chidmessageid": chidmessageid}) expr = ("SET #tl=:ttl, #ts=:timestamp," @@ -439,14 +456,16 @@ def delete_message(self, uaid, channel_id, message_id, updateid=None): try: self.table.delete_item( uaid=hasher(uaid), - chidmessageid="%s:%s" % (channel_id, message_id), + chidmessageid="%s:%s" % (normalize_id(channel_id), + message_id), expected={'updateid__eq': updateid}) except ConditionalCheckFailedException: return False else: self.table.delete_item( uaid=hasher(uaid), - chidmessageid="%s:%s" % (channel_id, message_id)) + chidmessageid="%s:%s" % (normalize_id(channel_id), + message_id)) return True def delete_messages(self, uaid, chidmessageids): @@ -463,7 +482,7 @@ def delete_messages_for_channel(self, uaid, channel_id): """Deletes all messages for a uaid/channel_id""" results = self.table.query_2( uaid__eq=hasher(uaid), - chidmessageid__beginswith="%s:" % channel_id, + chidmessageid__beginswith="%s:" % normalize_id(channel_id), consistent=True, attributes=("chidmessageid",), ) diff --git a/autopush/endpoint.py b/autopush/endpoint.py index 25d0ad60..64873bad 100644 --- a/autopush/endpoint.py +++ b/autopush/endpoint.py @@ -42,7 +42,8 @@ from autopush.db import ( generate_last_connect, - hasher + hasher, + normalize_id, ) from autopush.router.interface import RouterException from autopush.utils import ( @@ -174,8 +175,8 @@ def chid(self): @chid.setter def chid(self, value): """Set the ChannelID and record to _client_info""" - self._chid = value - self._client_info["channelID"] = value + self._chid = normalize_id(value) + self._client_info["channelID"] = self._chid def _init_info(self): """Returns a dict of additional client data""" @@ -309,7 +310,7 @@ def _invalid_auth(self, fail): log.msg("Invalid bearer token: " + message, **self._client_info) raise VapidAuthException("Invalid bearer token: " + message) - def _process_auth(self, result): + def _process_auth(self, result, keys): """Process the optional VAPID auth token. VAPID requires two headers to be present; @@ -322,28 +323,26 @@ def _process_auth(self, result): # No auth present, so it's not a VAPID call. if not authorization: return result - - header_info = parse_header(self.request.headers.get('crypto-key')) - if not header_info: + if not keys: raise VapidAuthException("Missing Crypto-Key") - values = header_info[-1] - if isinstance(values, dict): - crypto_key = values.get('p256ecdsa') - try: - (auth_type, token) = authorization.split(' ', 1) - except ValueError: - raise VapidAuthException("Invalid Authorization header") - # if it's a bearer token containing what may be a JWT - if auth_type.lower() == AUTH_SCHEME and '.' in token: - d = deferToThread(extract_jwt, token, crypto_key) - d.addCallback(self._store_auth, crypto_key, token, result) - d.addErrback(self._invalid_auth) - return d - # otherwise, it's not, so ignore the VAPID data. - return result - else: + if not isinstance(keys, dict): + raise VapidAuthException("Invalid Crypto-Key") + crypto_key = keys.get('p256ecdsa') + if not crypto_key: raise VapidAuthException("Invalid bearer token: " "improperly specified crypto-key") + try: + (auth_type, token) = authorization.split(' ', 1) + except ValueError: + raise VapidAuthException("Invalid Authorization header") + # if it's a bearer token containing what may be a JWT + if auth_type.lower() == AUTH_SCHEME and '.' in token: + d = deferToThread(extract_jwt, token, crypto_key) + d.addCallback(self._store_auth, crypto_key, token, result) + d.addErrback(self._invalid_auth) + return d + # otherwise, it's not, so ignore the VAPID data. + return result class MessageHandler(AutoendpointHandler): @@ -404,17 +403,27 @@ class EndpointHandler(AutoendpointHandler): # Cyclone HTTP Methods ############################################################# @cyclone.web.asynchronous - def put(self, token): + def put(self, api_ver="v0", token=None): """HTTP PUT Handler Primary entry-point to handling a notification for a push client. """ self.start_time = time.time() - fernet = self.ap_settings.fernet - - d = deferToThread(fernet.decrypt, token.encode('utf8')) - d.addCallback(self._process_auth) + public_key = None + keys = {} + crypto_key = self.request.headers.get('crypto-key') + if crypto_key: + header_info = parse_header(crypto_key) + keys = header_info[-1] + if isinstance(keys, dict): + public_key = keys.get('p256ecdsa') + + d = deferToThread(self.ap_settings.parse_endpoint, + token, + api_ver, + public_key) + d.addCallback(self._process_auth, keys) d.addCallback(self._token_valid) d.addErrback(self._auth_err) d.addErrback(self._token_err) @@ -426,11 +435,10 @@ def put(self, token): ############################################################# def _token_valid(self, result): """Called after the token is decrypted successfully""" - info = result.split(":") - if len(info) != 2: + if len(result) != 2: raise ValueError("Wrong subscription token components") - self.uaid, self.chid = info + self.uaid, self.chid = result d = deferToThread(self.ap_settings.router.get_uaid, self.uaid) d.addCallback(self._uaid_lookup_results) d.addErrback(self._uaid_not_found_err) diff --git a/autopush/logging.py b/autopush/logging.py index 420dad7d..cc0dbf9d 100644 --- a/autopush/logging.py +++ b/autopush/logging.py @@ -15,7 +15,8 @@ import sys import raven -from eliot import add_destination, fields, Logger, MessageType +from eliot import (add_destination, fields, + Logger, MessageType) from twisted.python.log import textFromEventDict, startLoggingWithObserver HOSTNAME = socket.getfqdn() diff --git a/autopush/main.py b/autopush/main.py index 1f541099..d9579ecb 100644 --- a/autopush/main.py +++ b/autopush/main.py @@ -467,7 +467,8 @@ def endpoint_main(sysargs=None, use_files=True): # Endpoint HTTP router site = cyclone.web.Application([ - (r"/push/([^\/]+)", EndpointHandler, dict(ap_settings=settings)), + (r"/push/(?:(v\d+)\/)?([^\/]+)", EndpointHandler, + dict(ap_settings=settings)), (r"/m/([^\/]+)", MessageHandler, dict(ap_settings=settings)), # PUT /register/ => connect info # GET /register/uaid => chid + endpoint diff --git a/autopush/router/webpush.py b/autopush/router/webpush.py index 88c38197..a4f9ab4b 100644 --- a/autopush/router/webpush.py +++ b/autopush/router/webpush.py @@ -19,6 +19,7 @@ from autopush.protocol import IgnoreBody from autopush.router.interface import RouterException, RouterResponse from autopush.router.simple import SimpleRouter +from autopush.db import normalize_id TTL_URL = "https://webpush-wg.github.io/webpush-protocol/#rfc.section.6.2" @@ -72,7 +73,8 @@ def preflight_check(self, uaid, channel_id): self.ap_settings.message_tables[month_table].all_channels, uaid=uaid) - if not exists or channel_id not in chans: + if (not exists or channel_id.lower() not + in map(lambda x: normalize_id(x), chans)): raise RouterException("No such subscription", status_code=404, log_exception=False, errno=106) returnValue(month_table) @@ -84,7 +86,8 @@ def _send_notification(self, uaid, node_id, notification): headers for the notification. """ - payload = {"channelID": notification.channel_id, + # Firefox currently requires channelIDs to be '-' formatted. + payload = {"channelID": normalize_id(notification.channel_id), "version": notification.version, "ttl": notification.ttl or 0, "timestamp": int(time.time()), diff --git a/autopush/settings.py b/autopush/settings.py index 8a8daec4..41af7fa1 100644 --- a/autopush/settings.py +++ b/autopush/settings.py @@ -2,7 +2,10 @@ import datetime import socket +from hashlib import sha256 + from cryptography.fernet import Fernet, MultiFernet +from cryptography.hazmat.primitives import constant_time from twisted.internet import reactor from twisted.internet.defer import ( inlineCallbacks, @@ -232,6 +235,7 @@ def update_rotating_tables(self): self.current_msg_month = message_table.table_name self.message_tables[self.current_msg_month] = \ Message(message_table, self.metrics) + returnValue(True) def update(self, **kwargs): @@ -248,7 +252,57 @@ def update(self, **kwargs): else: setattr(self, key, val) - def make_endpoint(self, uaid, chid): - """ Create an endpoint from the identifiers""" - return self.endpoint_url + '/push/' + \ - self.fernet.encrypt((uaid + ':' + chid).encode('utf8')) + def make_endpoint(self, uaid, chid, key=None): + """Create an v1 or v2 endpoint from the indentifiers. + + Both endpoints use bytes instead of hex to reduce ID length. + v0 is uaid.hex + ':' + chid.hex and is deprecated. + v1 is the uaid + chid + v2 is the uaid + chid + sha256(key).bytes + + :param uaid: User Agent Identifier + :param chid: Channel or Subscription ID + :param key: Optional provided Public Key + :returns: Push endpoint + + """ + root = self.endpoint_url + '/push/' + base = (uaid.replace('-', '').decode("hex") + + chid.replace('-', '').decode("hex")) + + if key is None: + return root + 'v1/' + self.fernet.encrypt(base).strip('=') + + return root + 'v2/' + self.fernet.encrypt(base + sha256(key).digest()) + + def parse_endpoint(self, token, version="v0", public_key=None): + """Parse an endpoint into component elements of UAID, CHID and optional + key hash if v2 + + :param token: The obscured subscription data. + :param version: This is the API version of the token. + :param public_key: the public key (from Encryption-Key: p256ecdsa=) + + :raises ValueError: In the case of a malformed endpoint. + + :returns: a tuple containing the (UAID, CHID) + + """ + + token = self.fernet.decrypt(token.encode('utf8')) + + if version == 'v0': + if ':' not in token: + raise ValueError("Corrupted push token") + return tuple(token.split(':')) + if version == 'v1' and len(token) != 32: + raise ValueError("Corrupted push token") + if version == 'v2': + if len(token) != 64: + raise ValueError("Corrupted push token") + if not public_key: + raise ValueError("Invalid key data") + if not constant_time.bytes_eq(sha256(public_key).digest(), + token[32:]): + raise ValueError("Key mismatch") + return (token[:16].encode('hex'), token[16:32].encode('hex')) diff --git a/autopush/tests/test_db.py b/autopush/tests/test_db.py index 622fdc96..1cb74104 100644 --- a/autopush/tests/test_db.py +++ b/autopush/tests/test_db.py @@ -28,6 +28,8 @@ mock_db2 = mock_dynamodb2() +dummy_uaid = str(uuid.UUID("abad1dea00000000aabbccdd00000000")) +dummy_chid = str(uuid.UUID("deadbeef00000000decafbad00000000")) def setUp(): @@ -106,7 +108,7 @@ def raise_error(*args, **kwargs): raise ConditionalCheckFailedException(None, None) storage.table.connection.put_item.side_effect = raise_error - result = storage.save_notification("fdas", "asdf", 8) + result = storage.save_notification(dummy_uaid, dummy_chid, 8) eq_(result, False) def test_fetch_over_provisioned(self): @@ -119,7 +121,7 @@ def raise_error(*args, **kwargs): storage.table.connection.put_item.side_effect = raise_error with self.assertRaises(ProvisionedThroughputExceededException): - storage.save_notification("asdf", "asdf", 12) + storage.save_notification(dummy_uaid, dummy_chid, 12) def test_save_over_provisioned(self): s = get_storage_table() @@ -131,7 +133,7 @@ def raise_error(*args, **kwargs): storage.table.query_2.side_effect = raise_error with self.assertRaises(ProvisionedThroughputExceededException): - storage.fetch_notifications("asdf") + storage.fetch_notifications(dummy_uaid) def test_delete_over_provisioned(self): s = get_storage_table() @@ -142,7 +144,7 @@ def raise_error(*args, **kwargs): raise ProvisionedThroughputExceededException(None, None) storage.table.connection.delete_item.side_effect = raise_error - results = storage.delete_notification("asdf", "asdf") + results = storage.delete_notification(dummy_uaid, dummy_chid) eq_(results, False) @@ -195,7 +197,7 @@ def test_unregister(self): m.connection.update_item.return_value = { 'Attributes': {'uaid': {'S': self.uaid}}, 'ConsumedCapacityUnits': 0.5} - r = message.unregister_channel(self.uaid, "test") + r = message.unregister_channel(self.uaid, dummy_chid) eq_(r, False) def test_all_channels(self): @@ -232,7 +234,7 @@ def test_save_channels(self): def test_all_channels_no_uaid(self): m = get_rotating_message_table() message = Message(m, SinkMetrics()) - exists, chans = message.all_channels("asdf") + exists, chans = message.all_channels(dummy_uaid) assert(chans == set([])) def test_message_storage(self): @@ -322,7 +324,7 @@ def raise_condition(*args, **kwargs): message.table = Mock() message.table.delete_item.side_effect = raise_condition - result = message.delete_message(uaid="asdf", channel_id="asdf", + result = message.delete_message(uaid=dummy_uaid, channel_id=dummy_chid, message_id="asdf", updateid="asdf") eq_(result, False) @@ -415,7 +417,7 @@ def raise_error(*args, **kwargs): router.table.connection.update_item.side_effect = raise_error with self.assertRaises(ProvisionedThroughputExceededException): - router.register_user(dict(uaid="asdf", node_id="me", + router.register_user(dict(uaid=dummy_uaid, node_id="me", connected_at=1234)) def test_clear_node_provision_failed(self): @@ -428,7 +430,8 @@ def raise_error(*args, **kwargs): router.table.connection.put_item.side_effect = raise_error with self.assertRaises(ProvisionedThroughputExceededException): - router.clear_node(Item(r, dict(uaid="asdf", connected_at="1234", + router.clear_node(Item(r, dict(uaid=dummy_uaid, + connected_at="1234", node_id="asdf"))) def test_save_uaid(self): @@ -465,7 +468,7 @@ def raise_condition(*args, **kwargs): router.table.connection = Mock() router.table.connection.update_item.side_effect = raise_condition - router_data = dict(uaid="asdf", node_id="asdf", connected_at=1234) + router_data = dict(uaid=dummy_uaid, node_id="asdf", connected_at=1234) result = router.register_user(router_data) eq_(result, (False, {}, router_data)) @@ -474,18 +477,18 @@ def test_node_clear(self): router = Router(r, SinkMetrics()) # Register a node user - router.register_user(dict(uaid="asdf", node_id="asdf", + router.register_user(dict(uaid=dummy_uaid, node_id="asdf", connected_at=1234)) # Verify - user = router.get_uaid("asdf") + user = router.get_uaid(dummy_uaid) eq_(user["node_id"], "asdf") # Clear router.clear_node(user) # Verify - user = router.get_uaid("asdf") + user = router.get_uaid(dummy_uaid) eq_(user.get("node_id"), None) def test_node_clear_fail(self): @@ -497,7 +500,7 @@ def raise_condition(*args, **kwargs): router.table.connection.put_item = Mock() router.table.connection.put_item.side_effect = raise_condition - data = dict(uaid="asdf", node_id="asdf", connected_at=1234) + data = dict(uaid=dummy_uaid, node_id="asdf", connected_at=1234) result = router.clear_node(Item(r, data)) eq_(result, False) diff --git a/autopush/tests/test_endpoint.py b/autopush/tests/test_endpoint.py index 1dda1c04..20582e5e 100644 --- a/autopush/tests/test_endpoint.py +++ b/autopush/tests/test_endpoint.py @@ -6,6 +6,8 @@ import base64 import random +from hashlib import sha256 + import ecdsa import twisted.internet.base from cryptography.fernet import Fernet, InvalidToken @@ -37,6 +39,8 @@ from autopush.utils import (generate_hash, decipher_public_key) mock_dynamodb2 = mock_dynamodb2() +dummy_uaid = str(uuid.UUID("abad1dea00000000aabbccdd00000000")) +dummy_chid = str(uuid.UUID("deadbeef00000000decafbad00000000")) def setUp(): @@ -228,13 +232,14 @@ def test_parse_header(self): result = utils.parse_header('fee; fie="foe"; fum=""foobar="five""";') eq_(result[0], 'fee') eq_(result[1], {'fum': 'foobar="five"', 'fie': 'foe'}) + eq_([], utils.parse_header("")) def test_uaid_lookup_results(self): fresult = dict(router_type="test") frouter = Mock(spec=Router) frouter.route_notification = Mock() frouter.route_notification.return_value = RouterResponse() - self.endpoint.chid = "fred" + self.endpoint.chid = dummy_chid self.request_mock.headers["encryption"] = "stuff" self.request_mock.headers["content-encoding"] = "aes128" self.endpoint.ap_settings.routers["test"] = frouter @@ -251,7 +256,7 @@ def test_uaid_lookup_results_bad_ttl(self): frouter = Mock(spec=Router) frouter.route_notification = Mock() frouter.route_notification.return_value = RouterResponse() - self.endpoint.chid = "fred" + self.endpoint.chid = dummy_chid self.request_mock.headers["ttl"] = "woops" self.request_mock.headers["encryption"] = "stuff" self.request_mock.headers["content-encoding"] = "aes128" @@ -270,7 +275,7 @@ def test_webpush_ttl_too_large(self): frouter = Mock(spec=Router) frouter.route_notification = Mock() frouter.route_notification.return_value = RouterResponse() - self.endpoint.chid = "fred" + self.endpoint.chid = dummy_chid self.request_mock.headers["ttl"] = str(MAX_TTL + 100) self.request_mock.headers["encryption"] = "stuff" self.request_mock.headers["content-encoding"] = "aes128" @@ -328,12 +333,12 @@ def handle_finish(result): self.endpoint.set_status.assert_called_with(400) self.finish_deferred.addCallback(handle_finish) - self.endpoint.post(dummy_uaid) + self.endpoint.post(None, dummy_uaid) return self.finish_deferred def test_webpush_bad_routertype(self): fresult = dict(router_type="fred") - self.endpoint.chid = "fred" + self.endpoint.chid = dummy_chid self.request_mock.body = b"stuff" self.endpoint._uaid_lookup_results(fresult) @@ -347,7 +352,7 @@ def test_webpush_uaid_lookup_no_crypto_headers_with_data(self): fresult = dict(router_type="webpush") frouter = self.settings.routers["webpush"] frouter.route_notification.return_value = RouterResponse() - self.endpoint.chid = "fred" + self.endpoint.chid = dummy_chid self.request_mock.body = b"stuff" self.endpoint._uaid_lookup_results(fresult) @@ -362,7 +367,7 @@ def test_webpush_uaid_lookup_no_crypto_headers_without_data(self): fresult = dict(router_type="webpush") frouter = self.settings.routers["webpush"] frouter.route_notification.return_value = RouterResponse() - self.endpoint.chid = "fred" + self.endpoint.chid = dummy_chid self.endpoint._uaid_lookup_results(fresult) def handle_finish(value): @@ -377,7 +382,7 @@ def test_other_uaid_lookup_no_crypto_headers(self): frouter = Mock(spec=Router) frouter.route_notification = Mock() frouter.route_notification.return_value = RouterResponse() - self.endpoint.chid = "fred" + self.endpoint.chid = dummy_chid self.endpoint.ap_settings.routers["test"] = frouter self.endpoint._uaid_lookup_results(fresult) @@ -391,7 +396,7 @@ def test_webpush_payload_encoding(self): fresult = dict(router_type="webpush") frouter = self.settings.routers["webpush"] frouter.route_notification.return_value = RouterResponse() - self.endpoint.chid = "fred" + self.endpoint.chid = dummy_chid self.request_mock.headers["encryption"] = "stuff" self.request_mock.headers["content-encoding"] = "aes128" self.request_mock.body = b"\xc3\x28\xa0\xa1" @@ -401,7 +406,7 @@ def handle_finish(value): calls = frouter.route_notification.mock_calls eq_(len(calls), 1) (_, (notification, _), _) = calls[0] - eq_(notification.channel_id, "fred") + eq_(notification.channel_id, dummy_chid) eq_(notification.data, b"wyigoQ==") self.finish_deferred.addCallback(handle_finish) @@ -412,7 +417,7 @@ def test_other_payload_encoding(self): frouter = Mock(spec=Router) frouter.route_notification = Mock() frouter.route_notification.return_value = RouterResponse() - self.endpoint.chid = "fred" + self.endpoint.chid = dummy_chid self.endpoint.ap_settings.routers["test"] = frouter self.request_mock.body = b"stuff" @@ -422,7 +427,7 @@ def handle_finish(value): calls = frouter.route_notification.mock_calls eq_(len(calls), 1) (_, (notification, _), _) = calls[0] - eq_(notification.channel_id, "fred") + eq_(notification.channel_id, dummy_chid) eq_(notification.data, b"stuff") self.finish_deferred.addCallback(handle_finish) @@ -444,10 +449,10 @@ def test_init_info(self): self.request_mock.headers["authorization"] = "bearer token fred" d = self.endpoint._init_info() eq_(d["authorization"], "bearer token fred") - self.endpoint.uaid = "faa" - eq_(self.endpoint._client_info["uaid_hash"], hasher("faa")) - self.endpoint.chid = "fie" - eq_(self.endpoint._client_info['channelID'], "fie") + self.endpoint.uaid = dummy_uaid + eq_(self.endpoint._client_info["uaid_hash"], hasher(dummy_uaid)) + self.endpoint.chid = dummy_chid + eq_(self.endpoint._client_info['channelID'], dummy_chid) def test_load_params_arguments(self): args = self.endpoint.request.arguments @@ -510,7 +515,7 @@ def handle_finish(result): self.endpoint.set_status.assert_called_with(413) self.finish_deferred.addCallback(handle_finish) - self.endpoint.put('') + self.endpoint.put(None, '') return self.finish_deferred @patch_logger @@ -525,7 +530,7 @@ def handle_finish(value): self._assert_error_response(value) self.finish_deferred.addCallback(handle_finish) - self.endpoint.put('') + self.endpoint.put(None, '') return self.finish_deferred def test_put_token_invalid(self): @@ -537,7 +542,7 @@ def handle_finish(result): self.status_mock.assert_called_with(400) self.finish_deferred.addCallback(handle_finish) - self.endpoint.put('') + self.endpoint.put(None, '') return self.finish_deferred def test_put_token_wrong(self): @@ -548,7 +553,7 @@ def handle_finish(result): self.status_mock.assert_called_with(400) self.finish_deferred.addCallback(handle_finish) - self.endpoint.put('') + self.endpoint.put(None, '') return self.finish_deferred def _throw_item_not_found(self, item): @@ -569,7 +574,12 @@ def handle_finish(result): self.endpoint.version, self.endpoint.data = 789, None - self.endpoint._token_valid('123:456') + exc = self.assertRaises(ValueError, + self.endpoint._token_valid, + ['invalid']) + eq_(exc.message, "Wrong subscription token components") + + self.endpoint._token_valid(['123', dummy_chid]) return self.finish_deferred def test_put_default_router(self): @@ -582,7 +592,7 @@ def handle_finish(result): self.endpoint.set_status.assert_called_with(200) self.finish_deferred.addCallback(handle_finish) - self.endpoint.put(dummy_uaid) + self.endpoint.put(None, dummy_uaid) return self.finish_deferred def test_put_router_with_headers(self): @@ -605,7 +615,7 @@ def handle_finish(result): self.endpoint.set_status.assert_called_with(200) self.finish_deferred.addCallback(handle_finish) - self.endpoint.put(dummy_uaid) + self.endpoint.put(None, dummy_uaid) return self.finish_deferred def test_put_router_needs_change(self): @@ -625,7 +635,7 @@ def handle_finish(result): assert(self.router_mock.register_user.called) self.finish_deferred.addCallback(handle_finish) - self.endpoint.put(dummy_uaid) + self.endpoint.put(None, dummy_uaid) return self.finish_deferred def test_put_router_needs_update(self): @@ -645,7 +655,7 @@ def handle_finish(result): assert(self.router_mock.register_user.called) self.finish_deferred.addCallback(handle_finish) - self.endpoint.put(dummy_uaid) + self.endpoint.put(None, dummy_uaid) return self.finish_deferred def test_put_bogus_headers(self): @@ -669,7 +679,7 @@ def handle_finish(result): self.endpoint.set_status.assert_called_with(400) self.finish_deferred.addBoth(handle_finish) - self.endpoint.put(dummy_uaid) + self.endpoint.put(None, dummy_uaid) return self.finish_deferred def test_put_invalid_vapid_crypto_header(self): @@ -693,10 +703,10 @@ def handle_finish(result): self.endpoint.set_status.assert_called_with(401) self.finish_deferred.addCallback(handle_finish) - self.endpoint.put(dummy_uaid) + self.endpoint.put(None, dummy_uaid) return self.finish_deferred - def test_put_invalid_vapid_auth_header(self): + def test_put_invalid_vapid_crypto_key(self): self.request_mock.headers["encryption"] = "ignored" self.request_mock.headers["content-encoding"] = 'text' self.request_mock.headers["authorization"] = "invalid" @@ -717,7 +727,31 @@ def handle_finish(result): self.endpoint.set_status.assert_called_with(401) self.finish_deferred.addCallback(handle_finish) - self.endpoint.put(dummy_uaid) + self.endpoint.put(None, dummy_uaid) + return self.finish_deferred + + def test_put_invalid_vapid_auth_header(self): + self.request_mock.headers["encryption"] = "ignored" + self.request_mock.headers["content-encoding"] = 'text' + self.request_mock.headers["authorization"] = "invalid" + self.request_mock.headers["crypto-key"] = "p256ecdsa=crap" + self.request_mock.body = b' ' + self.fernet_mock.decrypt.return_value = "123:456" + self.router_mock.get_uaid.return_value = dict( + router_type="webpush", + router_data=dict(), + ) + self.wp_router_mock.route_notification.return_value = RouterResponse( + status_code=200, + router_data={}, + ) + + def handle_finish(result): + self.assertTrue(result) + self.endpoint.set_status.assert_called_with(401) + + self.finish_deferred.addCallback(handle_finish) + self.endpoint.put(None, dummy_uaid) return self.finish_deferred def test_put_missing_vapid_crypto_header(self): @@ -740,7 +774,7 @@ def handle_finish(result): self.endpoint.set_status.assert_called_with(401) self.finish_deferred.addCallback(handle_finish) - self.endpoint.put(dummy_uaid) + self.endpoint.put(None, dummy_uaid) return self.finish_deferred def test_post_webpush_with_headers_in_response(self): @@ -764,7 +798,7 @@ def handle_finish(result): "Location", "Somewhere") self.finish_deferred.addCallback(handle_finish) - self.endpoint.post(dummy_uaid) + self.endpoint.post(None, dummy_uaid) return self.finish_deferred def _gen_jwt(self, header, payload): @@ -812,7 +846,7 @@ def handle_finish(result, crypto_key, token): self.assertTrue(result) self.finish_deferred.addCallback(handle_finish, crypto_key, token) - self.endpoint.post(dummy_uaid) + self.endpoint.post(None, dummy_uaid) return self.finish_deferred def test_decipher_public_key(self): @@ -865,7 +899,7 @@ def handle_finish(result): self.assertTrue(result) self.finish_deferred.addCallback(handle_finish) - self.endpoint.post(dummy_uaid) + self.endpoint.post(None, dummy_uaid) return self.finish_deferred def test_post_webpush_with_bad_vapid_auth(self): @@ -897,7 +931,7 @@ def handle_finish(result): self.assertTrue(result) self.finish_deferred.addCallback(handle_finish) - self.endpoint.post(dummy_uaid) + self.endpoint.post(None, dummy_uaid) return self.finish_deferred def test_post_webpush_no_sig(self): @@ -932,7 +966,7 @@ def handle_finish(result): self.finish_deferred.addCallback(handle_finish) - self.endpoint.post(dummy_uaid) + self.endpoint.post(None, dummy_uaid) return self.finish_deferred def test_util_extract_jwt(self): @@ -970,7 +1004,7 @@ def handle_finish(result): self.endpoint.set_status.assert_called_with(401) self.finish_deferred.addCallback(handle_finish) - self.endpoint.post(dummy_uaid) + self.endpoint.post(None, dummy_uaid) return self.finish_deferred def test_post_webpush_bad_exp(self): @@ -1002,7 +1036,7 @@ def handle_finish(result): self.endpoint.set_status.assert_called_with(401) self.finish_deferred.addCallback(handle_finish) - self.endpoint.post(dummy_uaid) + self.endpoint.post(None, dummy_uaid) return self.finish_deferred def test_post_webpush_with_auth(self): @@ -1027,7 +1061,7 @@ def handle_finish(result): "Location", "Somewhere") self.finish_deferred.addCallback(handle_finish) - self.endpoint.post(dummy_uaid) + self.endpoint.post(None, dummy_uaid) return self.finish_deferred def test_post_webpush_with_logged_delivered(self): @@ -1058,7 +1092,7 @@ def handle_finish(result): log_patcher.stop() self.finish_deferred.addCallback(handle_finish) - self.endpoint.post(dummy_uaid) + self.endpoint.post(None, dummy_uaid) return self.finish_deferred def test_post_webpush_with_logged_stored(self): @@ -1089,7 +1123,7 @@ def handle_finish(result): log_patcher.stop() self.finish_deferred.addCallback(handle_finish) - self.endpoint.post(dummy_uaid) + self.endpoint.post(None, dummy_uaid) return self.finish_deferred @patch("twisted.python.log") @@ -1120,7 +1154,7 @@ def handle_finish(result): self.endpoint.set_status.assert_called_with(503) self.finish_deferred.addCallback(handle_finish) - self.endpoint.post(dummy_uaid) + self.endpoint.post(None, dummy_uaid) return self.finish_deferred def test_put_db_error(self): @@ -1132,7 +1166,7 @@ def handle_finish(result): self.endpoint.set_status.assert_called_with(503) self.finish_deferred.addCallback(handle_finish) - self.endpoint.put(dummy_uaid) + self.endpoint.put(None, dummy_uaid) return self.finish_deferred def test_cors(self): @@ -1221,9 +1255,84 @@ def test_padding(self): eq_(utils.fix_padding("ab"), "ab==") eq_(utils.fix_padding("abcd"), "abcd") + def test_parse_endpoint(self): + v0_valid = dummy_uaid + ":" + dummy_chid + uaid_strip = dummy_uaid.replace('-', '') + chid_strip = dummy_chid.replace('-', '') + uaid_dec = uaid_strip.decode('hex') + chid_dec = chid_strip.decode('hex') + v1_valid = uaid_dec + chid_dec + pub_key = uuid.uuid4().hex + v2_valid = sha256(pub_key).digest() + v2_invalid = sha256(uuid.uuid4().hex).digest() + # v0 good + self.fernet_mock.decrypt.return_value = v0_valid + tokens = self.settings.parse_endpoint('/valid') + eq_(tokens, (dummy_uaid, dummy_chid)) + + # v0 bad + self.fernet_mock.decrypt.return_value = v1_valid + exc = self.assertRaises(ValueError, + self.settings.parse_endpoint, + '/invalid') + eq_(exc.message, 'Corrupted push token') + + self.fernet_mock.decrypt.return_value = v1_valid[:30] + exc = self.assertRaises(ValueError, + self.settings.parse_endpoint, + 'invalid', 'v1') + eq_(exc.message, 'Corrupted push token') + + self.fernet_mock.decrypt.return_value = v1_valid + tokens = self.settings.parse_endpoint('valid', 'v1') + eq_(tokens, (uaid_strip, chid_strip)) + + self.fernet_mock.decrypt.return_value = v1_valid + v2_valid + tokens = self.settings.parse_endpoint('valid', 'v2', pub_key) + eq_(tokens, (uaid_strip, chid_strip)) + + self.fernet_mock.decrypt.return_value = v1_valid + "invalid" + exc = self.assertRaises(ValueError, + self.settings.parse_endpoint, + 'invalid', 'v2', pub_key) + eq_(exc.message, "Corrupted push token") + + self.fernet_mock.decrypt.return_value = v1_valid + v2_valid + exc = self.assertRaises(ValueError, + self.settings.parse_endpoint, + 'invalid', 'v2', pub_key[:30]) + eq_(exc.message, "Key mismatch") + + self.fernet_mock.decrypt.return_value = v1_valid + v2_invalid + exc = self.assertRaises(ValueError, + self.settings.parse_endpoint, + 'invalid', 'v2') + eq_(exc.message, "Invalid key data") + + self.fernet_mock.decrypt.return_value = v1_valid + v2_invalid + exc = self.assertRaises(ValueError, + self.settings.parse_endpoint, + 'invalid', 'v2', pub_key) + eq_(exc.message, "Key mismatch") + + def test_make_endpoint(self): + + def echo(val): + return val.encode('hex') + + # make a v1 endpoint: + self.fernet_mock.encrypt = echo + strip_uaid = dummy_uaid.replace('-', '') + strip_chid = dummy_chid.replace('-', '') + dummy_key = "RandomKeyString" + sha = sha256(dummy_key).hexdigest() + ep = self.settings.make_endpoint(dummy_uaid, dummy_chid) + eq_(ep, 'http://localhost/push/v1/' + strip_uaid + strip_chid) + ep = self.settings.make_endpoint(dummy_uaid, dummy_chid, + "RandomKeyString") + eq_(ep, 'http://localhost/push/v2/' + strip_uaid + strip_chid + sha) + -dummy_uaid = "00000000123412341234567812345678" -dummy_chid = "11111111123412341234567812345678" CORS_HEAD = "POST,PUT,DELETE" @@ -1378,9 +1487,9 @@ def handle_finish(value): ok_(call_args is not None) args = call_args[0] call_arg = json.loads(args[0]) - eq_(call_arg["uaid"], dummy_uaid) + eq_(call_arg["uaid"], dummy_uaid.replace('-', '')) eq_(call_arg["channelID"], dummy_chid) - eq_(call_arg["endpoint"], "http://localhost/push/abcd123") + eq_(call_arg["endpoint"], "http://localhost/push/v1/abcd123") ok_("secret" in call_arg) self.finish_deferred.addCallback(handle_finish) @@ -1417,9 +1526,9 @@ def handle_finish(value): ok_(call_args is not None) args = call_args[0] call_arg = json.loads(args[0]) - eq_(call_arg["uaid"], dummy_uaid) + eq_(call_arg["uaid"], dummy_uaid.replace('-', '')) eq_(call_arg["channelID"], dummy_chid) - eq_(call_arg["endpoint"], "http://localhost/push/abcd123") + eq_(call_arg["endpoint"], "http://localhost/push/v1/abcd123") calls = self.reg.ap_settings.router.register_user.call_args call_args = calls[0][0] eq_(True, has_connected_this_month(call_args)) @@ -1475,7 +1584,7 @@ def handle_finish(value): args = call_args[0] call_arg = json.loads(args[0]) eq_(call_arg["channelID"], dummy_chid) - eq_(call_arg["endpoint"], "http://localhost/push/abcd123") + eq_(call_arg["endpoint"], "http://localhost/push/v1/abcd123") self.finish_deferred.addCallback(handle_finish) self.reg.request.headers["Authorization"] = self.auth @@ -1530,7 +1639,7 @@ def handle_finish(value): args = call_args[0] call_arg = json.loads(args[0]) eq_(call_arg["channelID"], dummy_chid) - eq_(call_arg["endpoint"], "http://localhost/push/abcd123") + eq_(call_arg["endpoint"], "http://localhost/push/v1/abcd123") self.finish_deferred.addCallback(handle_finish) self.reg.request.headers["Authorization"] = self.auth @@ -1555,7 +1664,7 @@ def handle_finish(value): args = call_args[0] call_arg = json.loads(args[0]) eq_(call_arg["channelID"], dummy_chid) - eq_(call_arg["endpoint"], "http://localhost/push/abcd123") + eq_(call_arg["endpoint"], "http://localhost/push/v1/abcd123") self.finish_deferred.addCallback(handle_finish) self.reg.request.headers["Authorization"] = self.auth @@ -1613,23 +1722,24 @@ def test_delete_chid(self): dummy_chid, "1", 10000) - messages.register_channel(dummy_uaid, "test") + chid2 = str(uuid.uuid4()) + messages.register_channel(dummy_uaid, chid2) messages.store_message( dummy_uaid, - "test", + chid2, "2", 10000) self.reg.request.headers["Authorization"] = self.auth - def handle_finish(value): + def handle_finish(value, chid2): ml = messages.fetch_messages(dummy_uaid) cl = messages.all_channels(dummy_uaid) eq_(len(ml), 1) - eq_((True, set(['test'])), cl) + eq_((True, set([chid2])), cl) messages.delete_user(dummy_uaid) - self.finish_deferred.addCallback(handle_finish) - self.reg.delete("simplepush", "test", uaid=dummy_uaid, chid=dummy_chid) + self.finish_deferred.addCallback(handle_finish, chid2) + self.reg.delete("simplepush", "test", dummy_uaid, dummy_chid) return self.finish_deferred def test_delete_bad_chid(self): @@ -1647,11 +1757,12 @@ def handle_finish(value): messages.delete_user(dummy_uaid) self.finish_deferred.addCallback(handle_finish) - self.reg.delete("test", "test", uaid=dummy_uaid, chid="invalid") + self.reg.delete("test", "test", dummy_uaid, "invalid") return self.finish_deferred def test_delete_uaid(self): messages = self.reg.ap_settings.message + chid2 = str(uuid.uuid4()) messages.store_message( dummy_uaid, dummy_chid, @@ -1659,15 +1770,14 @@ def test_delete_uaid(self): 10000) messages.store_message( dummy_uaid, - "test", + chid2, "2", 10000) self.reg.ap_settings.router.drop_user = Mock() self.reg.ap_settings.router.drop_user.return_value = True - def handle_finish(value): + def handle_finish(value, chid2): ml = messages.fetch_messages(dummy_uaid) - cl = messages.all_channels(dummy_uaid) eq_(len(ml), 0) # Note: Router is mocked, so the UAID is never actually # dropped. Normally, this should messages.all_channels @@ -1675,11 +1785,10 @@ def handle_finish(value): ok_(self.reg.ap_settings.router.drop_user.called) eq_(self.reg.ap_settings.router.drop_user.call_args_list[0][0], (dummy_uaid,)) - eq_((True, set(["test"])), cl) - self.finish_deferred.addCallback(handle_finish) + self.finish_deferred.addCallback(handle_finish, chid2) self.reg.request.headers["Authorization"] = self.auth - self.reg.delete("simplepush", "test", uaid=dummy_uaid) + self.reg.delete("simplepush", "test", dummy_uaid) return self.finish_deferred def test_delete_bad_uaid(self): @@ -1689,7 +1798,7 @@ def handle_finish(value): self.reg.set_status.assert_called_with(401) self.finish_deferred.addCallback(handle_finish) - self.reg.delete("test", "test", uaid="invalid") + self.reg.delete("test", "test", "invalid") return self.finish_deferred def test_delete_orphans(self): @@ -1701,7 +1810,7 @@ def handle_finish(value): self.router_mock.drop_user = Mock() self.router_mock.drop_user.return_value = False self.finish_deferred.addCallback(handle_finish) - self.reg.delete("test", "test", uaid=dummy_uaid) + self.reg.delete("test", "test", dummy_uaid) return self.finish_deferred def test_delete_bad_auth(self, *args): @@ -1711,7 +1820,7 @@ def handle_finish(value): self.reg.set_status.assert_called_with(401) self.finish_deferred.addCallback(handle_finish) - self.reg.delete("test", "test", uaid=dummy_uaid) + self.reg.delete("test", "test", dummy_uaid) return self.finish_deferred def test_delete_bad_router(self): @@ -1721,7 +1830,7 @@ def handle_finish(value): self.reg.set_status.assert_called_with(400) self.finish_deferred.addCallback(handle_finish) - self.reg.delete("invalid", "test", uaid=dummy_uaid) + self.reg.delete("invalid", "test", dummy_uaid) return self.finish_deferred def test_validate_auth(self): diff --git a/autopush/tests/test_integration.py b/autopush/tests/test_integration.py index e8dbf4e8..b6ebbc5f 100644 --- a/autopush/tests/test_integration.py +++ b/autopush/tests/test_integration.py @@ -90,7 +90,7 @@ def _get_vapid(key=None, payload=None): key = ecdsa.SigningKey.generate(curve=ecdsa.NIST256p) vk = key.get_verifying_key() auth = jws.sign(payload, key, algorithm="ES256").strip('=') - crypto_key = base64.urlsafe_b64encode(vk.to_string()).strip('=') + crypto_key = base64.urlsafe_b64encode('\4' + vk.to_string()).strip('=') return {"auth": auth, "crypto-key": crypto_key, "key": key} @@ -146,9 +146,11 @@ def hello(self, uaid=None): eq_(result["status"], 200) return result - def register(self, chid=None): + def register(self, chid=None, key=None): chid = chid or str(uuid.uuid4()) - msg = json.dumps(dict(messageType="register", channelID=chid)) + msg = json.dumps(dict(messageType="register", + channelID=chid, + key=key)) log.debug("Send: %s", msg) self.ws.send(msg) result = json.loads(self.ws.recv()) @@ -186,20 +188,17 @@ def delete_notification(self, channel, message=None, status=204): def send_notification(self, channel=None, version=None, data=None, use_header=True, status=None, ttl=200, - timeout=0.2, vapid=False): + timeout=0.2, vapid=None): if not channel: channel = random.choice(self.channels.keys()) endpoint = self.channels[channel] url = urlparse.urlparse(endpoint) http = None - vapid_info = None if url.scheme == "https": # pragma: nocover http = httplib.HTTPSConnection(url.netloc) else: http = httplib.HTTPConnection(url.netloc) - if vapid: - vapid_info = _get_vapid() if self.use_webpush: headers = {} @@ -214,9 +213,9 @@ def send_notification(self, channel=None, version=None, data=None, }) if vapid: headers.update({ - "Authorization": "Bearer " + vapid_info.get('auth') + "Authorization": "Bearer " + vapid.get('auth') }) - ckey = 'p256ecdsa="' + vapid_info.get('crypto-key') + '"' + ckey = 'p256ecdsa="' + vapid.get('crypto-key') + '"' headers.update({ 'Crypto-Key': headers.get('Crypto-Key') + ';' + ckey }) @@ -362,7 +361,8 @@ def setUp(self): # Endpoint HTTP router site = cyclone.web.Application([ - (r"/push/([^\/]+)", EndpointHandler, dict(ap_settings=settings)), + (r"/push/(v\d+)?/?([^\/]+)", EndpointHandler, + dict(ap_settings=settings)), (r"/m/([^\/]+)", MessageHandler, dict(ap_settings=settings)), # PUT /register/ => connect info # GET /register/uaid => chid + endpoint @@ -693,7 +693,8 @@ def test_basic_delivery(self): def test_basic_delivery_with_vapid(self): data = str(uuid.uuid4()) client = yield self.quick_register(use_webpush=True) - result = yield client.send_notification(data=data, vapid=True) + vapid_info = _get_vapid() + result = yield client.send_notification(data=data, vapid=vapid_info) eq_(result["headers"]["encryption"], client._crypto_key) eq_(result["data"], urlsafe_b64encode(data)) eq_(result["messageType"], "notification") @@ -1250,6 +1251,34 @@ def test_webpush_monthly_rotation_no_channels(self): yield self.shut_down(client) + @inlineCallbacks + def test_with_key(self): + private_key = ecdsa.SigningKey.generate(curve=ecdsa.NIST256p) + claims = {"aud": "http://example.com", + "exp": int(time.time()) + 86400, + "sub": "a@example.com"} + vapid = _get_vapid(private_key, claims) + pk_hex = vapid['crypto-key'] + chid = str(uuid.uuid4()) + client = Client("ws://localhost:9010/", use_webpush=True) + yield client.connect() + yield client.hello() + yield client.register(chid=chid, key=pk_hex) + # check that the client actually registered the key. + + # Send an update with a properly formatted key. + yield client.send_notification(vapid=vapid) + + # now try an invalid key. + new_key = ecdsa.SigningKey.generate(curve=ecdsa.NIST256p) + vapid = _get_vapid(new_key, claims) + + yield client.send_notification( + vapid=vapid, + status=400) + + yield self.shut_down(client) + class TestHealth(IntegrationBase): @inlineCallbacks diff --git a/autopush/tests/test_websocket.py b/autopush/tests/test_websocket.py index e8f72e07..6fbac597 100644 --- a/autopush/tests/test_websocket.py +++ b/autopush/tests/test_websocket.py @@ -1,6 +1,7 @@ import json import time import uuid +from hashlib import sha256 import twisted.internet.base from boto.dynamodb2.exceptions import ( @@ -475,7 +476,7 @@ def test_hello_failure(self): router = self.proto.ap_settings.router router.table.connection.update_item = Mock(side_effect=KeyError) - self._send_message(dict(messageType="hello", channelIDs=[])) + self._send_message(dict(messageType="hello", channelIDs=[], stop=1)) def check_result(msg): eq_(msg["status"], 503) @@ -761,7 +762,7 @@ def check_result(result): def test_register(self): self._connect() - self._send_message(dict(messageType="hello", channelIDs=[])) + self._send_message(dict(messageType="hello", channelIDs=[], stop=1)) d = Deferred() d.addCallback(lambda x: True) @@ -799,6 +800,39 @@ def check_register_result(msg): res.addCallback(check_register_result) return d + def test_register_webpush_with_key(self): + self._connect() + self.proto.ps.use_webpush = True + chid = str(uuid.uuid4()) + self.proto.ps.uaid = str(uuid.uuid4()) + self.proto.ap_settings.message.register_channel = Mock() + test_key = "SomeRandomCryptoKeyString" + test_sha = sha256(test_key).hexdigest() + test_endpoint = ('http://localhost/push/v2/' + + self.proto.ps.uaid.replace('-', '') + + chid.replace('-', '') + + test_sha) + self.proto.sendJSON = Mock() + + def echo(str): + return str.encode('hex') + + self.proto.ap_settings.fernet.encrypt = echo + + d = Deferred() + + def check_register_result(msg, test_endpoint): + eq_(test_endpoint, + self.proto.sendJSON.call_args[0][0]['pushEndpoint']) + assert self.proto.ap_settings.message.register_channel.called + d.callback(True) + + res = self.proto.process_register( + dict(channelID=chid, + key=test_key)) + res.addCallback(check_register_result, test_endpoint) + return d + def test_register_no_chid(self): self._connect() self._send_message(dict(messageType="hello", channelIDs=[])) @@ -1098,37 +1132,6 @@ def test_notification_avoid_newer_delivery(self): args = self.send_mock.call_args eq_(args, None) - def test_notification_retains_no_dash(self): - self._connect() - - uaid = str(uuid.uuid4()).replace('-', '') - chid = str(uuid.uuid4()).replace('-', '') - - storage = self.proto.ap_settings.storage - storage.save_notification(uaid, chid, 10) - self._send_message(dict(messageType="hello", channelIDs=[], uaid=uaid)) - - d = Deferred() - - def check_notif_result(msg): - eq_(msg["messageType"], "notification") - updates = msg["updates"] - eq_(len(updates), 1) - eq_(updates[0]["channelID"], chid) - eq_(updates[0]["version"], 10) - d.callback(True) - - def check_hello_result(msg): - eq_(msg["status"], 200) - - # Now wait for the notification - nd = self._check_response(check_notif_result) - nd.addErrback(lambda x: d.errback(x)) - - f = self._check_response(check_hello_result) - f.addErrback(lambda x: d.errback(x)) - return d - def test_ack(self): patcher = patch('autopush.websocket.log', spec=True) mock_log = patcher.start() diff --git a/autopush/websocket.py b/autopush/websocket.py index cccfb6b4..56768fb8 100644 --- a/autopush/websocket.py +++ b/autopush/websocket.py @@ -1000,7 +1000,7 @@ def process_register(self, data): self.transport.pauseProducing() d = self.deferToThread(self.ap_settings.make_endpoint, self.ps.uaid, - chid) + chid, data.get("key")) d.addCallback(self.finish_register, chid) d.addErrback(self.trap_cancel) d.addErrback(self.error_register)