From a2d4444ecda5c5577a9524bb73177d1ded1edb3a Mon Sep 17 00:00:00 2001 From: Ben Bangert Date: Sun, 25 Sep 2016 19:08:30 -0700 Subject: [PATCH] feat: add webpush topics Add's webpush topics with versioned sort key. Closes #643 --- autopush/db.py | 186 ++++++--------- autopush/endpoint.py | 118 +++++----- autopush/router/apnsrouter.py | 2 +- autopush/router/webpush.py | 52 ++--- autopush/tests/test_db.py | 148 +++--------- autopush/tests/test_endpoint.py | 154 ++++++++----- autopush/tests/test_integration.py | 85 ++++--- autopush/tests/test_router.py | 134 +++++++---- autopush/tests/test_web_validation.py | 41 +++- autopush/tests/test_websocket.py | 43 ++-- autopush/utils.py | 317 ++++++++++++++++++++++++++ autopush/web/base.py | 9 +- autopush/web/simplepush.py | 3 +- autopush/web/validation.py | 41 ++-- autopush/web/webpush.py | 14 +- autopush/websocket.py | 85 ++----- base-requirements.txt | 3 +- docs/http.rst | 1 + 18 files changed, 872 insertions(+), 564 deletions(-) diff --git a/autopush/db.py b/autopush/db.py index 55853787..a2ad8ef7 100644 --- a/autopush/db.py +++ b/autopush/db.py @@ -1,4 +1,34 @@ -"""Database Interaction""" +"""Database Interaction + +WebPush Sort Keys +----------------- + +Messages for WebPush are stored using a partition key + sort key, originally +the sort key was: + + CHID : Encrypted(UAID: CHID) + +The encrypted portion was returned as the Location to the Application Server. +Decrypting it resulted in enough information to create the sort key so that +the message could be deleted and located again. + +For WebPush Topic messages, a new scheme was needed since the only way to +locate the prior message is the UAID + CHID + Topic. Using Encryption in +the sort key is therefore not useful since it would change every update. + +The sort key scheme for WebPush messages is: + + VERSION : CHID : TOPIC + +To ensure updated messages are not deleted, each message will still have an +update-id key/value in its item. + +Non-versioned messages are assumed to be original messages from before this +scheme was adopted. + +``VERSION`` is a 2-digit 0-padded number, starting at 01 for Topic messages. + +""" from __future__ import absolute_import import datetime @@ -19,7 +49,11 @@ from boto.dynamodb2.types import NUMBER from autopush.exceptions import AutopushException -from autopush.utils import generate_hash +from autopush.utils import ( + generate_hash, + normalize_id, + WebPushNotification, +) key_hash = "" TRACK_DB_CALLS = False @@ -63,16 +97,6 @@ def dump_uaid(uaid_data): return repr(uaid_data) -def normalize_id(ident): - if (len(ident) == 36 and - ident[8] == ident[13] == ident[18] == ident[23] == '-'): - return ident.lower() - raw = filter(lambda x: x in '0123456789abcdef', ident.lower()) - if len(raw) != 32: - raise ValueError("Invalid UUID") - return '-'.join((raw[:8], raw[8:12], raw[12:16], raw[16:20], raw[20:])) - - def make_rotating_tablename(prefix, delta=0, date=None): """Creates a tablename for table rotation based on a prefix with a given month delta.""" @@ -417,130 +441,62 @@ def save_channels(self, uaid, channels): ), overwrite=True) @track_provisioned - def store_message(self, uaid, channel_id, message_id, ttl, data=None, - headers=None, timestamp=None): - """Stores a message in the message table for the given uaid/channel - with the message id""" + def store_message(self, notification): + """Stores a WebPushNotification in the message table + + :type notification: WebPushNotification + :type timestamp: int + + """ item = dict( - uaid=hasher(uaid), - chidmessageid="%s:%s" % (normalize_id(channel_id), message_id), - data=data, - headers=headers, - ttl=ttl, - timestamp=timestamp or int(time.time()), - updateid=uuid.uuid4().hex + uaid=hasher(notification.uaid.hex), + chidmessageid=notification.sort_key, + data=notification.data, + headers=notification.headers, + ttl=notification.ttl, + timestamp=notification.timestamp, + updateid=notification.update_id ) - if data: - item["headers"] = headers - item["data"] = data self.table.put_item(data=item, overwrite=True) return True @track_provisioned - def update_message(self, uaid, channel_id, message_id, ttl, data=None, - headers=None, timestamp=None): - """Updates a message in the message table for the given uaid/channel - /message_id. + def delete_message(self, notification): + """Deletes a specific message - If the message is not present, False is returned. + :type notification: WebPushNotification """ - conn = self.table.connection - item = dict( - ttl=ttl, - timestamp=timestamp or int(time.time()), - updateid=uuid.uuid4().hex - ) - if data: - item["headers"] = headers - item["data"] = data - try: - chidmessageid = "%s:%s" % (normalize_id(channel_id), message_id) - db_key = self.encode({"uaid": hasher(uaid), - "chidmessageid": chidmessageid}) - expr = ("SET #tl=:ttl, #ts=:timestamp," - " updateid=:updateid") - if data: - expr += ", #dd=:data, headers=:headers" - else: - expr += " REMOVE #dd, headers" - expr_values = self.encode({":%s" % k: v for k, v in item.items()}) - conn.update_item( - self.table.table_name, - db_key, - condition_expression="attribute_exists(updateid)", - update_expression=expr, - expression_attribute_names={"#tl": "ttl", - "#ts": "timestamp", - "#dd": "data"}, - expression_attribute_values=expr_values, - ) - except ConditionalCheckFailedException: - return False - return True - - @track_provisioned - def delete_message(self, uaid, channel_id, message_id, updateid=None): - """Deletes a specific message""" - if updateid: + if notification.update_id: try: self.table.delete_item( - uaid=hasher(uaid), - chidmessageid="%s:%s" % (normalize_id(channel_id), - message_id), - expected={'updateid__eq': updateid}) + uaid=hasher(notification.uaid.hex), + chidmessageid=notification.sort_key, + expected={'updateid__eq': notification.update_id}) except ConditionalCheckFailedException: return False else: self.table.delete_item( - uaid=hasher(uaid), - chidmessageid="%s:%s" % (normalize_id(channel_id), - message_id)) + uaid=hasher(notification.uaid.hex), + chidmessageid=notification.sort_key, + ) return True - def delete_messages(self, uaid, chidmessageids): - with self.table.batch_write() as batch: - for chidmessageid in chidmessageids: - if chidmessageid: - batch.delete_item( - uaid=hasher(uaid), - chidmessageid=chidmessageid - ) - @track_provisioned - def delete_messages_for_channel(self, uaid, channel_id): - """Deletes all messages for a uaid/channel_id""" - results = self.table.query_2( - uaid__eq=hasher(uaid), - chidmessageid__beginswith="%s:" % normalize_id(channel_id), - consistent=True, - attributes=("chidmessageid",), - ) - chidmessageids = [x["chidmessageid"] for x in results] - if chidmessageids: - self.delete_messages(uaid, chidmessageids) - return len(chidmessageids) > 0 + def fetch_messages(self, uaid, limit=10): + """Fetches messages for a uaid - @track_provisioned - def delete_user(self, uaid): - """Deletes all messages and channel info for a given uaid""" - results = self.table.query_2( - uaid__eq=hasher(uaid), - chidmessageid__gte=" ", - consistent=True, - attributes=("chidmessageid",), - ) - chidmessageids = [x["chidmessageid"] for x in results] - if chidmessageids: - self.delete_messages(uaid, chidmessageids) + :type uaid: uuid.UUID + :type limit: int - @track_provisioned - def fetch_messages(self, uaid, limit=10): - """Fetches messages for a uaid""" + """ # Eagerly fetches all results in the result set. - return list(self.table.query_2(uaid__eq=hasher(uaid), - chidmessageid__gt=" ", - consistent=True, limit=limit)) + results = list(self.table.query_2(uaid__eq=hasher(uaid.hex), + chidmessageid__gt=" ", + consistent=True, limit=limit)) + return [ + WebPushNotification.from_message_table(uaid, x) for x in results + ] class Router(object): diff --git a/autopush/endpoint.py b/autopush/endpoint.py index 4758fed5..65cddaf1 100644 --- a/autopush/endpoint.py +++ b/autopush/endpoint.py @@ -27,9 +27,8 @@ import uuid import re -from collections import namedtuple - import cyclone.web +from attr import attrs, attrib from boto.dynamodb2.exceptions import ( ItemNotFound, ProvisionedThroughputExceededException, @@ -55,6 +54,7 @@ validate_uaid, extract_jwt, base64url_encode, + WebPushNotification, ) from autopush.web.base import DEFAULT_ERR_URL from autopush.websocket import ms_time @@ -62,6 +62,7 @@ # Our max TTL is 60 days realistically with table rotation, so we hard-code it MAX_TTL = 60 * 60 * 24 * 60 +VALID_BASE64_URL = re.compile(r'^[0-9A-Za-z\-_]+=*$') VALID_TTL = re.compile(r'^\d+$') AUTH_SCHEMES = ["bearer", "webpush"] PREF_SCHEME = "webpush" @@ -79,9 +80,12 @@ } -class Notification(namedtuple("Notification", - "version data channel_id headers ttl")): +@attrs +class Notification(object): """Parsed notification from the request""" + version = attrib() + data = attrib() + channel_id = attrib() def parse_request_params(request): @@ -388,18 +392,6 @@ class MessageHandler(AutoendpointHandler): cors_methods = "DELETE" cors_response_headers = ("location",) - def _token_valid(self, result, func): - """Handles valid token processing, then dispatches to supplied - function""" - info = result.split(":") - if len(info) != 3: - raise InvalidTokenException("Wrong message token components") - - kind, uaid, chid = info - if kind != 'm': - raise InvalidTokenException("Wrong message token kind") - return func(kind, uaid, chid) - @cyclone.web.asynchronous def delete(self, token): """Drops a pending message. @@ -409,20 +401,21 @@ def delete(self, token): yet, will not be dropped. """ - self.version = self._client_info['version'] = token - d = deferToThread(self.ap_settings.fernet.decrypt, - self.version.encode('utf8')) - d.addCallback(self._token_valid, self._delete_message) + message_id = token.encode('utf8') + self.version = self._client_info['version'] = message_id + d = deferToThread(self._delete_message, message_id) + d.addCallback(self._delete_completed) d.addErrback(self._token_err) + self._db_error_handling(d) d.addErrback(self._response_err) return d - def _delete_message(self, kind, uaid, chid): - d = deferToThread(self.ap_settings.message.delete_message, uaid, - chid, self.version) - d.addCallback(self._delete_completed) - self._db_error_handling(d) - return d + def _delete_message(self, message_id): + notif = WebPushNotification.from_message_id( + message_id, + fernet=self.ap_settings.fernet, + ) + return self.ap_settings.message.delete_message(notif) def _delete_completed(self, response): self.log.info(format="Message Deleted", status_code=204, @@ -438,9 +431,6 @@ class EndpointHandler(AutoendpointHandler): "encryption-key", "content-type", "authorization") cors_response_headers = ("location", "www-authenticate") - # Remove trailing padding characters from complex header items like - # Crypto-Key and Encryption - strip_padding = re.compile('=+(?=[,;]|$)') ############################################################# # Cyclone HTTP Methods @@ -514,6 +504,7 @@ def _uaid_lookup_results(self, uaid_data): # Only simplepush uses version/data out of body/query, GCM/APNS will # use data out of the request body 'WebPush' style. use_simplepush = router_key == "simplepush" + topic = self.request.headers.get("topic") if use_simplepush: self.version, data = parse_request_params(self.request) self._client_info['message_id'] = self.version @@ -548,6 +539,21 @@ def _uaid_lookup_results(self, uaid_data): 401, 110, message="Encryption header missing 'salt' value") return + if topic: + if len(topic) > 32: + self._write_response( + 400, 113, message="Topic must be no greater than 32 " + "characters" + ) + return + + if not VALID_BASE64_URL.match(topic): + self._write_response( + 400, 113, message="Topic must be URL and Filename " + "safe Base64 alphabet" + ) + return + if VALID_TTL.match(self.request.headers.get("ttl", "0")): ttl = int(self.request.headers.get("ttl", "0")) # Cap the TTL to our MAX_TTL @@ -565,44 +571,47 @@ def _uaid_lookup_results(self, uaid_data): return if use_simplepush: - self._route_notification(self.version, uaid_data, data) + notification = Notification(version=self.version, data=data, + channel_id=self.chid) + self._route_notification(False, uaid_data, notification) return # Web Push and bridged messages are encrypted binary blobs. We store # and deliver these messages as Base64-encoded strings. data = base64url_encode(self.request.body) + notification = WebPushNotification(uaid=uuid.UUID(self.uaid), + channel_id=uuid.UUID(self.chid), + data=data, + headers=self.request.headers, + ttl=ttl, topic=topic) + if notification.data: + notification.cleanup_headers() + else: + notification.headers = None + # Generate a message ID, then route the notification. - d = deferToThread(self.ap_settings.fernet.encrypt, ':'.join([ - 'm', self.uaid, self.chid]).encode('utf8')) - d.addCallback(self._route_notification, uaid_data, data, ttl) + d = deferToThread(notification.generate_message_id, + self.ap_settings.fernet) + d.addCallback(self._route_notification, uaid_data, notification) return d - def _route_notification(self, version, uaid_data, data, ttl=None): - self.version = self._client_info['message_id'] = version - warning = "" - # Clean up the header values (remove padding) - for hdr in ['crypto-key', 'encryption']: - if self.strip_padding.search(self.request.headers.get(hdr, "")): - warning = ("Padded content detected. Please strip" - " base64 encoding padding.") - head = self.request.headers[hdr].replace('"', '') - self.request.headers[hdr] = self.strip_padding.sub("", head) - notification = Notification(version=version, data=data, - channel_id=self.chid, - headers=self.request.headers, - ttl=ttl) - + def _route_notification(self, webpush_message_id, uaid_data, notification): + if webpush_message_id: + self.version = self._client_info['message_id'] = webpush_message_id + else: + self.version = self._client_info['message_id'] = \ + notification.version d = Deferred() d.addCallback(self.router.route_notification, uaid_data) - d.addCallback(self._router_completed, uaid_data, warning) + d.addCallback(self._router_completed, uaid_data) d.addErrback(self._router_fail_err) d.addErrback(self._response_err) # Call the prepared router d.callback(notification) - def _router_completed(self, response, uaid_data, warning=""): + def _router_completed(self, response, uaid_data): """Called after router has completed successfully""" # TODO: Add some custom wake logic here # Were we told to update the router data? @@ -629,8 +638,7 @@ def _router_completed(self, response, uaid_data, warning=""): uaid_data) response.router_data = None d.addCallback(lambda x: self._router_completed(response, - uaid_data, - warning)) + uaid_data)) return d else: # No changes are requested by the bridge system, proceed as normal @@ -642,8 +650,7 @@ def _router_completed(self, response, uaid_data, warning=""): client_info=self._client_info) time_diff = time.time() - self.start_time self.metrics.timing("updates.handled", duration=time_diff) - response.response_body = ( - response.response_body + " " + warning).strip() + response.response_body = (response.response_body).strip() self._router_response(response) @@ -753,13 +760,10 @@ def put(self, router_type="", router_token="", uaid="", chid=""): def _delete_channel(self, uaid, chid): message = self.ap_settings.message - message.delete_messages_for_channel(uaid, chid) if not message.unregister_channel(uaid, chid): raise ItemNotFound("ChannelID not found") def _delete_uaid(self, uaid, router): - message = self.ap_settings.message - message.delete_user(uaid) self.log.info(format="Dropping User", code=101, uaid_hash=hasher(uaid)) if not router.drop_user(uaid): diff --git a/autopush/router/apnsrouter.py b/autopush/router/apnsrouter.py index 13a3ee4a..9605beff 100644 --- a/autopush/router/apnsrouter.py +++ b/autopush/router/apnsrouter.py @@ -108,7 +108,7 @@ def _route(self, notification, router_data): config = self._config[rel_channel] apns_client = self.apns[rel_channel] payload = { - "chid": notification.channel_id, + "chid": str(notification.channel_id), "ver": notification.version, } if notification.data: diff --git a/autopush/router/webpush.py b/autopush/router/webpush.py index 816747a5..676950eb 100644 --- a/autopush/router/webpush.py +++ b/autopush/router/webpush.py @@ -33,7 +33,7 @@ class WebPushRouter(SimpleRouter): def delivered_response(self, notification): location = "%s/m/%s" % (self.ap_settings.endpoint_url, - notification.version) + notification.location) return RouterResponse(status_code=201, response_body="", headers={"Location": location, "TTL": notification.ttl or 0}, @@ -41,31 +41,21 @@ def delivered_response(self, notification): def stored_response(self, notification): location = "%s/m/%s" % (self.ap_settings.endpoint_url, - notification.version) + notification.location) return RouterResponse(status_code=201, response_body="", headers={"Location": location, "TTL": notification.ttl}, logged_status=202) - def _crypto_headers(self, notification): - """Creates a dict of the crypto headers for this request.""" - headers = notification.headers - data = dict( - encoding=headers["content-encoding"], - encryption=headers["encryption"], - ) - # AWS cannot store empty strings, so we only add these keys if - # they're present to avoid empty strings. - for name in ["encryption-key", "crypto-key"]: - if name in headers: - # NOTE: The client code expects all header keys to be lower - # case and s/-/_/. - data[name.lower().replace("-", "_")] = headers[name] - return data - @inlineCallbacks def preflight_check(self, uaid_data, channel_id): - """Verifies this routing call can be done successfully""" + """Verifies this routing call can be done successfully + + :type uaid_data: dict + :type channel_id: uuid.UUID + + """ + channel_id = normalize_id(channel_id.hex) uaid = uaid_data["uaid"] if 'current_month' not in uaid_data: self.log.info(format="Dropping User", code=102, @@ -101,16 +91,11 @@ def _send_notification(self, uaid, node_id, notification): This version of the overriden method includes the necessary crypto headers for the notification. + :type notification: autopush.utils.WebPushNotification + """ - # Firefox currently requires channelIDs to be '-' formatted. - payload = {"channelID": normalize_id(notification.channel_id), - "version": notification.version, - "ttl": notification.ttl or 0, - "timestamp": int(time.time()), - } - if notification.data: - payload["headers"] = self._crypto_headers(notification) - payload["data"] = notification.data + payload = notification.serialize() + payload["timestamp"] = int(time.time()) url = node_id + "/push/" + uaid d = self.ap_settings.agent.request( "PUT", @@ -146,18 +131,9 @@ def _save_notification(self, uaid, notification, month_table): headers={"TTL": str(notification.ttl), "Location": location}, logged_status=204) - headers = None - if notification.data: - headers = self._crypto_headers(notification) return deferToThread( self.ap_settings.message_tables[month_table].store_message, - uaid=uaid, - channel_id=notification.channel_id, - data=notification.data, - headers=headers, - message_id=notification.version, - ttl=notification.ttl, - timestamp=int(time.time()), + notification=notification, ) def amend_msg(self, msg, router_data=None): diff --git a/autopush/tests/test_db.py b/autopush/tests/test_db.py index f9aa1364..eb26fc8a 100644 --- a/autopush/tests/test_db.py +++ b/autopush/tests/test_db.py @@ -26,6 +26,7 @@ Router, ) from autopush.metrics import SinkMetrics +from autopush.utils import WebPushNotification dummy_uaid = str(uuid.UUID("abad1dea00000000aabbccdd00000000")) @@ -42,6 +43,17 @@ def tearDown(): tearDown() +def make_webpush_notification(uaid, chid, ttl=100): + message_id = str(uuid.uuid4()) + return WebPushNotification( + uaid=uuid.UUID(uaid), + channel_id=uuid.UUID(chid), + update_id=message_id, + message_id=message_id, + ttl=ttl, + ) + + class DbCheckTestCase(unittest.TestCase): def test_preflight_check_fail(self): router = Router(get_router_table(), SinkMetrics()) @@ -214,9 +226,6 @@ def setUp(self): def tearDown(self): self.real_table.connection = self.real_connection - def _nstime(self): - return int(time.time() * 1000 * 1000) - def test_register(self): chid = str(uuid.uuid4()) m = get_rotating_message_table() @@ -301,106 +310,37 @@ def test_message_storage(self): message.register_channel(self.uaid, chid) message.register_channel(self.uaid, chid2) - data1 = str(uuid.uuid4()) - data2 = str(uuid.uuid4()) - ttl = int(time.time())+100 - time1, time2, time3 = self._nstime(), self._nstime(), self._nstime()+1 - message.store_message(self.uaid, chid, time1, ttl, data1, {}) - message.store_message(self.uaid, chid2, time2, ttl, data2, {}) - message.store_message(self.uaid, chid2, time3, ttl, data1, {}) + message.store_message(make_webpush_notification(self.uaid, chid)) + message.store_message(make_webpush_notification(self.uaid, chid)) + message.store_message(make_webpush_notification(self.uaid, chid)) - all_messages = list(message.fetch_messages(self.uaid)) + all_messages = list(message.fetch_messages(uuid.UUID(self.uaid))) eq_(len(all_messages), 3) - message.delete_messages_for_channel(self.uaid, chid2) - all_messages = list(message.fetch_messages(self.uaid)) - eq_(len(all_messages), 1) - - message.delete_message(self.uaid, chid, time1) - all_messages = list(message.fetch_messages(self.uaid)) - eq_(len(all_messages), 0) - def test_message_storage_overwrite(self): """Test that store_message can overwrite existing messages which can occur in some reconnect cases but shouldn't error""" chid = str(uuid.uuid4()) chid2 = str(uuid.uuid4()) + notif1 = make_webpush_notification(self.uaid, chid) + notif2 = make_webpush_notification(self.uaid, chid) + notif3 = make_webpush_notification(self.uaid, chid2) + notif2.message_id = notif1.message_id m = get_rotating_message_table() message = Message(m, SinkMetrics()) message.register_channel(self.uaid, chid) message.register_channel(self.uaid, chid2) - data1 = str(uuid.uuid4()) - data2 = str(uuid.uuid4()) - ttl = int(time.time())+100 - time1, time2 = self._nstime(), self._nstime()+1 - message.store_message(self.uaid, chid, time1, ttl, data1, {}) - message.store_message(self.uaid, chid, time1, ttl, data2, {}) - message.store_message(self.uaid, chid2, time2, ttl, data1, {}) + message.store_message(notif1) + message.store_message(notif2) + message.store_message(notif3) - all_messages = list(message.fetch_messages(self.uaid)) + all_messages = list(message.fetch_messages(uuid.UUID(self.uaid))) eq_(len(all_messages), 2) - message.delete_messages_for_channel(self.uaid, chid2) - all_messages = list(message.fetch_messages(self.uaid)) - eq_(len(all_messages), 1) - - message.delete_message(self.uaid, chid, time1) - all_messages = list(message.fetch_messages(self.uaid)) - eq_(len(all_messages), 0) - - def test_delete_user(self): - chid = str(uuid.uuid4()) - chid2 = str(uuid.uuid4()) - m = get_rotating_message_table() - message = Message(m, SinkMetrics()) - message.register_channel(self.uaid, chid) - message.register_channel(self.uaid, chid2) - - data1 = str(uuid.uuid4()) - data2 = str(uuid.uuid4()) - ttl = int(time.time())+100 - time1, time2, time3 = self._nstime(), self._nstime(), self._nstime()+1 - message.store_message(self.uaid, chid, time1, ttl, data1, {}) - message.store_message(self.uaid, chid2, time2, ttl, data2, {}) - message.store_message(self.uaid, chid2, time3, ttl, data1, {}) - - message.delete_user(self.uaid) - all_messages = list(message.fetch_messages(self.uaid)) - eq_(len(all_messages), 0) - - def test_message_delete_pagination(self): - def make_messages(channel_id, count, rtable): - t = self._nstime() - ttl = int(time.time())+200 - for i in range(count): - rtable.append( - (self.uaid, channel_id, str(uuid.uuid4()), ttl, {}, t+i) - ) - return rtable - - chid = str(uuid.uuid4()) - m = get_rotating_message_table() - message = Message(m, SinkMetrics()) - message.register_channel(self.uaid, chid) - - # Shove 80 messages in - m = [] - for message_args in make_messages(chid, 80, m): - message.store_message(*message_args) - - # Verify we can see them all - all_messages = list(message.fetch_messages(self.uaid, limit=100)) - eq_(len(all_messages), 80) - - # Delete them all - message.delete_messages_for_channel(self.uaid, chid) - - # Verify they're gone - all_messages = list(message.fetch_messages(self.uaid, limit=100)) - eq_(len(all_messages), 0) - def test_message_delete_fail_condition(self): + notif = make_webpush_notification(dummy_uaid, dummy_chid) + notif.message_id = notif.update_id = dummy_uaid m = get_rotating_message_table() message = Message(m, SinkMetrics()) @@ -409,43 +349,9 @@ def raise_condition(*args, **kwargs): message.table = Mock() message.table.delete_item.side_effect = raise_condition - result = message.delete_message(uaid=dummy_uaid, channel_id=dummy_chid, - message_id="asdf", updateid="asdf") + result = message.delete_message(notif) eq_(result, False) - def test_update_message(self): - chid = uuid.uuid4().hex - m = get_rotating_message_table() - message = Message(m, SinkMetrics()) - data1 = str(uuid.uuid4()) - data2 = str(uuid.uuid4()) - time1 = self._nstime() - ttl = self._nstime()+1000 - message.store_message(self.uaid, chid, time1, ttl, data1, {}) - message.update_message(self.uaid, chid, time1, ttl, data2, {}) - messages = list(message.fetch_messages(self.uaid)) - eq_(data2, messages[0]['data']) - - def test_update_message_fail(self): - message = Message(get_rotating_message_table(), SinkMetrics) - message.store_message(self.uaid, - uuid.uuid4().hex, - self._nstime(), - str(uuid.uuid4()), - {}) - u = message.table.connection.update_item = Mock() - - def raise_condition(*args, **kwargs): - raise ConditionalCheckFailedException(None, None) - - u.side_effect = raise_condition - b = message.update_message(self.uaid, - uuid.uuid4().hex, - self._nstime(), - str(uuid.uuid4()), - {}) - eq_(b, False) - class RouterTestCase(unittest.TestCase): def setUp(self): diff --git a/autopush/tests/test_endpoint.py b/autopush/tests/test_endpoint.py index d6f294c9..aa8d13f2 100644 --- a/autopush/tests/test_endpoint.py +++ b/autopush/tests/test_endpoint.py @@ -37,6 +37,7 @@ from autopush.exceptions import InvalidTokenException, VapidAuthException from autopush.settings import AutopushSettings from autopush.router.interface import IRouter, RouterResponse +from autopush.tests.test_db import make_webpush_notification from autopush.utils import ( generate_hash, decipher_public_key, @@ -114,7 +115,8 @@ def handle_finish(result): return self.finish_deferred def test_delete_token_wrong_kind(self): - self.fernet_mock.decrypt.return_value = "r:123:456" + tok = ":".join(["r", dummy_uaid, dummy_chid]) + self.fernet_mock.decrypt.return_value = tok def handle_finish(result): self.status_mock.assert_called_with(400, None) @@ -124,21 +126,49 @@ def handle_finish(result): return self.finish_deferred def test_delete_success(self): - self.fernet_mock.decrypt.return_value = "m:123:456" + tok = ":".join(["m", dummy_uaid, dummy_chid]) + self.fernet_mock.decrypt.return_value = tok self.message_mock.configure_mock(**{ "delete_message.return_value": True}) def handle_finish(result): - self.message_mock.delete_message.assert_called_with( - "123", "456", "123-456") + self.message_mock.delete_message.assert_called() self.status_mock.assert_called_with(204) self.finish_deferred.addCallback(handle_finish) self.message.delete("123-456") return self.finish_deferred + def test_delete_topic_success(self): + tok = ":".join(["01", dummy_uaid, dummy_chid, "Inbox"]) + self.fernet_mock.decrypt.return_value = tok + self.message_mock.configure_mock(**{ + "delete_message.return_value": True}) + + def handle_finish(result): + self.message_mock.delete_message.assert_called() + self.status_mock.assert_called_with(204) + self.finish_deferred.addCallback(handle_finish) + + self.message.delete("123-456") + return self.finish_deferred + + def test_delete_topic_error_parts(self): + tok = ":".join(["01", dummy_uaid, dummy_chid]) + self.fernet_mock.decrypt.return_value = tok + self.message_mock.configure_mock(**{ + "delete_message.return_value": True}) + + def handle_finish(result): + self.status_mock.assert_called_with(400, None) + self.finish_deferred.addCallback(handle_finish) + + self.message.delete("123-456") + return self.finish_deferred + def test_delete_db_error(self): - self.fernet_mock.decrypt.return_value = "m:123:456" + tok = ":".join(["m", dummy_uaid, dummy_chid]) + self.fernet_mock.decrypt.return_value = tok self.message_mock.configure_mock(**{ "delete_message.side_effect": ProvisionedThroughputExceededException(None, None)}) @@ -203,6 +233,7 @@ def setUp(self): self.status_mock = self.endpoint.set_status = Mock() self.write_mock = self.endpoint.write = Mock() self.endpoint.log = Mock(spec=Logger) + self.endpoint.uaid = dummy_uaid d = self.finish_deferred = Deferred() self.endpoint.finish = lambda: d.callback(True) @@ -279,6 +310,7 @@ def test_webpush_missing_ttl(self): frouter = Mock(spec=IRouter) frouter.route_notification = Mock() frouter.route_notification.return_value = RouterResponse() + self.endpoint.chid = dummy_chid self.endpoint.ap_settings.routers["webpush"] = frouter self.endpoint._uaid_lookup_results(dict(router_type="webpush")) @@ -406,13 +438,11 @@ def handle_finish(value): _, (notification, _), _ = calls[0] eq_(notification.headers.get('encryption'), 'keyid=p256;salt=stuff') - eq_(notification.headers.get('crypto-key'), + eq_(notification.headers.get('crypto_key'), 'keyid=spad;dh=AQ,p256ecdsa=Ag;foo=bar') - eq_(notification.channel_id, dummy_chid) + eq_(str(notification.channel_id), dummy_chid) eq_(notification.data, b"wyigoQ") self.endpoint.set_status.assert_called_with(200) - ok_('Padded content detected' in - self.endpoint.write.call_args[0][0]) self.finish_deferred.addCallback(handle_finish) return self.finish_deferred @@ -457,6 +487,44 @@ def handle_finish(value): self.finish_deferred.addCallback(handle_finish) return self.finish_deferred + def test_webpush_bad_topic_len(self): + fresult = dict(router_type="webpush") + frouter = self.settings.routers["webpush"] + frouter.route_notification.return_value = RouterResponse() + self.endpoint.chid = dummy_chid + self.request_mock.headers["topic"] = \ + "asdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdfasdf" + self.request_mock.body = b"" + self.endpoint._uaid_lookup_results(fresult) + + def handle_finish(value): + self.endpoint.set_status.assert_called_with(400, None) + self._check_error(code=400, errno=113, + message="Topic must be no greater than 32 " + "characters") + + self.finish_deferred.addCallback(handle_finish) + return self.finish_deferred + + def test_webpush_bad_topic_content(self): + fresult = dict(router_type="webpush") + frouter = self.settings.routers["webpush"] + frouter.route_notification.return_value = RouterResponse() + self.endpoint.chid = dummy_chid + self.request_mock.headers["topic"] = \ + "asdf:442;23^@*#$(O!4232" + self.request_mock.body = b"" + self.endpoint._uaid_lookup_results(fresult) + + def handle_finish(value): + self.endpoint.set_status.assert_called_with(400, None) + self._check_error(code=400, errno=113, + message="Topic must be URL and Filename " + "safe Base64 alphabet") + + self.finish_deferred.addCallback(handle_finish) + return self.finish_deferred + @patch('uuid.uuid4', return_value=uuid.UUID(dummy_request_id)) def test_init_info(self, t): d = self.endpoint._init_info() @@ -1917,74 +1985,56 @@ def handle_finish(value): self.reg.put(uaid=dummy_uaid) return self.finish_deferred - def test_delete_chid(self): + def test_delete_bad_chid_value(self): + notif = make_webpush_notification(dummy_uaid, dummy_chid) messages = self.reg.ap_settings.message messages.register_channel(dummy_uaid, dummy_chid) - messages.store_message( - dummy_uaid, - dummy_chid, - "1", - 10000) - chid2 = str(uuid.uuid4()) - messages.register_channel(dummy_uaid, chid2) - messages.store_message( - dummy_uaid, - chid2, - "2", - 10000) + messages.store_message(notif) self.reg.request.headers["Authorization"] = self.auth - def handle_finish(value, chid2): - ml = messages.fetch_messages(dummy_uaid) - cl = messages.all_channels(dummy_uaid) - eq_(len(ml), 1) - eq_((True, set([chid2])), cl) - messages.delete_user(dummy_uaid) + def handle_finish(value): + self._check_error(410, 106, "") - self.finish_deferred.addCallback(handle_finish, chid2) - self.reg.delete("simplepush", "test", dummy_uaid, dummy_chid) + self.finish_deferred.addCallback(handle_finish) + self.reg.delete("test", "test", dummy_uaid, "invalid") return self.finish_deferred - def test_delete_bad_chid(self): + def test_delete_no_such_chid(self): + notif = make_webpush_notification(dummy_uaid, dummy_chid) messages = self.reg.ap_settings.message messages.register_channel(dummy_uaid, dummy_chid) - messages.store_message( - dummy_uaid, - dummy_chid, - "1", - 10000) + messages.store_message(notif) + + # Moto can't handle set operations of this nature so we have + # to mock the reply + unreg = messages.unregister_channel + messages.unregister_channel = Mock(return_value=False) self.reg.request.headers["Authorization"] = self.auth def handle_finish(value): self._check_error(410, 106, "") - messages.delete_user(dummy_uaid) + + def fixup_messages(result): + messages.unregister_channel = unreg self.finish_deferred.addCallback(handle_finish) - self.reg.delete("test", "test", dummy_uaid, "invalid") + self.finish_deferred.addBoth(fixup_messages) + self.reg.delete("test", "test", dummy_uaid, str(uuid.uuid4())) return self.finish_deferred def test_delete_uaid(self): + notif = make_webpush_notification(dummy_uaid, dummy_chid) + notif2 = make_webpush_notification(dummy_uaid, dummy_chid) messages = self.reg.ap_settings.message chid2 = str(uuid.uuid4()) - messages.store_message( - dummy_uaid, - dummy_chid, - "1", - 10000) - messages.store_message( - dummy_uaid, - chid2, - "2", - 10000) + messages.store_message(notif) + messages.store_message(notif2) self.reg.ap_settings.router.drop_user = Mock() self.reg.ap_settings.router.drop_user.return_value = True def handle_finish(value, chid2): - ml = messages.fetch_messages(dummy_uaid) - eq_(len(ml), 0) # Note: Router is mocked, so the UAID is never actually - # dropped. Normally, this should messages.all_channels - # would come back as empty + # dropped. ok_(self.reg.ap_settings.router.drop_user.called) eq_(self.reg.ap_settings.router.drop_user.call_args_list[0][0], (dummy_uaid,)) diff --git a/autopush/tests/test_integration.py b/autopush/tests/test_integration.py index d1fc031f..59137bb8 100644 --- a/autopush/tests/test_integration.py +++ b/autopush/tests/test_integration.py @@ -187,7 +187,8 @@ def delete_notification(self, channel, message=None, status=204): def send_notification(self, channel=None, version=None, data=None, use_header=True, status=None, ttl=200, - timeout=0.2, vapid=None, endpoint=None): + timeout=0.2, vapid=None, endpoint=None, + topic=None): if not channel: channel = random.choice(self.channels.keys()) @@ -218,6 +219,8 @@ def send_notification(self, channel=None, version=None, data=None, headers.update({ 'Crypto-Key': headers.get('Crypto-Key') + ';' + ckey }) + if topic: + headers["Topic"] = topic body = data or "" method = "POST" status = status or 201 @@ -716,6 +719,57 @@ def test_basic_delivery(self): eq_(result["messageType"], "notification") yield self.shut_down(client) + @inlineCallbacks + def test_topic_basic_delivery(self): + data = str(uuid.uuid4()) + client = yield self.quick_register(use_webpush=True) + result = yield client.send_notification(data=data, topic="Inbox") + eq_(result["headers"]["encryption"], client._crypto_key) + eq_(result["data"], base64url_encode(data)) + eq_(result["messageType"], "notification") + yield self.shut_down(client) + + @inlineCallbacks + def test_topic_replacement_delivery(self): + data = str(uuid.uuid4()) + data2 = str(uuid.uuid4()) + client = yield self.quick_register(use_webpush=True) + yield client.disconnect() + yield client.send_notification(data=data, topic="Inbox") + yield client.send_notification(data=data2, topic="Inbox") + yield client.connect() + yield client.hello() + result = yield client.get_notification() + eq_(result["headers"]["encryption"], client._crypto_key) + eq_(result["data"], base64url_encode(data2)) + eq_(result["messageType"], "notification") + result = yield client.get_notification() + eq_(result, None) + yield self.shut_down(client) + + @inlineCallbacks + def test_topic_no_delivery_on_reconnect(self): + data = str(uuid.uuid4()) + client = yield self.quick_register(use_webpush=True) + yield client.disconnect() + yield client.send_notification(data=data, topic="Inbox") + yield client.connect() + yield client.hello() + result = yield client.get_notification(timeout=10) + eq_(result["headers"]["encryption"], client._crypto_key) + eq_(result["data"], base64url_encode(data)) + eq_(result["messageType"], "notification") + yield client.ack(result["channelID"], result["version"]) + yield client.disconnect() + yield client.connect() + yield client.hello() + result = yield client.get_notification() + eq_(result, None) + yield client.disconnect() + yield client.connect() + yield client.hello() + yield self.shut_down(client) + @inlineCallbacks def test_basic_delivery_v0_endpoint(self): data = str(uuid.uuid4()) @@ -863,29 +917,6 @@ def test_no_delivery_to_unregistered(self): eq_(result, None) yield self.shut_down(client) - @inlineCallbacks - def test_no_delivery_to_unregistered_on_reconnect(self): - data = str(uuid.uuid4()) - client = yield self.quick_register(use_webpush=True) - yield client.disconnect() - ok_(client.channels) - chan = client.channels.keys()[0] - yield client.send_notification(data=data) - yield client.connect() - yield client.hello() - result = yield client.get_notification() - eq_(result["channelID"], chan) - eq_(result["data"], base64url_encode(data)) - - yield client.unregister(chan) - yield client.disconnect() - time.sleep(1) - yield client.connect() - yield client.hello() - result = yield client.get_notification() - eq_(result, None) - yield self.shut_down(client) - @inlineCallbacks def test_ttl_not_present_not_connected(self): data = str(uuid.uuid4()) @@ -1096,7 +1127,8 @@ def test_webpush_monthly_rotation(self): # table data = uuid.uuid4().hex yield client.send_notification(data=data) - notifs = yield deferToThread(lm_message.fetch_messages, client.uaid) + notifs = yield deferToThread(lm_message.fetch_messages, + uuid.UUID(client.uaid)) eq_(len(notifs), 1) # Connect the client, verify the migration @@ -1188,7 +1220,8 @@ def test_webpush_monthly_rotation_prior_record_exists(self): # table data = uuid.uuid4().hex yield client.send_notification(data=data) - notifs = yield deferToThread(lm_message.fetch_messages, client.uaid) + notifs = yield deferToThread(lm_message.fetch_messages, + uuid.UUID(client.uaid)) eq_(len(notifs), 1) # Connect the client, verify the migration diff --git a/autopush/tests/test_router.py b/autopush/tests/test_router.py index be062517..0f1e9e7a 100644 --- a/autopush/tests/test_router.py +++ b/autopush/tests/test_router.py @@ -4,6 +4,7 @@ import time import json +from autopush.utils import WebPushNotification from mock import Mock, PropertyMock, patch from moto import mock_dynamodb2 from nose.tools import eq_, ok_, assert_raises @@ -127,8 +128,14 @@ def setUp(self, mt, mc): self.headers = {"content-encoding": "aesgcm", "encryption": "test", "encryption-key": "test"} - self.notif = Notification(10, "q60d6g", dummy_chid, self.headers, - 200) + self.notif = WebPushNotification( + uaid=uuid.UUID(dummy_uaid), + channel_id=uuid.UUID(dummy_chid), + data="q60d6g", + headers=self.headers, + ttl=200, + message_id=10, + ) self.router_data = dict(router_data=dict(token="connect_data", rel_channel="firefox")) @@ -238,7 +245,14 @@ def test_route_crypto_key(self): headers = {"content-encoding": "aesgcm", "encryption": "test", "crypto-key": "test"} - self.notif = Notification(10, "q60d6g", dummy_chid, headers, 200) + self.notif = WebPushNotification( + uaid=uuid.UUID(dummy_uaid), + channel_id=uuid.UUID(dummy_chid), + data="q60d6g", + headers=headers, + ttl=200, + message_id=10, + ) d = self.router.route_notification(self.notif, self.router_data) def check_results(result): @@ -271,8 +285,14 @@ def setUp(self, fgcm): "encryption": "test", "encryption-key": "test"} # Payloads are Base64-encoded. - self.notif = Notification(10, "q60d6g", dummy_chid, self.headers, - 200) + self.notif = WebPushNotification( + uaid=uuid.UUID(dummy_uaid), + channel_id=uuid.UUID(dummy_chid), + data="q60d6g", + headers=self.headers, + ttl=200, + message_id=10, + ) self.router_data = dict( router_data=dict( token="connect_data", @@ -349,11 +369,13 @@ def check_results(result): def test_ttl_none(self): self.router.gcm['test123'] = self.gcm - self.notif = Notification(version=10, - data="q60d6g", - channel_id=dummy_chid, - headers=self.headers, - ttl=None) + self.notif = WebPushNotification( + uaid=uuid.UUID(dummy_uaid), + channel_id=uuid.UUID(dummy_chid), + data="q60d6g", + headers=self.headers, + ttl=None + ) d = self.router.route_notification(self.notif, self.router_data) def check_results(result): @@ -376,11 +398,13 @@ def check_results(result): def test_ttl_high(self): self.router.gcm['test123'] = self.gcm - self.notif = Notification(version=10, - data="q60d6g", - channel_id=dummy_chid, - headers=self.headers, - ttl=5184000) + self.notif = WebPushNotification( + uaid=uuid.UUID(dummy_uaid), + channel_id=uuid.UUID(dummy_chid), + data="q60d6g", + headers=self.headers, + ttl=5184000 + ) d = self.router.route_notification(self.notif, self.router_data) def check_results(result): @@ -400,9 +424,13 @@ def check_results(result): def test_long_data(self): self.router.gcm['test123'] = self.gcm - bad_notif = Notification( - 10, "\x01abcdefghijklmnopqrstuvwxyz0123456789", dummy_chid, - self.headers, 200) + bad_notif = WebPushNotification( + uaid=uuid.UUID(dummy_uaid), + channel_id=uuid.UUID(dummy_chid), + data="\x01abcdefghijklmnopqrstuvwxyz0123456789", + headers=self.headers, + ttl=200 + ) d = self.router.route_notification(bad_notif, self.router_data) def check_results(result): @@ -537,8 +565,13 @@ def setUp(self, ffcm): "encryption": "test", "encryption-key": "test"} # Payloads are Base64-encoded. - self.notif = Notification(10, "q60d6g", dummy_chid, self.headers, - 200) + self.notif = WebPushNotification( + uaid=uuid.UUID(dummy_uaid), + channel_id=uuid.UUID(dummy_chid), + data="q60d6g", + headers=self.headers, + ttl=200 + ) self.router_data = dict( router_data=dict( token="connect_data", @@ -607,11 +640,13 @@ def check_results(result): def test_ttl_none(self): self.router.fcm = self.fcm - self.notif = Notification(version=10, - data="q60d6g", - channel_id=dummy_chid, - headers=self.headers, - ttl=None) + self.notif = WebPushNotification( + uaid=uuid.UUID(dummy_uaid), + channel_id=uuid.UUID(dummy_chid), + data="q60d6g", + headers=self.headers, + ttl=None + ) d = self.router.route_notification(self.notif, self.router_data) def check_results(result): @@ -631,11 +666,13 @@ def check_results(result): def test_ttl_high(self): self.router.fcm = self.fcm - self.notif = Notification(version=10, - data="q60d6g", - channel_id=dummy_chid, - headers=self.headers, - ttl=5184000) + self.notif = WebPushNotification( + uaid=uuid.UUID(dummy_uaid), + channel_id=uuid.UUID(dummy_chid), + data="q60d6g", + headers=self.headers, + ttl=5184000 + ) d = self.router.route_notification(self.notif, self.router_data) def check_results(result): @@ -655,9 +692,14 @@ def check_results(result): def test_long_data(self): self.router.fcm = self.fcm - bad_notif = Notification( - 10, "\x01abcdefghijklmnopqrstuvwxyz0123456789", dummy_chid, - self.headers, 200) + bad_notif = WebPushNotification( + uaid=uuid.UUID(dummy_uaid), + channel_id=uuid.UUID(dummy_chid), + data="\x01abcdefghijklmnopqrstuvwxyz0123456789", + headers=self.headers, + ttl=200, + message_id=10, + ) d = self.router.route_notification(bad_notif, self.router_data) def check_results(result): @@ -777,7 +819,7 @@ def setUp(self): self.router = SimpleRouter(settings, {}) self.router.log = Mock(spec=Logger) - self.notif = Notification(10, "data", dummy_chid, None, 200) + self.notif = Notification(10, "data", dummy_chid) mock_result = Mock(spec=gcmclient.gcm.Result) mock_result.canonical = dict() mock_result.failed = dict() @@ -1019,8 +1061,15 @@ def setUp(self): "crypto-key": "niftykey" } self.router = WebPushRouter(settings, {}) - self.notif = Notification("EncMessageId", "data", - dummy_chid, headers, 20) + self.notif = WebPushNotification( + uaid=uuid.UUID(dummy_uaid), + channel_id=uuid.UUID(dummy_chid), + data="data", + headers=headers, + ttl=20, + message_id=uuid.uuid4().hex, + ) + self.notif.cleanup_headers() mock_result = Mock(spec=gcmclient.gcm.Result) mock_result.canonical = dict() mock_result.failed = dict() @@ -1050,7 +1099,8 @@ def test_route_to_busy_node_saves_looks_up_and_sends_check_201(self): def verify_deliver(result): ok_(isinstance(result, RouterResponse)) eq_(result.status_code, 201) - t_h = self.message_mock.store_message.call_args[1].get('headers') + kwargs = self.message_mock.store_message.call_args[1] + t_h = kwargs["notification"].headers eq_(t_h.get('encryption'), self.headers.get('encryption')) eq_(t_h.get('crypto_key'), self.headers.get('crypto-key')) eq_(t_h.get('encoding'), self.headers.get('content-encoding')) @@ -1063,8 +1113,14 @@ def verify_deliver(result): return d def test_route_to_busy_node_with_ttl_zero(self): - notif = Notification("EncMessageId", "data", dummy_chid, - self.headers, 0) + notif = WebPushNotification( + uaid=uuid.UUID(dummy_uaid), + channel_id=uuid.UUID(dummy_chid), + data="data", + headers=self.headers, + ttl=0, + message_id=uuid.uuid4().hex, + ) self.agent_mock.request.return_value = response_mock = Mock() response_mock.addCallback.return_value = response_mock type(response_mock).code = PropertyMock( diff --git a/autopush/tests/test_web_validation.py b/autopush/tests/test_web_validation.py index 4e1e11bc..cab1becf 100644 --- a/autopush/tests/test_web_validation.py +++ b/autopush/tests/test_web_validation.py @@ -316,7 +316,7 @@ def test_valid_data(self): ) result, errors = schema.load(self._make_test_data()) eq_(errors, {}) - ok_("message_id" in result) + ok_("notification" in result) eq_(str(result["subscription"]["uaid"]), dummy_uaid) def test_no_headers(self): @@ -528,6 +528,45 @@ def test_invalid_vapid_crypto_header(self): eq_(cm.exception.status_code, 401) + def test_invalid_topic(self): + schema = self._make_fut() + schema.context["settings"].parse_endpoint.return_value = dict( + uaid=dummy_uaid, + chid=dummy_chid, + public_key="", + ) + schema.context["settings"].router.get_uaid.return_value = dict( + router_type="webpush", + ) + + info = self._make_test_data( + headers={ + "topic": "asdfasdfasdfasdfasdfasdfasdfasdfasdfasdf", + } + ) + + with assert_raises(InvalidRequest) as cm: + schema.load(info) + + eq_(cm.exception.status_code, 400) + eq_(cm.exception.errno, 113) + eq_(cm.exception.message, + "Topic must be no greater than 32 characters") + + info = self._make_test_data( + headers={ + "topic": "asdf??asdf::;f", + } + ) + + with assert_raises(InvalidRequest) as cm: + schema.load(info) + + eq_(cm.exception.status_code, 400) + eq_(cm.exception.errno, 113) + eq_(cm.exception.message, + "Topic must be URL and Filename safe Base64 alphabet") + class TestWebPushRequestSchemaUsingVapid(unittest.TestCase): def _make_fut(self): diff --git a/autopush/tests/test_websocket.py b/autopush/tests/test_websocket.py index a1c3c5a2..eca19a28 100644 --- a/autopush/tests/test_websocket.py +++ b/autopush/tests/test_websocket.py @@ -5,6 +5,7 @@ from hashlib import sha256 import twisted.internet.base +from autopush.tests.test_db import make_webpush_notification from boto.dynamodb2.exceptions import ( ProvisionedThroughputExceededException, ) @@ -1365,7 +1366,7 @@ def test_notification_with_webpush(self): # Check the call result args = json.loads(self.send_mock.call_args[0][0]) eq_(args, {"messageType": "notification", "channelID": chid, - "data": "bleh", "version": "10:", "headers": {}}) + "data": "bleh", "version": "10", "headers": {}}) def test_notification_avoid_newer_delivery(self): self._connect() @@ -1422,11 +1423,10 @@ def test_ack_with_webpush_direct(self): self.proto.ps.uaid = uuid.uuid4().hex chid = str(uuid.uuid4()) + notif = make_webpush_notification(self.proto.ps.uaid, chid) + notif.message_id = "bleh:asdjfilajsdilfj" self.proto.ps.use_webpush = True - self.proto.ps.direct_updates[chid] = [ - Notification(version="bleh", headers={}, data="meh", - channel_id=chid, ttl=200, timestamp=0) - ] + self.proto.ps.direct_updates[chid] = [notif] self.proto.ack_update(dict( channelID=chid, @@ -1442,12 +1442,12 @@ def test_ack_with_webpush_direct(self): def test_ack_with_webpush_from_storage(self): self._connect() chid = str(uuid.uuid4()) + self.proto.ps.uaid = uuid.uuid4().hex self.proto.ps.use_webpush = True self.proto.ps.direct_updates[chid] = [] - self.proto.ps.updates_sent[chid] = [ - Notification(version="bleh", headers={}, data="meh", - channel_id=chid, ttl=200, timestamp=0) - ] + notif = make_webpush_notification(self.proto.ps.uaid, chid) + notif.message_id = "bleh:jialsdjfilasjdf" + self.proto.ps.updates_sent[chid] = [notif] mock_defer = Mock() self.proto.force_retry = Mock(return_value=mock_defer) @@ -1708,12 +1708,14 @@ def test_notif_finished_with_webpush_with_notifications(self): self.proto.ps.use_webpush = True self.proto.ps._check_notifications = True self.proto.process_notifications = Mock() - self.proto.ps.updates_sent["asdf"] = [] - self.proto.finish_webpush_notifications([ - dict(chidmessageid="asdf:fdsa", headers={}, data="bleh", ttl=100, - timestamp=int(time.time()), updateid=uuid.uuid4().hex) - ]) + notif = make_webpush_notification( + self.proto.ps.uaid, + uuid.uuid4().hex, + ) + self.proto.ps.updates_sent[str(notif.channel_id)] = [] + + self.proto.finish_webpush_notifications([notif]) assert self.send_mock.called def test_notif_finished_with_webpush_with_old_notifications(self): @@ -1722,13 +1724,16 @@ def test_notif_finished_with_webpush_with_old_notifications(self): self.proto.ps.use_webpush = True self.proto.ps._check_notifications = True self.proto.process_notifications = Mock() - self.proto.ps.updates_sent["asdf"] = [] + notif = make_webpush_notification( + self.proto.ps.uaid, + uuid.uuid4().hex, + ttl=5 + ) + notif.timestamp = 0 + self.proto.ps.updates_sent[str(notif.channel_id)] = [] self.proto.force_retry = Mock() - self.proto.finish_webpush_notifications([ - dict(chidmessageid="asdf:fdsa", headers={}, data="bleh", ttl=10, - timestamp=0, updateid=uuid.uuid4().hex) - ]) + self.proto.finish_webpush_notifications([notif]) assert self.proto.force_retry.called assert not self.send_mock.called diff --git a/autopush/utils.py b/autopush/utils.py index ab8f6035..a33ba67f 100644 --- a/autopush/utils.py +++ b/autopush/utils.py @@ -2,14 +2,29 @@ import base64 import hashlib import hmac +import re import socket import uuid import ecdsa import requests +import time + +from attr import ( + Factory, + attrs, + attrib +) from jose import jwt from ua_parser import user_agent_parser +from autopush.exceptions import InvalidTokenException + + +# Remove trailing padding characters from complex header items like +# Crypto-Key and Encryption +STRIP_PADDING = re.compile('=+(?=[,;]|$)') + # List of valid user-agent attributes to keep, anything not in this list is # considered 'Other'. We log the user-agent on connect always to retain the @@ -28,6 +43,16 @@ } +def normalize_id(ident): + if (len(ident) == 36 and + ident[8] == ident[13] == ident[18] == ident[23] == '-'): + return ident.lower() + raw = filter(lambda x: x in '0123456789abcdef', ident.lower()) + if len(raw) != 32: + raise ValueError("Invalid UUID") + return '-'.join((raw[:8], raw[8:12], raw[12:16], raw[16:20], raw[20:])) + + def canonical_url(scheme, hostname, port=None): """Return a canonical URL given a scheme/hostname and optional port""" if port is None or port == default_ports.get(scheme): @@ -198,3 +223,295 @@ def parse_user_agent(agent_string): raw_info["ua_browser_ver"] = ".".join(filter(None, browser_bits)) return dd_info, raw_info + + +@attrs(slots=True) +class WebPushNotification(object): + """WebPush Notification + + This object centralizes all logic involving the addressing of a single + WebPush Notification. + + message_id serves a complex purpose. It's returned as the Location header + value so that an application server may delete the message. It's used as + part of the non-versioned sort-key. Due to this, its an encrypted value + that contains the necessary information to derive the location of this + precise message in the appropriate message table. + + """ + uaid = attrib() # type: uuid.UUID + channel_id = attrib() # type: uuid.UUID + ttl = attrib() # type: int + data = attrib(default=None) + headers = attrib(default=None) # type: dict + timestamp = attrib(default=Factory(lambda: int(time.time()))) + topic = attrib(default=None) + + message_id = attrib(default=None) # type: str + + # Not an alias for message_id, for backwards compat and cases where an old + # message with any update_id should be removed. + update_id = attrib(default=None) # type: str + + def generate_message_id(self, fernet): + """Generate a message-id suitable for accessing the message + + For non-topic messages, no sort_key version is currently used and the + message-id is: + + Encrypted(m : uaid.hex : channel_id.hex) + + For topic messages, a sort_key version of 01 is used, and the topic + is included for reference: + + Encrypted(01 : uaid.hex : channel_id.hex : topic) + + This is a blocking call. + + :type fernet: cryptography.fernet.Fernet + + """ + if self.topic: + msg_key = ":".join(["01", self.uaid.hex, self.channel_id.hex, + self.topic]) + else: + msg_key = ":".join(["m", self.uaid.hex, self.channel_id.hex]) + self.message_id = fernet.encrypt(msg_key.encode('utf8')) + self.update_id = self.message_id + return self.message_id + + @staticmethod + def parse_decrypted_message_id(decrypted_token): + """Parses a decrypted message-id into component parts + + :type decrypted_token: str + :rtype: dict + + """ + topic = None + if decrypted_token.startswith("01:"): + info = decrypted_token.split(":") + if len(info) != 4: + raise InvalidTokenException("Incorrect number of token parts.") + api_ver, uaid, chid, topic = info + else: + info = decrypted_token.split(":") + if len(info) != 3: + raise InvalidTokenException("Incorrect number of token parts.") + kind, uaid, chid = decrypted_token.split(":") + if kind != "m": + raise InvalidTokenException("Incorrect token kind.") + return dict( + uaid=uaid, + chid=chid, + topic=topic, + ) + + def cleanup_headers(self): + """Sanitize the headers for this notification + + This only needs to be run when creating a notification from passed + in application server headers. + + """ + headers = self.headers + # Strip crypto/encryption headers down + for hdr in ["crypto-key", "encryption"]: + if STRIP_PADDING.search(headers.get(hdr, "")): + head = headers[hdr].replace('"', '') + headers[hdr] = STRIP_PADDING.sub("", head) + + data = dict( + encoding=headers["content-encoding"], + encryption=headers["encryption"], + ) + # AWS cannot store empty strings, so we only add these keys if + # they're present to avoid empty strings. + for name in ["encryption-key", "crypto-key"]: + if name in headers: + # NOTE: The client code expects all header keys to be lower + # case and s/-/_/. + data[name.lower().replace("-", "_")] = headers[name] + self.headers = data + + @property + def sort_key(self): + """Return an appropriate sort_key for this notification""" + chid = normalize_id(self.channel_id.hex) + if self.topic: + return "01:{chid}:{topic}".format(chid=chid, topic=self.topic) + else: + return "{chid}:{message_id}".format(chid=chid, + message_id=self.message_id) + + @staticmethod + def parse_sort_key(sort_key): + """Parse the sort key from the database + + :type sort_key: str + :rtype: dict + + """ + topic = None + message_id = None + if re.match(r'^\d\d:', sort_key): + api_ver, channel_id, topic = sort_key.split(":") + else: + channel_id, message_id = sort_key.split(":") + api_ver = "00" + return dict(api_ver=api_ver, channel_id=channel_id, + topic=topic, message_id=message_id) + + @property + def location(self): + """Return an appropriate value for the Location header""" + return self.message_id + + def expired(self, at_time=None): + """Indicates whether the message has expired or not + + :param at_time: Optional time to compare for expiration + :type at_time: int + + """ + now = at_time or int(time.time()) + return now >= (self.ttl + self.timestamp) + + @classmethod + def from_message_table(cls, uaid, item): + """Create a WebPushNotification from a message table item + + :type uaid: uuid.UUID + :type item: dict or boto.dynamodb2.item.Item + + :rtype: WebPushNotification + + """ + key_info = cls.parse_sort_key(item["chidmessageid"]) + if key_info.get("topic"): + key_info["message_id"] = item["updateid"] + + return cls(uaid=uaid, channel_id=uuid.UUID(key_info["channel_id"]), + data=item.get("data"), + headers=item.get("headers"), + ttl=item["ttl"], + topic=key_info.get("topic"), + message_id=key_info["message_id"], + update_id=item.get("updateid"), + timestamp=item.get("timestamp"), + ) + + @classmethod + def from_webpush_request_schema(cls, data, fernet): + """Create a WebPushNotification from a validated WebPushRequestSchema + + This is a blocking call. + + :type data: autopush.web.validation.WebPushRequestSchema + :type fernet: cryptography.fernet.Fernet + + :rtype: WebPushNotification + + """ + sub = data["subscription"] + notif = cls(uaid=sub["uaid"], channel_id=sub["chid"], + data=data["body"], headers=data["headers"], + ttl=data["headers"]["ttl"], + topic=data["headers"]["topic"]) + + if notif.data: + notif.cleanup_headers() + else: + notif.headers = None + + notif.generate_message_id(fernet) + return notif + + @classmethod + def from_message_id(cls, message_id, fernet): + """Create a WebPushNotification from a message_id + + This is a blocking call. + + The resulting WebPushNotification is not a complete one + from the database, but has all the parsed attributes + available that can be derived from the message_id. + + This is suitable for passing to delete calls. + + :type message_id: str + :type fernet: cryptography.fernet.Fernet + + :rtype: WebPushNotification + + """ + decrypted_message_id = fernet.decrypt(message_id) + key_info = cls.parse_decrypted_message_id(decrypted_message_id) + notif = cls(uaid=uuid.UUID(key_info["uaid"]), + channel_id=uuid.UUID(key_info["chid"]), + data=None, + ttl=None, + topic=key_info["topic"], + message_id=message_id, + ) + if key_info["topic"]: + notif.update_id = message_id + return notif + + @classmethod + def from_serialized(cls, uaid, data): + """Create a WebPushNotification from a deserialized JSON dict + + :type uaid: uuid.UUID + :type data: dict + + :rtype: WebPushNotification + + """ + notif = cls(uaid=uaid, channel_id=uuid.UUID(data["channelID"]), + data=data.get("data"), + headers=data.get("headers"), + ttl=data.get("ttl"), + topic=data.get("topic"), + message_id=str(data["version"]), + update_id=str(data["version"]), + timestamp=data.get("timestamp"), + ) + return notif + + @property + def version(self): + """Return a 'version' for use with a websocket client + + In our case we use the message-id as its a unique value for every + message. + + """ + return self.message_id + + def serialize(self): + """Serialize to a dict for delivery to a connection node""" + payload = dict( + channelID=normalize_id(self.channel_id.hex), + version=self.version, + ttl=self.ttl, + topic=self.topic, + timestamp=self.timestamp, + ) + if self.data: + payload["data"] = self.data + payload["headers"] = self.headers + return payload + + def websocket_format(self): + """Format a notification for a websocket client""" + # Firefox currently requires channelIDs to be '-' formatted. + payload = dict( + messageType="notification", + channelID=normalize_id(self.channel_id.hex), + version=self.version, + ) + if self.data: + payload["data"] = self.data + payload["headers"] = self.headers + return payload diff --git a/autopush/web/base.py b/autopush/web/base.py index a19118c4..1256c0d6 100644 --- a/autopush/web/base.py +++ b/autopush/web/base.py @@ -1,7 +1,7 @@ import json import time -from collections import namedtuple +from attr import attrs, attrib from boto.dynamodb2.exceptions import ( ProvisionedThroughputExceededException, ) @@ -27,9 +27,12 @@ "#error-codes") -class Notification(namedtuple("Notification", - "version data channel_id headers ttl")): +@attrs +class Notification(object): """Parsed notification from the request""" + version = attrib() + data = attrib() + channel_id = attrib() class BaseWebHandler(BaseHandler): diff --git a/autopush/web/simplepush.py b/autopush/web/simplepush.py index 6068ac78..3469b096 100644 --- a/autopush/web/simplepush.py +++ b/autopush/web/simplepush.py @@ -29,8 +29,7 @@ def put(self, api_ver="v1", token=None): version=self.valid_input["version"], data=self.valid_input["data"], channel_id=str(sub["chid"]), - headers=self.request.headers, - ttl=None) + ) d = Deferred() d.addCallback(router.route_notification, user_data) diff --git a/autopush/web/validation.py b/autopush/web/validation.py index 2756fb87..5995d65f 100644 --- a/autopush/web/validation.py +++ b/autopush/web/validation.py @@ -28,13 +28,16 @@ from autopush.utils import ( base64url_encode, extract_jwt, -) + WebPushNotification) MAX_TTL = 60 * 60 * 24 * 60 # Older versions used "bearer", newer specification requires "webpush" AUTH_SCHEMES = ["bearer", "webpush"] PREF_SCHEME = "webpush" +# Base64 URL validation +VALID_BASE64_URL = re.compile(r'^[0-9A-Za-z\-_]+=*$') + class ThreadedValidate(object): """A cyclone request validation decorator @@ -110,11 +113,6 @@ def post(self): threaded_validate = ThreadedValidate.validate -# Remove trailing padding characters from complex header items like -# Crypto-Key and Encryption -strip_padding = re.compile('=+(?=[,;]|$)') - - class SimplePushSubscriptionSchema(Schema): uaid = fields.UUID(required=True) chid = fields.UUID(required=True) @@ -234,8 +232,22 @@ class WebPushHeaderSchema(Schema): encryption = fields.String() encryption_key = fields.String(load_from="encryption-key") ttl = fields.Integer(required=False, missing=None) + topic = fields.String(required=False, missing=None) api_ver = fields.String() + @validates('topic') + def validate_topic(self, value): + if value is None: + return True + + if len(value) > 32: + raise InvalidRequest("Topic must be no greater than 32 " + "characters", errno=113) + + if not VALID_BASE64_URL.match(value): + raise InvalidRequest("Topic must be URL and Filename safe Base" + "64 alphabet", errno=113) + @validates_schema def validate_cypto_headers(self, d): # Not allowed to use aesgcm128 + a crypto_key @@ -352,18 +364,11 @@ def fixup_output(self, d): # schema logic to run first. self.validate_auth(d) - # Add a message_id - sub = d["subscription"] - d["message_id"] = self.context["settings"].fernet.encrypt( - ":".join(["m", sub["uaid"].hex, sub["chid"].hex]).encode('utf8') - ) - - # Strip crypto/encryption headers down - for hdr in ["crypto-key", "encryption"]: - if strip_padding.search(d["headers"].get(hdr, "")): - head = d["headers"][hdr].replace('"', '') - d["headers"][hdr] = strip_padding.sub("", head) - # Base64-encode data for Web Push d["body"] = base64url_encode(d["body"]) + + # Set the notification based on the validated request schema data + d["notification"] = WebPushNotification.from_webpush_request_schema( + data=d, fernet=self.context["settings"].fernet + ) return d diff --git a/autopush/web/webpush.py b/autopush/web/webpush.py index 8b48a8a4..6480c1d1 100644 --- a/autopush/web/webpush.py +++ b/autopush/web/webpush.py @@ -5,7 +5,6 @@ from autopush.web.base import ( BaseWebHandler, - Notification, ) from autopush.web.validation import ( threaded_validate, @@ -35,18 +34,11 @@ def post(self, api_ver="v1", token=None): for i in jwt["jwt_data"]: self._client_info["jwt_" + i] = jwt["jwt_data"][i] - sub = self.valid_input["subscription"] - user_data = sub["user_data"] + user_data = self.valid_input["subscription"]["user_data"] router = self.ap_settings.routers[user_data["router_type"]] - self._client_info["message_id"] = self.valid_input["message_id"] + notification = self.valid_input["notification"] + self._client_info["message_id"] = notification.message_id - notification = Notification( - version=self.valid_input["message_id"], - data=self.valid_input["body"], - channel_id=str(sub["chid"]), - headers=self.valid_input["headers"], - ttl=self.valid_input["headers"]["ttl"] - ) self._client_info["uaid"] = hasher(user_data.get("uaid")) self._client_info["channel_id"] = user_data.get("chid") d = Deferred() diff --git a/autopush/websocket.py b/autopush/websocket.py index 42688411..d803001d 100644 --- a/autopush/websocket.py +++ b/autopush/websocket.py @@ -41,6 +41,7 @@ ProvisionedThroughputExceededException, ItemNotFound ) +from typing import List # flake8: noqa from twisted.internet import reactor from twisted.internet.defer import ( Deferred, @@ -71,7 +72,7 @@ from autopush.utils import ( parse_user_agent, validate_uaid, -) + WebPushNotification) from autopush.noseplugin import track_object @@ -139,6 +140,7 @@ class PushState(object): '_paused', 'metrics', 'uaid', + 'uaid_obj', 'uaid_hash', 'raw_agent', 'last_ping', @@ -194,6 +196,7 @@ def __init__(self, settings, request): self.metrics.increment("client.socket.connect", tags=self._base_tags or None) self.uaid = None + self.uaid_obj = None self.uaid_hash = "" self.last_ping = 0 self.check_storage = False @@ -518,13 +521,7 @@ def _save_webpush_notif(self, notif): """Save a direct_update webpush style notification""" return deferToThread( self.ps.message.store_message, - uaid=self.ps.uaid, - channel_id=notif.channel_id, - data=notif.data, - headers=notif.headers, - message_id=notif.version, - ttl=notif.ttl, - timestamp=notif.timestamp, + notif ).addErrback(self.log_failure) def _save_simple_notif(self, channel_id, version): @@ -639,6 +636,7 @@ def process_hello(self, data): existing_user, uaid = validate_uaid(uaid) self.ps.uaid = uaid + self.ps.uaid_obj = uuid.UUID(uaid) self.ps.uaid_hash = hasher(uaid) # Check for the special wakeup commands if "wakeup_host" in data and "mobilenetwork" in data: @@ -841,7 +839,7 @@ def process_notifications(self): if self.ps.use_webpush: d = self.deferToThread(self.ps.message.fetch_messages, - self.ps.uaid) + self.ps.uaid_obj) else: d = self.deferToThread( self.ap_settings.storage.fetch_notifications, self.ps.uaid) @@ -897,7 +895,11 @@ def finish_notifications(self, notifs): d.addErrback(self.trap_cancel) def finish_webpush_notifications(self, notifs): - """webpush notification processor""" + """webpush notification processor + + :type notifs: List[autopush.utils.WebPushNotification] + + """ if not notifs: # No more notifications, we can stop. self.ps._more_notifications = False @@ -918,30 +920,13 @@ def finish_webpush_notifications(self, notifs): # Send out all the notifications now = int(time.time()) for notif in notifs: - # Split off the chid and message id - chid, version = notif["chidmessageid"].split(":") - # If the TTL is too old, don't deliver and fire a delete off - if not notif["ttl"] or now >= (notif["ttl"]+notif["timestamp"]): - self.force_retry( - self.ps.message.delete_message, self.ps.uaid, - chid, version, updateid=notif["updateid"]) + if notif.expired(at_time=now): + self.force_retry(self.ps.message.delete_message, notif) continue - data = notif.get("data") - msg = dict( - messageType="notification", - channelID=chid, - version=version + ":" + notif["updateid"], - ) - if data: - msg["data"] = data - msg["headers"] = notif["headers"] - self.ps.updates_sent[chid].append( - Notification(channel_id=chid, version=version, - data=notif["data"], headers=notif.get("headers"), - ttl=notif["ttl"], timestamp=notif["timestamp"]) - ) + self.ps.updates_sent[str(notif.channel_id)].append(notif) + msg = notif.websocket_format() self.sendJSON(msg) def _rotate_message_table(self): @@ -1098,12 +1083,9 @@ def process_unregister(self, data): self.ps.updates_sent.pop(chid, None) if self.ps.use_webpush: - # Unregister the channel, delete all messages stored + # Unregister the channel self.force_retry(self.ap_settings.message.unregister_channel, self.ps.uaid, chid) - self.force_retry( - self.ap_settings.message.delete_messages_for_channel, - self.ps.uaid, chid) else: # Delete any record from storage, we don't wait for this self.force_retry(self.ap_settings.storage.delete_notification, @@ -1136,11 +1118,8 @@ def ack_update(self, update): def _handle_webpush_ack(self, chid, version, code): """Handle clearing out a webpush ack""" - # Split off the updateid if its not a direct update - version, updateid = version.split(":") - - def ver_filter(update): - return update.version == version + def ver_filter(notif): + return notif.version == version found = filter(ver_filter, self.ps.direct_updates[chid]) if found: @@ -1163,11 +1142,7 @@ def ver_filter(update): message_size=size, uaid_hash=self.ps.uaid_hash, user_agent=self.ps.user_agent, code=code, **self.ps.raw_agent) - d = self.force_retry(self.ps.message.delete_message, - uaid=self.ps.uaid, - channel_id=chid, - message_id=version, - updateid=updateid) + d = self.force_retry(self.ps.message.delete_message, msg) # We don't remove the update until we know the delete ran # This is because we don't use range queries on dynamodb and we # need to make sure this notification is deleted from the db before @@ -1288,21 +1263,11 @@ def send_notifications(self, update): return if self.ps.use_webpush: - response = dict( - messageType="notification", - channelID=chid, - version="%s:" % version, - ) - data = update.get("data") - if data: - response["data"] = data - response["headers"] = update["headers"] - self.ps.direct_updates[chid].append( - Notification(channel_id=chid, version=version, - data=data, headers=update.get("headers"), - ttl=update["ttl"], timestamp=update["timestamp"]) - ) - self.sendJSON(response) + # Create the notification + notif = WebPushNotification.from_serialized(self.ps.uaid_obj, + update) + self.ps.direct_updates[chid].append(notif) + self.sendJSON(notif.websocket_format()) else: self.ps.direct_updates[chid] = version msg = {"messageType": "notification", "updates": [update]} diff --git a/base-requirements.txt b/base-requirements.txt index de42435b..aa04a1c4 100644 --- a/base-requirements.txt +++ b/base-requirements.txt @@ -32,7 +32,7 @@ idna==2.1 ipaddress==1.0.16 itsdangerous==0.24 jmespath==0.9.0 -marshmallow==2.9.1 +marshmallow==2.10.2 mccabe==0.5.2 pbr==1.10.0 pluggy==0.3.1 @@ -51,6 +51,7 @@ service-identity==16.0.0 simplejson==3.8.2 six==1.10.0 translationstring==1.3 +typing==3.5.2.2 -e git+https://github.com/habnabit/txstatsd.git@157ef85fbdeafe23865c7c4e176237ffcb3c3f1f#egg=txStatsD-master txaio==2.5.1 ua_parser==0.7.1 diff --git a/docs/http.rst b/docs/http.rst index 64c84308..0fa612ba 100644 --- a/docs/http.rst +++ b/docs/http.rst @@ -93,6 +93,7 @@ Unless otherwise specified, all calls return the following error codes: - Missing Crypto Headers - Include the appropriate encryption headers (`WebPush Encryption §3.2 `_ and `WebPush VAPID §4 `_) - errno 112 - Invalid TTL header value + - errno 113 - Invalid Topic header value - 401 - Bad Authorization