From 3d15b9bbc5002d8c6b03b3fd57418aa1892be0e7 Mon Sep 17 00:00:00 2001 From: jrconlin Date: Mon, 29 Feb 2016 15:01:25 -0800 Subject: [PATCH] feature: allow channels to register with public key Channels can include a public key when registering a new subscription channel. This public key should match the public key used to send subscription updates later. NOTE: this patch changes the format of the endpoint URLs, & the content of the endpoint URL token. This change also requires that ChannelIDs be normalized to dashed format, (e.g. a lower case, dash delimited string "deadbeef-0000-0000-deca-fbad11112222") This is the default mechanism used by Firefox for UUID generation. It is STRONGLY urged that clients normalize UUIDs used for ChannelIDs and User Agent IDs. While this should not break existing clients, additional testing may be required. Closes #326 --- autopush/db.py | 39 +++-- autopush/endpoint.py | 70 +++++---- autopush/logging.py | 3 +- autopush/main.py | 3 +- autopush/router/webpush.py | 7 +- autopush/settings.py | 62 +++++++- autopush/tests/test_db.py | 31 ++-- autopush/tests/test_endpoint.py | 245 +++++++++++++++++++++-------- autopush/tests/test_integration.py | 51 ++++-- autopush/tests/test_websocket.py | 69 ++++---- autopush/websocket.py | 2 +- 11 files changed, 406 insertions(+), 176 deletions(-) diff --git a/autopush/db.py b/autopush/db.py index 393c447e..c0ede2a4 100644 --- a/autopush/db.py +++ b/autopush/db.py @@ -52,6 +52,18 @@ def hasher(uaid): return uaid +def normalize_id(id): + if not id: + return id + if (len(id) == 36 and + id[8] == id[13] == id[18] == id[23] == '-'): + return id.lower() + raw = filter(lambda x: x in '0123456789abcdef', id.lower()) + if len(raw) != 32: + raise ValueError("Invalid UUID") + return '-'.join((raw[:8], raw[8:12], raw[12:16], raw[16:20], raw[20:])) + + def make_rotating_tablename(prefix, delta=0): """Creates a tablename for table rotation based on a prefix with a given month delta.""" @@ -261,7 +273,8 @@ def save_notification(self, uaid, chid, version): cond = "attribute_not_exists(version) or version < :ver" conn.put_item( self.table.table_name, - item=self.encode(dict(uaid=hasher(uaid), chid=chid, + item=self.encode(dict(uaid=hasher(uaid), + chid=normalize_id(chid), version=version)), condition_expression=cond, expression_attribute_values={ @@ -281,10 +294,12 @@ def delete_notification(self, uaid, chid, version=None): """ try: if version: - self.table.delete_item(uaid=hasher(uaid), chid=chid, + self.table.delete_item(uaid=hasher(uaid), + chid=normalize_id(chid), expected={"version__eq": version}) else: - self.table.delete_item(uaid=hasher(uaid), chid=chid) + self.table.delete_item(uaid=hasher(uaid), + chid=normalize_id(chid)) return True except ProvisionedThroughputExceededException: self.metrics.increment("error.provisioned.delete_notification") @@ -312,7 +327,8 @@ def register_channel(self, uaid, channel_id): db_key = self.encode({"uaid": hasher(uaid), "chidmessageid": " "}) # Generate our update expression expr = "ADD chids :channel_id" - expr_values = self.encode({":channel_id": set([channel_id])}) + expr_values = self.encode({":channel_id": + set([normalize_id(channel_id)])}) conn.update_item( self.table.table_name, db_key, @@ -327,7 +343,8 @@ def unregister_channel(self, uaid, channel_id, **kwargs): conn = self.table.connection db_key = self.encode({"uaid": hasher(uaid), "chidmessageid": " "}) expr = "DELETE chids :channel_id" - expr_values = self.encode({":channel_id": set([channel_id])}) + expr_values = self.encode({":channel_id": + set([normalize_id(channel_id)])}) result = conn.update_item( self.table.table_name, @@ -375,7 +392,7 @@ def store_message(self, uaid, channel_id, message_id, ttl, data=None, with the message id""" item = dict( uaid=hasher(uaid), - chidmessageid="%s:%s" % (channel_id, message_id), + chidmessageid="%s:%s" % (normalize_id(channel_id), message_id), data=data, headers=headers, ttl=ttl, @@ -407,7 +424,7 @@ def update_message(self, uaid, channel_id, message_id, ttl, data=None, item["headers"] = headers item["data"] = data try: - chidmessageid = "%s:%s" % (channel_id, message_id) + chidmessageid = "%s:%s" % (normalize_id(channel_id), message_id) db_key = self.encode({"uaid": hasher(uaid), "chidmessageid": chidmessageid}) expr = ("SET #tl=:ttl, #ts=:timestamp," @@ -439,14 +456,16 @@ def delete_message(self, uaid, channel_id, message_id, updateid=None): try: self.table.delete_item( uaid=hasher(uaid), - chidmessageid="%s:%s" % (channel_id, message_id), + chidmessageid="%s:%s" % (normalize_id(channel_id), + message_id), expected={'updateid__eq': updateid}) except ConditionalCheckFailedException: return False else: self.table.delete_item( uaid=hasher(uaid), - chidmessageid="%s:%s" % (channel_id, message_id)) + chidmessageid="%s:%s" % (normalize_id(channel_id), + message_id)) return True def delete_messages(self, uaid, chidmessageids): @@ -463,7 +482,7 @@ def delete_messages_for_channel(self, uaid, channel_id): """Deletes all messages for a uaid/channel_id""" results = self.table.query_2( uaid__eq=hasher(uaid), - chidmessageid__beginswith="%s:" % channel_id, + chidmessageid__beginswith="%s:" % normalize_id(channel_id), consistent=True, attributes=("chidmessageid",), ) diff --git a/autopush/endpoint.py b/autopush/endpoint.py index 25d0ad60..64873bad 100644 --- a/autopush/endpoint.py +++ b/autopush/endpoint.py @@ -42,7 +42,8 @@ from autopush.db import ( generate_last_connect, - hasher + hasher, + normalize_id, ) from autopush.router.interface import RouterException from autopush.utils import ( @@ -174,8 +175,8 @@ def chid(self): @chid.setter def chid(self, value): """Set the ChannelID and record to _client_info""" - self._chid = value - self._client_info["channelID"] = value + self._chid = normalize_id(value) + self._client_info["channelID"] = self._chid def _init_info(self): """Returns a dict of additional client data""" @@ -309,7 +310,7 @@ def _invalid_auth(self, fail): log.msg("Invalid bearer token: " + message, **self._client_info) raise VapidAuthException("Invalid bearer token: " + message) - def _process_auth(self, result): + def _process_auth(self, result, keys): """Process the optional VAPID auth token. VAPID requires two headers to be present; @@ -322,28 +323,26 @@ def _process_auth(self, result): # No auth present, so it's not a VAPID call. if not authorization: return result - - header_info = parse_header(self.request.headers.get('crypto-key')) - if not header_info: + if not keys: raise VapidAuthException("Missing Crypto-Key") - values = header_info[-1] - if isinstance(values, dict): - crypto_key = values.get('p256ecdsa') - try: - (auth_type, token) = authorization.split(' ', 1) - except ValueError: - raise VapidAuthException("Invalid Authorization header") - # if it's a bearer token containing what may be a JWT - if auth_type.lower() == AUTH_SCHEME and '.' in token: - d = deferToThread(extract_jwt, token, crypto_key) - d.addCallback(self._store_auth, crypto_key, token, result) - d.addErrback(self._invalid_auth) - return d - # otherwise, it's not, so ignore the VAPID data. - return result - else: + if not isinstance(keys, dict): + raise VapidAuthException("Invalid Crypto-Key") + crypto_key = keys.get('p256ecdsa') + if not crypto_key: raise VapidAuthException("Invalid bearer token: " "improperly specified crypto-key") + try: + (auth_type, token) = authorization.split(' ', 1) + except ValueError: + raise VapidAuthException("Invalid Authorization header") + # if it's a bearer token containing what may be a JWT + if auth_type.lower() == AUTH_SCHEME and '.' in token: + d = deferToThread(extract_jwt, token, crypto_key) + d.addCallback(self._store_auth, crypto_key, token, result) + d.addErrback(self._invalid_auth) + return d + # otherwise, it's not, so ignore the VAPID data. + return result class MessageHandler(AutoendpointHandler): @@ -404,17 +403,27 @@ class EndpointHandler(AutoendpointHandler): # Cyclone HTTP Methods ############################################################# @cyclone.web.asynchronous - def put(self, token): + def put(self, api_ver="v0", token=None): """HTTP PUT Handler Primary entry-point to handling a notification for a push client. """ self.start_time = time.time() - fernet = self.ap_settings.fernet - - d = deferToThread(fernet.decrypt, token.encode('utf8')) - d.addCallback(self._process_auth) + public_key = None + keys = {} + crypto_key = self.request.headers.get('crypto-key') + if crypto_key: + header_info = parse_header(crypto_key) + keys = header_info[-1] + if isinstance(keys, dict): + public_key = keys.get('p256ecdsa') + + d = deferToThread(self.ap_settings.parse_endpoint, + token, + api_ver, + public_key) + d.addCallback(self._process_auth, keys) d.addCallback(self._token_valid) d.addErrback(self._auth_err) d.addErrback(self._token_err) @@ -426,11 +435,10 @@ def put(self, token): ############################################################# def _token_valid(self, result): """Called after the token is decrypted successfully""" - info = result.split(":") - if len(info) != 2: + if len(result) != 2: raise ValueError("Wrong subscription token components") - self.uaid, self.chid = info + self.uaid, self.chid = result d = deferToThread(self.ap_settings.router.get_uaid, self.uaid) d.addCallback(self._uaid_lookup_results) d.addErrback(self._uaid_not_found_err) diff --git a/autopush/logging.py b/autopush/logging.py index 420dad7d..cc0dbf9d 100644 --- a/autopush/logging.py +++ b/autopush/logging.py @@ -15,7 +15,8 @@ import sys import raven -from eliot import add_destination, fields, Logger, MessageType +from eliot import (add_destination, fields, + Logger, MessageType) from twisted.python.log import textFromEventDict, startLoggingWithObserver HOSTNAME = socket.getfqdn() diff --git a/autopush/main.py b/autopush/main.py index 1f541099..d9579ecb 100644 --- a/autopush/main.py +++ b/autopush/main.py @@ -467,7 +467,8 @@ def endpoint_main(sysargs=None, use_files=True): # Endpoint HTTP router site = cyclone.web.Application([ - (r"/push/([^\/]+)", EndpointHandler, dict(ap_settings=settings)), + (r"/push/(?:(v\d+)\/)?([^\/]+)", EndpointHandler, + dict(ap_settings=settings)), (r"/m/([^\/]+)", MessageHandler, dict(ap_settings=settings)), # PUT /register/ => connect info # GET /register/uaid => chid + endpoint diff --git a/autopush/router/webpush.py b/autopush/router/webpush.py index 88c38197..a4f9ab4b 100644 --- a/autopush/router/webpush.py +++ b/autopush/router/webpush.py @@ -19,6 +19,7 @@ from autopush.protocol import IgnoreBody from autopush.router.interface import RouterException, RouterResponse from autopush.router.simple import SimpleRouter +from autopush.db import normalize_id TTL_URL = "https://webpush-wg.github.io/webpush-protocol/#rfc.section.6.2" @@ -72,7 +73,8 @@ def preflight_check(self, uaid, channel_id): self.ap_settings.message_tables[month_table].all_channels, uaid=uaid) - if not exists or channel_id not in chans: + if (not exists or channel_id.lower() not + in map(lambda x: normalize_id(x), chans)): raise RouterException("No such subscription", status_code=404, log_exception=False, errno=106) returnValue(month_table) @@ -84,7 +86,8 @@ def _send_notification(self, uaid, node_id, notification): headers for the notification. """ - payload = {"channelID": notification.channel_id, + # Firefox currently requires channelIDs to be '-' formatted. + payload = {"channelID": normalize_id(notification.channel_id), "version": notification.version, "ttl": notification.ttl or 0, "timestamp": int(time.time()), diff --git a/autopush/settings.py b/autopush/settings.py index 8a8daec4..41af7fa1 100644 --- a/autopush/settings.py +++ b/autopush/settings.py @@ -2,7 +2,10 @@ import datetime import socket +from hashlib import sha256 + from cryptography.fernet import Fernet, MultiFernet +from cryptography.hazmat.primitives import constant_time from twisted.internet import reactor from twisted.internet.defer import ( inlineCallbacks, @@ -232,6 +235,7 @@ def update_rotating_tables(self): self.current_msg_month = message_table.table_name self.message_tables[self.current_msg_month] = \ Message(message_table, self.metrics) + returnValue(True) def update(self, **kwargs): @@ -248,7 +252,57 @@ def update(self, **kwargs): else: setattr(self, key, val) - def make_endpoint(self, uaid, chid): - """ Create an endpoint from the identifiers""" - return self.endpoint_url + '/push/' + \ - self.fernet.encrypt((uaid + ':' + chid).encode('utf8')) + def make_endpoint(self, uaid, chid, key=None): + """Create an v1 or v2 endpoint from the indentifiers. + + Both endpoints use bytes instead of hex to reduce ID length. + v0 is uaid.hex + ':' + chid.hex and is deprecated. + v1 is the uaid + chid + v2 is the uaid + chid + sha256(key).bytes + + :param uaid: User Agent Identifier + :param chid: Channel or Subscription ID + :param key: Optional provided Public Key + :returns: Push endpoint + + """ + root = self.endpoint_url + '/push/' + base = (uaid.replace('-', '').decode("hex") + + chid.replace('-', '').decode("hex")) + + if key is None: + return root + 'v1/' + self.fernet.encrypt(base).strip('=') + + return root + 'v2/' + self.fernet.encrypt(base + sha256(key).digest()) + + def parse_endpoint(self, token, version="v0", public_key=None): + """Parse an endpoint into component elements of UAID, CHID and optional + key hash if v2 + + :param token: The obscured subscription data. + :param version: This is the API version of the token. + :param public_key: the public key (from Encryption-Key: p256ecdsa=) + + :raises ValueError: In the case of a malformed endpoint. + + :returns: a tuple containing the (UAID, CHID) + + """ + + token = self.fernet.decrypt(token.encode('utf8')) + + if version == 'v0': + if ':' not in token: + raise ValueError("Corrupted push token") + return tuple(token.split(':')) + if version == 'v1' and len(token) != 32: + raise ValueError("Corrupted push token") + if version == 'v2': + if len(token) != 64: + raise ValueError("Corrupted push token") + if not public_key: + raise ValueError("Invalid key data") + if not constant_time.bytes_eq(sha256(public_key).digest(), + token[32:]): + raise ValueError("Key mismatch") + return (token[:16].encode('hex'), token[16:32].encode('hex')) diff --git a/autopush/tests/test_db.py b/autopush/tests/test_db.py index 622fdc96..1cb74104 100644 --- a/autopush/tests/test_db.py +++ b/autopush/tests/test_db.py @@ -28,6 +28,8 @@ mock_db2 = mock_dynamodb2() +dummy_uaid = str(uuid.UUID("abad1dea00000000aabbccdd00000000")) +dummy_chid = str(uuid.UUID("deadbeef00000000decafbad00000000")) def setUp(): @@ -106,7 +108,7 @@ def raise_error(*args, **kwargs): raise ConditionalCheckFailedException(None, None) storage.table.connection.put_item.side_effect = raise_error - result = storage.save_notification("fdas", "asdf", 8) + result = storage.save_notification(dummy_uaid, dummy_chid, 8) eq_(result, False) def test_fetch_over_provisioned(self): @@ -119,7 +121,7 @@ def raise_error(*args, **kwargs): storage.table.connection.put_item.side_effect = raise_error with self.assertRaises(ProvisionedThroughputExceededException): - storage.save_notification("asdf", "asdf", 12) + storage.save_notification(dummy_uaid, dummy_chid, 12) def test_save_over_provisioned(self): s = get_storage_table() @@ -131,7 +133,7 @@ def raise_error(*args, **kwargs): storage.table.query_2.side_effect = raise_error with self.assertRaises(ProvisionedThroughputExceededException): - storage.fetch_notifications("asdf") + storage.fetch_notifications(dummy_uaid) def test_delete_over_provisioned(self): s = get_storage_table() @@ -142,7 +144,7 @@ def raise_error(*args, **kwargs): raise ProvisionedThroughputExceededException(None, None) storage.table.connection.delete_item.side_effect = raise_error - results = storage.delete_notification("asdf", "asdf") + results = storage.delete_notification(dummy_uaid, dummy_chid) eq_(results, False) @@ -195,7 +197,7 @@ def test_unregister(self): m.connection.update_item.return_value = { 'Attributes': {'uaid': {'S': self.uaid}}, 'ConsumedCapacityUnits': 0.5} - r = message.unregister_channel(self.uaid, "test") + r = message.unregister_channel(self.uaid, dummy_chid) eq_(r, False) def test_all_channels(self): @@ -232,7 +234,7 @@ def test_save_channels(self): def test_all_channels_no_uaid(self): m = get_rotating_message_table() message = Message(m, SinkMetrics()) - exists, chans = message.all_channels("asdf") + exists, chans = message.all_channels(dummy_uaid) assert(chans == set([])) def test_message_storage(self): @@ -322,7 +324,7 @@ def raise_condition(*args, **kwargs): message.table = Mock() message.table.delete_item.side_effect = raise_condition - result = message.delete_message(uaid="asdf", channel_id="asdf", + result = message.delete_message(uaid=dummy_uaid, channel_id=dummy_chid, message_id="asdf", updateid="asdf") eq_(result, False) @@ -415,7 +417,7 @@ def raise_error(*args, **kwargs): router.table.connection.update_item.side_effect = raise_error with self.assertRaises(ProvisionedThroughputExceededException): - router.register_user(dict(uaid="asdf", node_id="me", + router.register_user(dict(uaid=dummy_uaid, node_id="me", connected_at=1234)) def test_clear_node_provision_failed(self): @@ -428,7 +430,8 @@ def raise_error(*args, **kwargs): router.table.connection.put_item.side_effect = raise_error with self.assertRaises(ProvisionedThroughputExceededException): - router.clear_node(Item(r, dict(uaid="asdf", connected_at="1234", + router.clear_node(Item(r, dict(uaid=dummy_uaid, + connected_at="1234", node_id="asdf"))) def test_save_uaid(self): @@ -465,7 +468,7 @@ def raise_condition(*args, **kwargs): router.table.connection = Mock() router.table.connection.update_item.side_effect = raise_condition - router_data = dict(uaid="asdf", node_id="asdf", connected_at=1234) + router_data = dict(uaid=dummy_uaid, node_id="asdf", connected_at=1234) result = router.register_user(router_data) eq_(result, (False, {}, router_data)) @@ -474,18 +477,18 @@ def test_node_clear(self): router = Router(r, SinkMetrics()) # Register a node user - router.register_user(dict(uaid="asdf", node_id="asdf", + router.register_user(dict(uaid=dummy_uaid, node_id="asdf", connected_at=1234)) # Verify - user = router.get_uaid("asdf") + user = router.get_uaid(dummy_uaid) eq_(user["node_id"], "asdf") # Clear router.clear_node(user) # Verify - user = router.get_uaid("asdf") + user = router.get_uaid(dummy_uaid) eq_(user.get("node_id"), None) def test_node_clear_fail(self): @@ -497,7 +500,7 @@ def raise_condition(*args, **kwargs): router.table.connection.put_item = Mock() router.table.connection.put_item.side_effect = raise_condition - data = dict(uaid="asdf", node_id="asdf", connected_at=1234) + data = dict(uaid=dummy_uaid, node_id="asdf", connected_at=1234) result = router.clear_node(Item(r, data)) eq_(result, False) diff --git a/autopush/tests/test_endpoint.py b/autopush/tests/test_endpoint.py index 1dda1c04..20582e5e 100644 --- a/autopush/tests/test_endpoint.py +++ b/autopush/tests/test_endpoint.py @@ -6,6 +6,8 @@ import base64 import random +from hashlib import sha256 + import ecdsa import twisted.internet.base from cryptography.fernet import Fernet, InvalidToken @@ -37,6 +39,8 @@ from autopush.utils import (generate_hash, decipher_public_key) mock_dynamodb2 = mock_dynamodb2() +dummy_uaid = str(uuid.UUID("abad1dea00000000aabbccdd00000000")) +dummy_chid = str(uuid.UUID("deadbeef00000000decafbad00000000")) def setUp(): @@ -228,13 +232,14 @@ def test_parse_header(self): result = utils.parse_header('fee; fie="foe"; fum=""foobar="five""";') eq_(result[0], 'fee') eq_(result[1], {'fum': 'foobar="five"', 'fie': 'foe'}) + eq_([], utils.parse_header("")) def test_uaid_lookup_results(self): fresult = dict(router_type="test") frouter = Mock(spec=Router) frouter.route_notification = Mock() frouter.route_notification.return_value = RouterResponse() - self.endpoint.chid = "fred" + self.endpoint.chid = dummy_chid self.request_mock.headers["encryption"] = "stuff" self.request_mock.headers["content-encoding"] = "aes128" self.endpoint.ap_settings.routers["test"] = frouter @@ -251,7 +256,7 @@ def test_uaid_lookup_results_bad_ttl(self): frouter = Mock(spec=Router) frouter.route_notification = Mock() frouter.route_notification.return_value = RouterResponse() - self.endpoint.chid = "fred" + self.endpoint.chid = dummy_chid self.request_mock.headers["ttl"] = "woops" self.request_mock.headers["encryption"] = "stuff" self.request_mock.headers["content-encoding"] = "aes128" @@ -270,7 +275,7 @@ def test_webpush_ttl_too_large(self): frouter = Mock(spec=Router) frouter.route_notification = Mock() frouter.route_notification.return_value = RouterResponse() - self.endpoint.chid = "fred" + self.endpoint.chid = dummy_chid self.request_mock.headers["ttl"] = str(MAX_TTL + 100) self.request_mock.headers["encryption"] = "stuff" self.request_mock.headers["content-encoding"] = "aes128" @@ -328,12 +333,12 @@ def handle_finish(result): self.endpoint.set_status.assert_called_with(400) self.finish_deferred.addCallback(handle_finish) - self.endpoint.post(dummy_uaid) + self.endpoint.post(None, dummy_uaid) return self.finish_deferred def test_webpush_bad_routertype(self): fresult = dict(router_type="fred") - self.endpoint.chid = "fred" + self.endpoint.chid = dummy_chid self.request_mock.body = b"stuff" self.endpoint._uaid_lookup_results(fresult) @@ -347,7 +352,7 @@ def test_webpush_uaid_lookup_no_crypto_headers_with_data(self): fresult = dict(router_type="webpush") frouter = self.settings.routers["webpush"] frouter.route_notification.return_value = RouterResponse() - self.endpoint.chid = "fred" + self.endpoint.chid = dummy_chid self.request_mock.body = b"stuff" self.endpoint._uaid_lookup_results(fresult) @@ -362,7 +367,7 @@ def test_webpush_uaid_lookup_no_crypto_headers_without_data(self): fresult = dict(router_type="webpush") frouter = self.settings.routers["webpush"] frouter.route_notification.return_value = RouterResponse() - self.endpoint.chid = "fred" + self.endpoint.chid = dummy_chid self.endpoint._uaid_lookup_results(fresult) def handle_finish(value): @@ -377,7 +382,7 @@ def test_other_uaid_lookup_no_crypto_headers(self): frouter = Mock(spec=Router) frouter.route_notification = Mock() frouter.route_notification.return_value = RouterResponse() - self.endpoint.chid = "fred" + self.endpoint.chid = dummy_chid self.endpoint.ap_settings.routers["test"] = frouter self.endpoint._uaid_lookup_results(fresult) @@ -391,7 +396,7 @@ def test_webpush_payload_encoding(self): fresult = dict(router_type="webpush") frouter = self.settings.routers["webpush"] frouter.route_notification.return_value = RouterResponse() - self.endpoint.chid = "fred" + self.endpoint.chid = dummy_chid self.request_mock.headers["encryption"] = "stuff" self.request_mock.headers["content-encoding"] = "aes128" self.request_mock.body = b"\xc3\x28\xa0\xa1" @@ -401,7 +406,7 @@ def handle_finish(value): calls = frouter.route_notification.mock_calls eq_(len(calls), 1) (_, (notification, _), _) = calls[0] - eq_(notification.channel_id, "fred") + eq_(notification.channel_id, dummy_chid) eq_(notification.data, b"wyigoQ==") self.finish_deferred.addCallback(handle_finish) @@ -412,7 +417,7 @@ def test_other_payload_encoding(self): frouter = Mock(spec=Router) frouter.route_notification = Mock() frouter.route_notification.return_value = RouterResponse() - self.endpoint.chid = "fred" + self.endpoint.chid = dummy_chid self.endpoint.ap_settings.routers["test"] = frouter self.request_mock.body = b"stuff" @@ -422,7 +427,7 @@ def handle_finish(value): calls = frouter.route_notification.mock_calls eq_(len(calls), 1) (_, (notification, _), _) = calls[0] - eq_(notification.channel_id, "fred") + eq_(notification.channel_id, dummy_chid) eq_(notification.data, b"stuff") self.finish_deferred.addCallback(handle_finish) @@ -444,10 +449,10 @@ def test_init_info(self): self.request_mock.headers["authorization"] = "bearer token fred" d = self.endpoint._init_info() eq_(d["authorization"], "bearer token fred") - self.endpoint.uaid = "faa" - eq_(self.endpoint._client_info["uaid_hash"], hasher("faa")) - self.endpoint.chid = "fie" - eq_(self.endpoint._client_info['channelID'], "fie") + self.endpoint.uaid = dummy_uaid + eq_(self.endpoint._client_info["uaid_hash"], hasher(dummy_uaid)) + self.endpoint.chid = dummy_chid + eq_(self.endpoint._client_info['channelID'], dummy_chid) def test_load_params_arguments(self): args = self.endpoint.request.arguments @@ -510,7 +515,7 @@ def handle_finish(result): self.endpoint.set_status.assert_called_with(413) self.finish_deferred.addCallback(handle_finish) - self.endpoint.put('') + self.endpoint.put(None, '') return self.finish_deferred @patch_logger @@ -525,7 +530,7 @@ def handle_finish(value): self._assert_error_response(value) self.finish_deferred.addCallback(handle_finish) - self.endpoint.put('') + self.endpoint.put(None, '') return self.finish_deferred def test_put_token_invalid(self): @@ -537,7 +542,7 @@ def handle_finish(result): self.status_mock.assert_called_with(400) self.finish_deferred.addCallback(handle_finish) - self.endpoint.put('') + self.endpoint.put(None, '') return self.finish_deferred def test_put_token_wrong(self): @@ -548,7 +553,7 @@ def handle_finish(result): self.status_mock.assert_called_with(400) self.finish_deferred.addCallback(handle_finish) - self.endpoint.put('') + self.endpoint.put(None, '') return self.finish_deferred def _throw_item_not_found(self, item): @@ -569,7 +574,12 @@ def handle_finish(result): self.endpoint.version, self.endpoint.data = 789, None - self.endpoint._token_valid('123:456') + exc = self.assertRaises(ValueError, + self.endpoint._token_valid, + ['invalid']) + eq_(exc.message, "Wrong subscription token components") + + self.endpoint._token_valid(['123', dummy_chid]) return self.finish_deferred def test_put_default_router(self): @@ -582,7 +592,7 @@ def handle_finish(result): self.endpoint.set_status.assert_called_with(200) self.finish_deferred.addCallback(handle_finish) - self.endpoint.put(dummy_uaid) + self.endpoint.put(None, dummy_uaid) return self.finish_deferred def test_put_router_with_headers(self): @@ -605,7 +615,7 @@ def handle_finish(result): self.endpoint.set_status.assert_called_with(200) self.finish_deferred.addCallback(handle_finish) - self.endpoint.put(dummy_uaid) + self.endpoint.put(None, dummy_uaid) return self.finish_deferred def test_put_router_needs_change(self): @@ -625,7 +635,7 @@ def handle_finish(result): assert(self.router_mock.register_user.called) self.finish_deferred.addCallback(handle_finish) - self.endpoint.put(dummy_uaid) + self.endpoint.put(None, dummy_uaid) return self.finish_deferred def test_put_router_needs_update(self): @@ -645,7 +655,7 @@ def handle_finish(result): assert(self.router_mock.register_user.called) self.finish_deferred.addCallback(handle_finish) - self.endpoint.put(dummy_uaid) + self.endpoint.put(None, dummy_uaid) return self.finish_deferred def test_put_bogus_headers(self): @@ -669,7 +679,7 @@ def handle_finish(result): self.endpoint.set_status.assert_called_with(400) self.finish_deferred.addBoth(handle_finish) - self.endpoint.put(dummy_uaid) + self.endpoint.put(None, dummy_uaid) return self.finish_deferred def test_put_invalid_vapid_crypto_header(self): @@ -693,10 +703,10 @@ def handle_finish(result): self.endpoint.set_status.assert_called_with(401) self.finish_deferred.addCallback(handle_finish) - self.endpoint.put(dummy_uaid) + self.endpoint.put(None, dummy_uaid) return self.finish_deferred - def test_put_invalid_vapid_auth_header(self): + def test_put_invalid_vapid_crypto_key(self): self.request_mock.headers["encryption"] = "ignored" self.request_mock.headers["content-encoding"] = 'text' self.request_mock.headers["authorization"] = "invalid" @@ -717,7 +727,31 @@ def handle_finish(result): self.endpoint.set_status.assert_called_with(401) self.finish_deferred.addCallback(handle_finish) - self.endpoint.put(dummy_uaid) + self.endpoint.put(None, dummy_uaid) + return self.finish_deferred + + def test_put_invalid_vapid_auth_header(self): + self.request_mock.headers["encryption"] = "ignored" + self.request_mock.headers["content-encoding"] = 'text' + self.request_mock.headers["authorization"] = "invalid" + self.request_mock.headers["crypto-key"] = "p256ecdsa=crap" + self.request_mock.body = b' ' + self.fernet_mock.decrypt.return_value = "123:456" + self.router_mock.get_uaid.return_value = dict( + router_type="webpush", + router_data=dict(), + ) + self.wp_router_mock.route_notification.return_value = RouterResponse( + status_code=200, + router_data={}, + ) + + def handle_finish(result): + self.assertTrue(result) + self.endpoint.set_status.assert_called_with(401) + + self.finish_deferred.addCallback(handle_finish) + self.endpoint.put(None, dummy_uaid) return self.finish_deferred def test_put_missing_vapid_crypto_header(self): @@ -740,7 +774,7 @@ def handle_finish(result): self.endpoint.set_status.assert_called_with(401) self.finish_deferred.addCallback(handle_finish) - self.endpoint.put(dummy_uaid) + self.endpoint.put(None, dummy_uaid) return self.finish_deferred def test_post_webpush_with_headers_in_response(self): @@ -764,7 +798,7 @@ def handle_finish(result): "Location", "Somewhere") self.finish_deferred.addCallback(handle_finish) - self.endpoint.post(dummy_uaid) + self.endpoint.post(None, dummy_uaid) return self.finish_deferred def _gen_jwt(self, header, payload): @@ -812,7 +846,7 @@ def handle_finish(result, crypto_key, token): self.assertTrue(result) self.finish_deferred.addCallback(handle_finish, crypto_key, token) - self.endpoint.post(dummy_uaid) + self.endpoint.post(None, dummy_uaid) return self.finish_deferred def test_decipher_public_key(self): @@ -865,7 +899,7 @@ def handle_finish(result): self.assertTrue(result) self.finish_deferred.addCallback(handle_finish) - self.endpoint.post(dummy_uaid) + self.endpoint.post(None, dummy_uaid) return self.finish_deferred def test_post_webpush_with_bad_vapid_auth(self): @@ -897,7 +931,7 @@ def handle_finish(result): self.assertTrue(result) self.finish_deferred.addCallback(handle_finish) - self.endpoint.post(dummy_uaid) + self.endpoint.post(None, dummy_uaid) return self.finish_deferred def test_post_webpush_no_sig(self): @@ -932,7 +966,7 @@ def handle_finish(result): self.finish_deferred.addCallback(handle_finish) - self.endpoint.post(dummy_uaid) + self.endpoint.post(None, dummy_uaid) return self.finish_deferred def test_util_extract_jwt(self): @@ -970,7 +1004,7 @@ def handle_finish(result): self.endpoint.set_status.assert_called_with(401) self.finish_deferred.addCallback(handle_finish) - self.endpoint.post(dummy_uaid) + self.endpoint.post(None, dummy_uaid) return self.finish_deferred def test_post_webpush_bad_exp(self): @@ -1002,7 +1036,7 @@ def handle_finish(result): self.endpoint.set_status.assert_called_with(401) self.finish_deferred.addCallback(handle_finish) - self.endpoint.post(dummy_uaid) + self.endpoint.post(None, dummy_uaid) return self.finish_deferred def test_post_webpush_with_auth(self): @@ -1027,7 +1061,7 @@ def handle_finish(result): "Location", "Somewhere") self.finish_deferred.addCallback(handle_finish) - self.endpoint.post(dummy_uaid) + self.endpoint.post(None, dummy_uaid) return self.finish_deferred def test_post_webpush_with_logged_delivered(self): @@ -1058,7 +1092,7 @@ def handle_finish(result): log_patcher.stop() self.finish_deferred.addCallback(handle_finish) - self.endpoint.post(dummy_uaid) + self.endpoint.post(None, dummy_uaid) return self.finish_deferred def test_post_webpush_with_logged_stored(self): @@ -1089,7 +1123,7 @@ def handle_finish(result): log_patcher.stop() self.finish_deferred.addCallback(handle_finish) - self.endpoint.post(dummy_uaid) + self.endpoint.post(None, dummy_uaid) return self.finish_deferred @patch("twisted.python.log") @@ -1120,7 +1154,7 @@ def handle_finish(result): self.endpoint.set_status.assert_called_with(503) self.finish_deferred.addCallback(handle_finish) - self.endpoint.post(dummy_uaid) + self.endpoint.post(None, dummy_uaid) return self.finish_deferred def test_put_db_error(self): @@ -1132,7 +1166,7 @@ def handle_finish(result): self.endpoint.set_status.assert_called_with(503) self.finish_deferred.addCallback(handle_finish) - self.endpoint.put(dummy_uaid) + self.endpoint.put(None, dummy_uaid) return self.finish_deferred def test_cors(self): @@ -1221,9 +1255,84 @@ def test_padding(self): eq_(utils.fix_padding("ab"), "ab==") eq_(utils.fix_padding("abcd"), "abcd") + def test_parse_endpoint(self): + v0_valid = dummy_uaid + ":" + dummy_chid + uaid_strip = dummy_uaid.replace('-', '') + chid_strip = dummy_chid.replace('-', '') + uaid_dec = uaid_strip.decode('hex') + chid_dec = chid_strip.decode('hex') + v1_valid = uaid_dec + chid_dec + pub_key = uuid.uuid4().hex + v2_valid = sha256(pub_key).digest() + v2_invalid = sha256(uuid.uuid4().hex).digest() + # v0 good + self.fernet_mock.decrypt.return_value = v0_valid + tokens = self.settings.parse_endpoint('/valid') + eq_(tokens, (dummy_uaid, dummy_chid)) + + # v0 bad + self.fernet_mock.decrypt.return_value = v1_valid + exc = self.assertRaises(ValueError, + self.settings.parse_endpoint, + '/invalid') + eq_(exc.message, 'Corrupted push token') + + self.fernet_mock.decrypt.return_value = v1_valid[:30] + exc = self.assertRaises(ValueError, + self.settings.parse_endpoint, + 'invalid', 'v1') + eq_(exc.message, 'Corrupted push token') + + self.fernet_mock.decrypt.return_value = v1_valid + tokens = self.settings.parse_endpoint('valid', 'v1') + eq_(tokens, (uaid_strip, chid_strip)) + + self.fernet_mock.decrypt.return_value = v1_valid + v2_valid + tokens = self.settings.parse_endpoint('valid', 'v2', pub_key) + eq_(tokens, (uaid_strip, chid_strip)) + + self.fernet_mock.decrypt.return_value = v1_valid + "invalid" + exc = self.assertRaises(ValueError, + self.settings.parse_endpoint, + 'invalid', 'v2', pub_key) + eq_(exc.message, "Corrupted push token") + + self.fernet_mock.decrypt.return_value = v1_valid + v2_valid + exc = self.assertRaises(ValueError, + self.settings.parse_endpoint, + 'invalid', 'v2', pub_key[:30]) + eq_(exc.message, "Key mismatch") + + self.fernet_mock.decrypt.return_value = v1_valid + v2_invalid + exc = self.assertRaises(ValueError, + self.settings.parse_endpoint, + 'invalid', 'v2') + eq_(exc.message, "Invalid key data") + + self.fernet_mock.decrypt.return_value = v1_valid + v2_invalid + exc = self.assertRaises(ValueError, + self.settings.parse_endpoint, + 'invalid', 'v2', pub_key) + eq_(exc.message, "Key mismatch") + + def test_make_endpoint(self): + + def echo(val): + return val.encode('hex') + + # make a v1 endpoint: + self.fernet_mock.encrypt = echo + strip_uaid = dummy_uaid.replace('-', '') + strip_chid = dummy_chid.replace('-', '') + dummy_key = "RandomKeyString" + sha = sha256(dummy_key).hexdigest() + ep = self.settings.make_endpoint(dummy_uaid, dummy_chid) + eq_(ep, 'http://localhost/push/v1/' + strip_uaid + strip_chid) + ep = self.settings.make_endpoint(dummy_uaid, dummy_chid, + "RandomKeyString") + eq_(ep, 'http://localhost/push/v2/' + strip_uaid + strip_chid + sha) + -dummy_uaid = "00000000123412341234567812345678" -dummy_chid = "11111111123412341234567812345678" CORS_HEAD = "POST,PUT,DELETE" @@ -1378,9 +1487,9 @@ def handle_finish(value): ok_(call_args is not None) args = call_args[0] call_arg = json.loads(args[0]) - eq_(call_arg["uaid"], dummy_uaid) + eq_(call_arg["uaid"], dummy_uaid.replace('-', '')) eq_(call_arg["channelID"], dummy_chid) - eq_(call_arg["endpoint"], "http://localhost/push/abcd123") + eq_(call_arg["endpoint"], "http://localhost/push/v1/abcd123") ok_("secret" in call_arg) self.finish_deferred.addCallback(handle_finish) @@ -1417,9 +1526,9 @@ def handle_finish(value): ok_(call_args is not None) args = call_args[0] call_arg = json.loads(args[0]) - eq_(call_arg["uaid"], dummy_uaid) + eq_(call_arg["uaid"], dummy_uaid.replace('-', '')) eq_(call_arg["channelID"], dummy_chid) - eq_(call_arg["endpoint"], "http://localhost/push/abcd123") + eq_(call_arg["endpoint"], "http://localhost/push/v1/abcd123") calls = self.reg.ap_settings.router.register_user.call_args call_args = calls[0][0] eq_(True, has_connected_this_month(call_args)) @@ -1475,7 +1584,7 @@ def handle_finish(value): args = call_args[0] call_arg = json.loads(args[0]) eq_(call_arg["channelID"], dummy_chid) - eq_(call_arg["endpoint"], "http://localhost/push/abcd123") + eq_(call_arg["endpoint"], "http://localhost/push/v1/abcd123") self.finish_deferred.addCallback(handle_finish) self.reg.request.headers["Authorization"] = self.auth @@ -1530,7 +1639,7 @@ def handle_finish(value): args = call_args[0] call_arg = json.loads(args[0]) eq_(call_arg["channelID"], dummy_chid) - eq_(call_arg["endpoint"], "http://localhost/push/abcd123") + eq_(call_arg["endpoint"], "http://localhost/push/v1/abcd123") self.finish_deferred.addCallback(handle_finish) self.reg.request.headers["Authorization"] = self.auth @@ -1555,7 +1664,7 @@ def handle_finish(value): args = call_args[0] call_arg = json.loads(args[0]) eq_(call_arg["channelID"], dummy_chid) - eq_(call_arg["endpoint"], "http://localhost/push/abcd123") + eq_(call_arg["endpoint"], "http://localhost/push/v1/abcd123") self.finish_deferred.addCallback(handle_finish) self.reg.request.headers["Authorization"] = self.auth @@ -1613,23 +1722,24 @@ def test_delete_chid(self): dummy_chid, "1", 10000) - messages.register_channel(dummy_uaid, "test") + chid2 = str(uuid.uuid4()) + messages.register_channel(dummy_uaid, chid2) messages.store_message( dummy_uaid, - "test", + chid2, "2", 10000) self.reg.request.headers["Authorization"] = self.auth - def handle_finish(value): + def handle_finish(value, chid2): ml = messages.fetch_messages(dummy_uaid) cl = messages.all_channels(dummy_uaid) eq_(len(ml), 1) - eq_((True, set(['test'])), cl) + eq_((True, set([chid2])), cl) messages.delete_user(dummy_uaid) - self.finish_deferred.addCallback(handle_finish) - self.reg.delete("simplepush", "test", uaid=dummy_uaid, chid=dummy_chid) + self.finish_deferred.addCallback(handle_finish, chid2) + self.reg.delete("simplepush", "test", dummy_uaid, dummy_chid) return self.finish_deferred def test_delete_bad_chid(self): @@ -1647,11 +1757,12 @@ def handle_finish(value): messages.delete_user(dummy_uaid) self.finish_deferred.addCallback(handle_finish) - self.reg.delete("test", "test", uaid=dummy_uaid, chid="invalid") + self.reg.delete("test", "test", dummy_uaid, "invalid") return self.finish_deferred def test_delete_uaid(self): messages = self.reg.ap_settings.message + chid2 = str(uuid.uuid4()) messages.store_message( dummy_uaid, dummy_chid, @@ -1659,15 +1770,14 @@ def test_delete_uaid(self): 10000) messages.store_message( dummy_uaid, - "test", + chid2, "2", 10000) self.reg.ap_settings.router.drop_user = Mock() self.reg.ap_settings.router.drop_user.return_value = True - def handle_finish(value): + def handle_finish(value, chid2): ml = messages.fetch_messages(dummy_uaid) - cl = messages.all_channels(dummy_uaid) eq_(len(ml), 0) # Note: Router is mocked, so the UAID is never actually # dropped. Normally, this should messages.all_channels @@ -1675,11 +1785,10 @@ def handle_finish(value): ok_(self.reg.ap_settings.router.drop_user.called) eq_(self.reg.ap_settings.router.drop_user.call_args_list[0][0], (dummy_uaid,)) - eq_((True, set(["test"])), cl) - self.finish_deferred.addCallback(handle_finish) + self.finish_deferred.addCallback(handle_finish, chid2) self.reg.request.headers["Authorization"] = self.auth - self.reg.delete("simplepush", "test", uaid=dummy_uaid) + self.reg.delete("simplepush", "test", dummy_uaid) return self.finish_deferred def test_delete_bad_uaid(self): @@ -1689,7 +1798,7 @@ def handle_finish(value): self.reg.set_status.assert_called_with(401) self.finish_deferred.addCallback(handle_finish) - self.reg.delete("test", "test", uaid="invalid") + self.reg.delete("test", "test", "invalid") return self.finish_deferred def test_delete_orphans(self): @@ -1701,7 +1810,7 @@ def handle_finish(value): self.router_mock.drop_user = Mock() self.router_mock.drop_user.return_value = False self.finish_deferred.addCallback(handle_finish) - self.reg.delete("test", "test", uaid=dummy_uaid) + self.reg.delete("test", "test", dummy_uaid) return self.finish_deferred def test_delete_bad_auth(self, *args): @@ -1711,7 +1820,7 @@ def handle_finish(value): self.reg.set_status.assert_called_with(401) self.finish_deferred.addCallback(handle_finish) - self.reg.delete("test", "test", uaid=dummy_uaid) + self.reg.delete("test", "test", dummy_uaid) return self.finish_deferred def test_delete_bad_router(self): @@ -1721,7 +1830,7 @@ def handle_finish(value): self.reg.set_status.assert_called_with(400) self.finish_deferred.addCallback(handle_finish) - self.reg.delete("invalid", "test", uaid=dummy_uaid) + self.reg.delete("invalid", "test", dummy_uaid) return self.finish_deferred def test_validate_auth(self): diff --git a/autopush/tests/test_integration.py b/autopush/tests/test_integration.py index e8dbf4e8..b6ebbc5f 100644 --- a/autopush/tests/test_integration.py +++ b/autopush/tests/test_integration.py @@ -90,7 +90,7 @@ def _get_vapid(key=None, payload=None): key = ecdsa.SigningKey.generate(curve=ecdsa.NIST256p) vk = key.get_verifying_key() auth = jws.sign(payload, key, algorithm="ES256").strip('=') - crypto_key = base64.urlsafe_b64encode(vk.to_string()).strip('=') + crypto_key = base64.urlsafe_b64encode('\4' + vk.to_string()).strip('=') return {"auth": auth, "crypto-key": crypto_key, "key": key} @@ -146,9 +146,11 @@ def hello(self, uaid=None): eq_(result["status"], 200) return result - def register(self, chid=None): + def register(self, chid=None, key=None): chid = chid or str(uuid.uuid4()) - msg = json.dumps(dict(messageType="register", channelID=chid)) + msg = json.dumps(dict(messageType="register", + channelID=chid, + key=key)) log.debug("Send: %s", msg) self.ws.send(msg) result = json.loads(self.ws.recv()) @@ -186,20 +188,17 @@ def delete_notification(self, channel, message=None, status=204): def send_notification(self, channel=None, version=None, data=None, use_header=True, status=None, ttl=200, - timeout=0.2, vapid=False): + timeout=0.2, vapid=None): if not channel: channel = random.choice(self.channels.keys()) endpoint = self.channels[channel] url = urlparse.urlparse(endpoint) http = None - vapid_info = None if url.scheme == "https": # pragma: nocover http = httplib.HTTPSConnection(url.netloc) else: http = httplib.HTTPConnection(url.netloc) - if vapid: - vapid_info = _get_vapid() if self.use_webpush: headers = {} @@ -214,9 +213,9 @@ def send_notification(self, channel=None, version=None, data=None, }) if vapid: headers.update({ - "Authorization": "Bearer " + vapid_info.get('auth') + "Authorization": "Bearer " + vapid.get('auth') }) - ckey = 'p256ecdsa="' + vapid_info.get('crypto-key') + '"' + ckey = 'p256ecdsa="' + vapid.get('crypto-key') + '"' headers.update({ 'Crypto-Key': headers.get('Crypto-Key') + ';' + ckey }) @@ -362,7 +361,8 @@ def setUp(self): # Endpoint HTTP router site = cyclone.web.Application([ - (r"/push/([^\/]+)", EndpointHandler, dict(ap_settings=settings)), + (r"/push/(v\d+)?/?([^\/]+)", EndpointHandler, + dict(ap_settings=settings)), (r"/m/([^\/]+)", MessageHandler, dict(ap_settings=settings)), # PUT /register/ => connect info # GET /register/uaid => chid + endpoint @@ -693,7 +693,8 @@ def test_basic_delivery(self): def test_basic_delivery_with_vapid(self): data = str(uuid.uuid4()) client = yield self.quick_register(use_webpush=True) - result = yield client.send_notification(data=data, vapid=True) + vapid_info = _get_vapid() + result = yield client.send_notification(data=data, vapid=vapid_info) eq_(result["headers"]["encryption"], client._crypto_key) eq_(result["data"], urlsafe_b64encode(data)) eq_(result["messageType"], "notification") @@ -1250,6 +1251,34 @@ def test_webpush_monthly_rotation_no_channels(self): yield self.shut_down(client) + @inlineCallbacks + def test_with_key(self): + private_key = ecdsa.SigningKey.generate(curve=ecdsa.NIST256p) + claims = {"aud": "http://example.com", + "exp": int(time.time()) + 86400, + "sub": "a@example.com"} + vapid = _get_vapid(private_key, claims) + pk_hex = vapid['crypto-key'] + chid = str(uuid.uuid4()) + client = Client("ws://localhost:9010/", use_webpush=True) + yield client.connect() + yield client.hello() + yield client.register(chid=chid, key=pk_hex) + # check that the client actually registered the key. + + # Send an update with a properly formatted key. + yield client.send_notification(vapid=vapid) + + # now try an invalid key. + new_key = ecdsa.SigningKey.generate(curve=ecdsa.NIST256p) + vapid = _get_vapid(new_key, claims) + + yield client.send_notification( + vapid=vapid, + status=400) + + yield self.shut_down(client) + class TestHealth(IntegrationBase): @inlineCallbacks diff --git a/autopush/tests/test_websocket.py b/autopush/tests/test_websocket.py index e8f72e07..6fbac597 100644 --- a/autopush/tests/test_websocket.py +++ b/autopush/tests/test_websocket.py @@ -1,6 +1,7 @@ import json import time import uuid +from hashlib import sha256 import twisted.internet.base from boto.dynamodb2.exceptions import ( @@ -475,7 +476,7 @@ def test_hello_failure(self): router = self.proto.ap_settings.router router.table.connection.update_item = Mock(side_effect=KeyError) - self._send_message(dict(messageType="hello", channelIDs=[])) + self._send_message(dict(messageType="hello", channelIDs=[], stop=1)) def check_result(msg): eq_(msg["status"], 503) @@ -761,7 +762,7 @@ def check_result(result): def test_register(self): self._connect() - self._send_message(dict(messageType="hello", channelIDs=[])) + self._send_message(dict(messageType="hello", channelIDs=[], stop=1)) d = Deferred() d.addCallback(lambda x: True) @@ -799,6 +800,39 @@ def check_register_result(msg): res.addCallback(check_register_result) return d + def test_register_webpush_with_key(self): + self._connect() + self.proto.ps.use_webpush = True + chid = str(uuid.uuid4()) + self.proto.ps.uaid = str(uuid.uuid4()) + self.proto.ap_settings.message.register_channel = Mock() + test_key = "SomeRandomCryptoKeyString" + test_sha = sha256(test_key).hexdigest() + test_endpoint = ('http://localhost/push/v2/' + + self.proto.ps.uaid.replace('-', '') + + chid.replace('-', '') + + test_sha) + self.proto.sendJSON = Mock() + + def echo(str): + return str.encode('hex') + + self.proto.ap_settings.fernet.encrypt = echo + + d = Deferred() + + def check_register_result(msg, test_endpoint): + eq_(test_endpoint, + self.proto.sendJSON.call_args[0][0]['pushEndpoint']) + assert self.proto.ap_settings.message.register_channel.called + d.callback(True) + + res = self.proto.process_register( + dict(channelID=chid, + key=test_key)) + res.addCallback(check_register_result, test_endpoint) + return d + def test_register_no_chid(self): self._connect() self._send_message(dict(messageType="hello", channelIDs=[])) @@ -1098,37 +1132,6 @@ def test_notification_avoid_newer_delivery(self): args = self.send_mock.call_args eq_(args, None) - def test_notification_retains_no_dash(self): - self._connect() - - uaid = str(uuid.uuid4()).replace('-', '') - chid = str(uuid.uuid4()).replace('-', '') - - storage = self.proto.ap_settings.storage - storage.save_notification(uaid, chid, 10) - self._send_message(dict(messageType="hello", channelIDs=[], uaid=uaid)) - - d = Deferred() - - def check_notif_result(msg): - eq_(msg["messageType"], "notification") - updates = msg["updates"] - eq_(len(updates), 1) - eq_(updates[0]["channelID"], chid) - eq_(updates[0]["version"], 10) - d.callback(True) - - def check_hello_result(msg): - eq_(msg["status"], 200) - - # Now wait for the notification - nd = self._check_response(check_notif_result) - nd.addErrback(lambda x: d.errback(x)) - - f = self._check_response(check_hello_result) - f.addErrback(lambda x: d.errback(x)) - return d - def test_ack(self): patcher = patch('autopush.websocket.log', spec=True) mock_log = patcher.start() diff --git a/autopush/websocket.py b/autopush/websocket.py index cccfb6b4..56768fb8 100644 --- a/autopush/websocket.py +++ b/autopush/websocket.py @@ -1000,7 +1000,7 @@ def process_register(self, data): self.transport.pauseProducing() d = self.deferToThread(self.ap_settings.make_endpoint, self.ps.uaid, - chid) + chid, data.get("key")) d.addCallback(self.finish_register, chid) d.addErrback(self.trap_cancel) d.addErrback(self.error_register)