From f46f8d7f251ac4981ec21ce4a784fcc0c1927cc6 Mon Sep 17 00:00:00 2001 From: Ben Bangert Date: Sun, 23 Oct 2016 23:02:12 -0700 Subject: [PATCH] feat: skip timestamped messages instead of deleting Alleviates heavy writes to remove messages by instead timestamping messages and skipping through them. The first message record retains the timestamp that was read up to for later use. This also removes an existing edge case where the connection node would not fetch more messages if all the ones in a batch it fetched had their TTL expired. Closes #661 --- autopush/db.py | 192 +++++++++++++++++++++-------- autopush/settings.py | 4 + autopush/tests/test_db.py | 3 +- autopush/tests/test_endpoint.py | 11 ++ autopush/tests/test_integration.py | 136 ++++++++++++++++++-- autopush/tests/test_websocket.py | 23 +++- autopush/utils.py | 141 +++++++++++++++------ autopush/web/push_validation.py | 3 +- autopush/websocket.py | 141 +++++++++++++++++---- 9 files changed, 517 insertions(+), 137 deletions(-) diff --git a/autopush/db.py b/autopush/db.py index 60614e5e..e451e125 100644 --- a/autopush/db.py +++ b/autopush/db.py @@ -47,28 +47,40 @@ from boto.dynamodb2.layer1 import DynamoDBConnection from boto.dynamodb2.table import Table, Item from boto.dynamodb2.types import NUMBER -from typing import Iterable, List # flake8: noqa +from typing import ( # noqa + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Set, + TypeVar, + Tuple, + Union, +) from autopush.exceptions import AutopushException +from autopush.metrics import IMetrics # noqa from autopush.utils import ( generate_hash, normalize_id, WebPushNotification, ) +# Typing +T = TypeVar('T') # noqa + key_hash = "" TRACK_DB_CALLS = False DB_CALLS = [] def get_month(delta=0): + # type: (int) -> datetime.date """Basic helper function to get a datetime.date object iterations months ahead/behind of now. - :type delta: int - - :rtype: datetime.datetime - """ new = last = datetime.date.today() # Move until we hit a new month, this avoids having to manually @@ -85,12 +97,15 @@ def get_month(delta=0): def hasher(uaid): + # type: (str) -> str + """Hashes a key using a key_hash if present""" if key_hash: return generate_hash(key_hash, uaid) return uaid def dump_uaid(uaid_data): + # type: (Union[Dict[str, Any], Item]) -> str """Return a dict for a uaid. This is utilized instead of repr since some db methods return a @@ -105,6 +120,7 @@ def dump_uaid(uaid_data): def make_rotating_tablename(prefix, delta=0, date=None): + # type: (str, int, Optional[datetime.date]) -> str """Creates a tablename for table rotation based on a prefix with a given month delta.""" if not date: @@ -114,6 +130,7 @@ def make_rotating_tablename(prefix, delta=0, date=None): def create_rotating_message_table(prefix="message", read_throughput=5, write_throughput=5, delta=0): + # type: (str, int, int, int) -> Table """Create a new message table for webpush style message storage""" tablename = make_rotating_tablename(prefix, delta) return Table.create(tablename, @@ -127,6 +144,7 @@ def create_rotating_message_table(prefix="message", read_throughput=5, def get_rotating_message_table(prefix="message", delta=0, date=None, message_read_throughput=5, message_write_throughput=5): + # type: (str, int, Optional[datetime.date], int, int) -> Table """Gets the message table for the current month.""" db = DynamoDBConnection() dblist = db.list_tables()["TableNames"] @@ -142,6 +160,7 @@ def get_rotating_message_table(prefix="message", delta=0, date=None, def create_router_table(tablename="router", read_throughput=5, write_throughput=5): + # type: (str, int, int) -> Table """Create a new router table The last_connect index is a value used to determine the last month a user @@ -164,13 +183,14 @@ def create_router_table(tablename="router", read_throughput=5, 'AccessIndex', parts=[ HashKey('last_connect', - data_type=NUMBER)], + data_type=NUMBER)], throughput=dict(read=5, write=5))], ) def create_storage_table(tablename="storage", read_throughput=5, write_throughput=5): + # type: (str, int, int) -> Table """Create a new storage table for simplepush style notification storage""" return Table.create(tablename, schema=[HashKey("uaid"), RangeKey("chid")], @@ -180,6 +200,7 @@ def create_storage_table(tablename="storage", read_throughput=5, def _make_table(table_func, tablename, read_throughput, write_throughput): + # type: (Callable[[str, int, int], Table], str, int, int) -> Table """Private common function to make a table with a table func""" db = DynamoDBConnection() dblist = db.list_tables()["TableNames"] @@ -191,6 +212,7 @@ def _make_table(table_func, tablename, read_throughput, write_throughput): def get_router_table(tablename="router", read_throughput=5, write_throughput=5): + # type: (str, int, int) -> Table """Get the main router table object Creates the table if it doesn't already exist, otherwise returns the @@ -203,6 +225,7 @@ def get_router_table(tablename="router", read_throughput=5, def get_storage_table(tablename="storage", read_throughput=5, write_throughput=5): + # type: (str, int, int) -> Table """Get the main storage table object Creates the table if it doesn't already exist, otherwise returns the @@ -214,6 +237,7 @@ def get_storage_table(tablename="storage", read_throughput=5, def preflight_check(storage, router, uaid="deadbeef00000000deadbeef00000000"): + # type: (Storage, Router, str) -> None """Performs a pre-flight check of the storage/router/message to ensure appropriate permissions for operation. @@ -255,6 +279,7 @@ def preflight_check(storage, router, uaid="deadbeef00000000deadbeef00000000"): def track_provisioned(func): + # type: (Callable[..., T]) -> Callable[..., T] """Tracks provisioned exceptions and increments a metric for them named after the function decorated""" @wraps(func) @@ -276,13 +301,8 @@ def wrapper(self, *args, **kwargs): def has_connected_this_month(item): - """Whether or not a router item has connected this month - - :type item: dict - - :rtype: bool - - """ + # type: (Dict[str, Any]) -> bool + """Whether or not a router item has connected this month""" last_connect = item.get("last_connect") if not last_connect: return False @@ -293,16 +313,13 @@ def has_connected_this_month(item): def generate_last_connect(): + # type: () -> int """Generate a last_connect This intentionally generates a limited set of keys for each month in a known sequence. For each month, there's 24 hours * 10 random numbers for a total of 240 keys per month depending on when the user migrates forward. - :type date: datetime.datetime - - :rtype: int - """ today = datetime.datetime.today() val = "".join([ @@ -315,15 +332,12 @@ def generate_last_connect(): def generate_last_connect_values(date): + # type: (datetime.date) -> Iterable[int] """Generator of last_connect values for a given date Creates an iterator that yields all the valid values for ``last_connect`` for a given year/month. - :type date: datetime.datetime - - :rtype: Iterable[int] - """ year = str(date.year) month = str(date.month).zfill(2) @@ -337,6 +351,7 @@ def generate_last_connect_values(date): class Storage(object): """Create a Storage table abstraction on top of a DynamoDB Table object""" def __init__(self, table, metrics): + # type: (Table, IMetrics) -> None """Create a new Storage object :param table: :class:`Table` object. @@ -350,6 +365,7 @@ def __init__(self, table, metrics): @track_provisioned def fetch_notifications(self, uaid): + # type: (str) -> List[Dict[str, Any]] """Fetch all notifications for a UAID :raises: @@ -363,6 +379,7 @@ def fetch_notifications(self, uaid): @track_provisioned def save_notification(self, uaid, chid, version): + # type: (str, str, Optional[int]) -> bool """Save a notification for the UAID :raises: @@ -388,10 +405,10 @@ def save_notification(self, uaid, chid, version): return False def delete_notification(self, uaid, chid, version=None): + # type: (str, str, Optional[int]) -> bool """Delete a notification for a UAID :returns: Whether or not the notification was able to be deleted. - :rtype: bool """ try: @@ -411,6 +428,7 @@ def delete_notification(self, uaid, chid, version=None): class Message(object): """Create a Message table abstraction on top of a DynamoDB Table object""" def __init__(self, table, metrics): + # type: (Table, IMetrics) -> None """Create a new Message object :param table: :class:`Table` object. @@ -424,6 +442,7 @@ def __init__(self, table, metrics): @track_provisioned def register_channel(self, uaid, channel_id): + # type: (str, str) -> bool """Register a channel for a given uaid""" conn = self.table.connection db_key = self.encode({"uaid": hasher(uaid), "chidmessageid": " "}) @@ -441,6 +460,7 @@ def register_channel(self, uaid, channel_id): @track_provisioned def unregister_channel(self, uaid, channel_id, **kwargs): + # type: (str, str, **str) -> bool """Remove a channel registration for a given uaid""" conn = self.table.connection db_key = self.encode({"uaid": hasher(uaid), "chidmessageid": " "}) @@ -466,6 +486,7 @@ def unregister_channel(self, uaid, channel_id, **kwargs): @track_provisioned def all_channels(self, uaid): + # type: (str) -> Tuple[bool, Set[str]] """Retrieve a list of all channels for a given uaid""" # Note: This only returns the chids associated with the UAID. @@ -480,6 +501,7 @@ def all_channels(self, uaid): @track_provisioned def save_channels(self, uaid, channels): + # type: (str, Set[str]) -> None """Save out a set of channels""" self.table.put_item(data=dict( uaid=hasher(uaid), @@ -489,12 +511,8 @@ def save_channels(self, uaid, channels): @track_provisioned def store_message(self, notification): - """Stores a WebPushNotification in the message table - - :type notification: WebPushNotification - :type timestamp: int - - """ + # type: (WebPushNotification) -> bool + """Stores a WebPushNotification in the message table""" item = dict( uaid=hasher(notification.uaid.hex), chidmessageid=notification.sort_key, @@ -509,11 +527,8 @@ def store_message(self, notification): @track_provisioned def delete_message(self, notification): - """Deletes a specific message - - :type notification: WebPushNotification - - """ + # type: (WebPushNotification) -> bool + """Deletes a specific message""" if notification.update_id: try: self.table.delete_item( @@ -530,25 +545,94 @@ def delete_message(self, notification): return True @track_provisioned - def fetch_messages(self, uaid, limit=10): + def fetch_messages( + self, + uaid, # type: uuid.UUID + limit=10, # type: int + ): + # type: (...) -> Tuple[Optional[int], List[WebPushNotification]] """Fetches messages for a uaid - :type uaid: uuid.UUID - :type limit: int + :returns: A tuple of the last timestamp to read for timestamped + messages and the list of non-timestamped messages. """ # Eagerly fetches all results in the result set. - results = self.table.query_2(uaid__eq=hasher(uaid.hex), - chidmessageid__gt=" ", - consistent=True, limit=limit) - return [ + results = list(self.table.query_2(uaid__eq=hasher(uaid.hex), + chidmessageid__lt="02", + consistent=True, limit=limit)) + + # First extract the position if applicable, slightly higher than 01: + # to ensure we don't load any 01 remainders that didn't get deleted + # yet + last_position = None + if results: + # Ensure we return an int, as boto2 can return Decimals + if results[0].get("current_timestamp"): + last_position = int(results[0]["current_timestamp"]) + + return last_position, [ + WebPushNotification.from_message_table(uaid, x) + for x in results[1:] + ] + + @track_provisioned + def fetch_timestamp_messages( + self, + uaid, # type: uuid.UUID + timestamp=None, # type: Optional[int] + limit=10, # type: int + ): + # type: (...) -> Tuple[Optional[int], List[WebPushNotification]] + """Fetches timestamped messages for a uaid + + Note that legacy messages start with a hex UUID, so they may be mixed + in with timestamp messages beginning with 02. As such we only move our + last_position forward to the last timestamped message. + + :returns: A tuple of the last timestamp to read and the list of + timestamped messages. + + """ + # Turn the timestamp into a proper sort key + if timestamp: + sortkey = "02:{timestamp}:z".format(timestamp=timestamp) + else: + sortkey = "01;" + + results = list(self.table.query_2(uaid__eq=hasher(uaid.hex), + chidmessageid__gt=sortkey, + consistent=True, limit=limit)) + notifs = [ WebPushNotification.from_message_table(uaid, x) for x in results ] + ts_notifs = [x for x in notifs if x.sortkey_timestamp] + last_position = None + if ts_notifs: + last_position = ts_notifs[-1].sortkey_timestamp + return last_position, notifs + + @track_provisioned + def update_last_message_read(self, uaid, timestamp): + # type: (uuid.UUID, int) -> bool + """Update the last read timestamp for a user""" + conn = self.table.connection + db_key = self.encode({"uaid": hasher(uaid.hex), "chidmessageid": " "}) + expr = "SET current_timestamp=:timestamp" + expr_values = self.encode({":timestamp": timestamp}) + conn.update_item( + self.table.table_name, + db_key, + update_expression=expr, + expression_attribute_values=expr_values, + ) + return True class Router(object): """Create a Router table abstraction on top of a DynamoDB Table object""" def __init__(self, table, metrics): + # type: (Table, IMetrics) -> None """Create a new Router object :param table: :class:`Table` object. @@ -561,10 +645,9 @@ def __init__(self, table, metrics): self.encode = table._encode_keys def get_uaid(self, uaid): + # type: (str) -> Item """Get the database record for the UAID - :returns: User item - :rtype: :class:`~boto.dynamodb2.items.Item` :raises: :exc:`ItemNotFound` if there is no record for this UAID. :exc:`ProvisionedThroughputExceededException` if dynamodb table @@ -592,13 +675,13 @@ def get_uaid(self, uaid): @track_provisioned def register_user(self, data): + # type: (Dict[str, Any]) -> Tuple[bool, Dict[str, Any], Dict[str, Any]] """Register this user If a record exists with a newer ``connected_at``, then the user will not be registered. :returns: Whether the user was registered or not. - :rtype: tuple :raises: :exc:`ProvisionedThroughputExceededException` if dynamodb table exceeds throughput. @@ -647,6 +730,8 @@ def register_user(self, data): @track_provisioned def drop_user(self, uaid): + # type: (str) -> bool + """Drops a user record""" # The following hack ensures that only uaids that exist and are # deleted return true. huaid = hasher(uaid) @@ -654,16 +739,14 @@ def drop_user(self, uaid): expected={"uaid__eq": huaid}) def delete_uaids(self, uaids): - """Issue a batch delete call for the given uaids - - :type uaids: List[str] - - """ + # type: (List[str]) -> None + """Issue a batch delete call for the given uaids""" with self.table.batch_write() as batch: for uaid in uaids: batch.delete_item(uaid=uaid) def drop_old_users(self, months_ago=2): + # type: (int) -> Iterable[int] """Drops user records that have no recent connection Utilizes the last_connect index to locate users that haven't @@ -684,10 +767,8 @@ def drop_old_users(self, months_ago=2): quickly as possible. :param months_ago: how many months ago since the last connect - :type months_ago: int :returns: Iterable of how many deletes were run - :rtype: Iterable[int] """ prior_date = get_month(-months_ago) @@ -713,17 +794,20 @@ def drop_old_users(self, months_ago=2): @track_provisioned def update_message_month(self, uaid, month): + # type: (str, str) -> bool """Update the route tables current_message_month Note that we also update the last_connect at this point since webpush - users when connecting will always call this once that month. + users when connecting will always call this once that month. The + current_timestamp is also reset as a new month has no last read + timestamp. """ conn = self.table.connection db_key = self.encode({"uaid": hasher(uaid)}) - expr = "SET current_month=:curmonth, last_connect=:last_connect" + expr = ("SET current_month=:curmonth, last_connect=:last_connect") expr_values = self.encode({":curmonth": month, - ":last_connect": generate_last_connect() + ":last_connect": generate_last_connect(), }) conn.update_item( self.table.table_name, @@ -735,13 +819,13 @@ def update_message_month(self, uaid, month): @track_provisioned def clear_node(self, item): + # type: (Item) -> bool """Given a router item and remove the node_id The node_id will only be cleared if the ``connected_at`` matches up with the item's ``connected_at``. :returns: Whether the node was cleared or not. - :rtype: bool :raises: :exc:`ProvisionedThroughputExceededException` if dynamodb table exceeds throughput. diff --git a/autopush/settings.py b/autopush/settings.py index 7c5126c9..2936ed3a 100644 --- a/autopush/settings.py +++ b/autopush/settings.py @@ -211,6 +211,10 @@ def __init__(self, self.ami_id = ami_id + # Generate messages per legacy rules, only used for testing to + # generate legacy data. + self._notification_legacy = False + @property def message(self): """Property that access the current message table""" diff --git a/autopush/tests/test_db.py b/autopush/tests/test_db.py index f27ca00f..e930be53 100644 --- a/autopush/tests/test_db.py +++ b/autopush/tests/test_db.py @@ -315,7 +315,8 @@ def test_message_storage(self): message.store_message(make_webpush_notification(self.uaid, chid)) message.store_message(make_webpush_notification(self.uaid, chid)) - all_messages = list(message.fetch_messages(uuid.UUID(self.uaid))) + _, all_messages = message.fetch_timestamp_messages( + uuid.UUID(self.uaid), " ") eq_(len(all_messages), 3) def test_message_storage_overwrite(self): diff --git a/autopush/tests/test_endpoint.py b/autopush/tests/test_endpoint.py index d1a7d751..cd2d0fa0 100644 --- a/autopush/tests/test_endpoint.py +++ b/autopush/tests/test_endpoint.py @@ -139,6 +139,17 @@ def handle_finish(result): self.message.delete(self._make_req('ignored')) return self.finish_deferred + def test_delete_invalid_timestamp_token(self): + tok = ":".join(["02", str(dummy_chid)]) + self.fernet_mock.decrypt.return_value = tok + + def handle_finish(result): + self.status_mock.assert_called_with(400, reason=None) + self.finish_deferred.addCallback(handle_finish) + + self.message.delete(self._make_req('ignored')) + return self.finish_deferred + def test_delete_success(self): tok = ":".join(["m", dummy_uaid.hex, str(dummy_chid)]) self.fernet_mock.decrypt.return_value = tok diff --git a/autopush/tests/test_integration.py b/autopush/tests/test_integration.py index 61940d12..5d013458 100644 --- a/autopush/tests/test_integration.py +++ b/autopush/tests/test_integration.py @@ -8,6 +8,7 @@ import time import urlparse import uuid +from contextlib import contextmanager from StringIO import StringIO from unittest.case import SkipTest @@ -448,6 +449,12 @@ def shut_down(self, client=None): if client: yield client.disconnect() + @contextmanager + def legacy_endpoint(self): + self._settings._notification_legacy = True + yield + self._settings._notification_legacy = False + class TestSimple(IntegrationBase): @inlineCallbacks @@ -703,15 +710,16 @@ def test_uaid_resumption_on_reconnect(self): class TestWebPush(IntegrationBase): @inlineCallbacks - def test_hello_only_has_two_calls(self): + def test_hello_only_has_three_calls(self): db.TRACK_DB_CALLS = True client = Client(self._ws_url, use_webpush=True) yield client.connect() result = yield client.hello() ok_(result != {}) eq_(result["use_webpush"], True) - yield client.wait_for(lambda: len(db.DB_CALLS) == 2) - eq_(db.DB_CALLS, ['register_user', 'fetch_messages']) + yield client.wait_for(lambda: len(db.DB_CALLS) == 3) + eq_(db.DB_CALLS, ['register_user', 'fetch_messages', + 'fetch_timestamp_messages']) db.DB_CALLS = [] db.TRACK_DB_CALLS = False @@ -872,17 +880,18 @@ def test_multiple_delivery_repeat_without_ack(self): yield self.shut_down(client) @inlineCallbacks - def test_multiple_delivery_with_single_ack(self): + def test_multiple_legacy_delivery_with_single_ack(self): data = str(uuid.uuid4()) data2 = str(uuid.uuid4()) client = yield self.quick_register(use_webpush=True) yield client.disconnect() ok_(client.channels) - yield client.send_notification(data=data) - yield client.send_notification(data=data2) + with self.legacy_endpoint(): + yield client.send_notification(data=data) + yield client.send_notification(data=data2) yield client.connect() yield client.hello() - result = yield client.get_notification() + result = yield client.get_notification(timeout=5) ok_(result != {}) ok_(result["data"] in map(base64url_encode, [data, data2])) result = yield client.get_notification() @@ -901,6 +910,45 @@ def test_multiple_delivery_with_single_ack(self): eq_(result, None) yield self.shut_down(client) + @inlineCallbacks + def test_multiple_delivery_with_single_ack(self): + data = str(uuid.uuid4()) + data2 = str(uuid.uuid4()) + client = yield self.quick_register(use_webpush=True) + yield client.disconnect() + ok_(client.channels) + yield client.send_notification(data=data) + yield client.send_notification(data=data2) + yield client.connect() + yield client.hello() + result = yield client.get_notification() + ok_(result != {}) + eq_(result["data"], base64url_encode(data)) + result2 = yield client.get_notification() + ok_(result2 != {}) + eq_(result2["data"], base64url_encode(data2)) + yield client.ack(result["channelID"], result["version"]) + + yield client.disconnect() + yield client.connect() + yield client.hello() + result = yield client.get_notification() + ok_(result != {}) + eq_(result["data"], base64url_encode(data)) + ok_(result["messageType"], "notification") + result2 = yield client.get_notification() + ok_(result2 != {}) + eq_(result2["data"], base64url_encode(data2)) + yield client.ack(result2["channelID"], result2["version"]) + + # Verify no messages are delivered + yield client.disconnect() + yield client.connect() + yield client.hello() + result = yield client.get_notification() + ok_(result is None) + yield self.shut_down(client) + @inlineCallbacks def test_multiple_delivery_with_multiple_ack(self): data = str(uuid.uuid4()) @@ -1021,6 +1069,64 @@ def test_ttl_expired(self): eq_(result, None) yield self.shut_down(client) + @inlineCallbacks + def test_ttl_batch_expired_and_good_one(self): + data = str(uuid.uuid4()) + data2 = str(uuid.uuid4()) + client = yield self.quick_register(use_webpush=True) + yield client.disconnect() + for x in range(0, 12): + yield client.send_notification(data=data, ttl=1) + + yield client.send_notification(data=data2) + time.sleep(1.5) + yield client.connect() + yield client.hello() + result = yield client.get_notification(timeout=4) + ok_(result is not None) + eq_(result["headers"]["encryption"], client._crypto_key) + eq_(result["data"], base64url_encode(data2)) + eq_(result["messageType"], "notification") + result = yield client.get_notification() + eq_(result, None) + yield self.shut_down(client) + + @inlineCallbacks + def test_ttl_batch_partly_expired_and_good_one(self): + data = str(uuid.uuid4()) + data1 = str(uuid.uuid4()) + data2 = str(uuid.uuid4()) + client = yield self.quick_register(use_webpush=True) + yield client.disconnect() + for x in range(0, 6): + yield client.send_notification(data=data) + + for x in range(0, 6): + yield client.send_notification(data=data1, ttl=1) + + yield client.send_notification(data=data2) + time.sleep(1.5) + yield client.connect() + yield client.hello() + + # Pull out and ack the first + for x in range(0, 6): + result = yield client.get_notification(timeout=4) + ok_(result is not None) + eq_(result["data"], base64url_encode(data)) + yield client.ack(result["channelID"], result["version"]) + + # Should have one more that is data2, this will only arrive if the + # other six were acked as that hits the batch size + result = yield client.get_notification(timeout=4) + ok_(result is not None) + eq_(result["data"], base64url_encode(data2)) + + # No more + result = yield client.get_notification() + eq_(result, None) + yield self.shut_down(client) + @inlineCallbacks def test_message_without_crypto_headers(self): data = str(uuid.uuid4()) @@ -1154,9 +1260,11 @@ def test_webpush_monthly_rotation(self): # Send in a notification, verify it landed in last months notification # table data = uuid.uuid4().hex - yield client.send_notification(data=data) - notifs = yield deferToThread(lm_message.fetch_messages, - uuid.UUID(client.uaid)) + with self.legacy_endpoint(): + yield client.send_notification(data=data) + ts, notifs = yield deferToThread(lm_message.fetch_timestamp_messages, + uuid.UUID(client.uaid), + " ") eq_(len(notifs), 1) # Connect the client, verify the migration @@ -1247,9 +1355,11 @@ def test_webpush_monthly_rotation_prior_record_exists(self): # Send in a notification, verify it landed in last months notification # table data = uuid.uuid4().hex - yield client.send_notification(data=data) - notifs = yield deferToThread(lm_message.fetch_messages, - uuid.UUID(client.uaid)) + with self.legacy_endpoint(): + yield client.send_notification(data=data) + _, notifs = yield deferToThread(lm_message.fetch_timestamp_messages, + uuid.UUID(client.uaid), + " ") eq_(len(notifs), 1) # Connect the client, verify the migration diff --git a/autopush/tests/test_websocket.py b/autopush/tests/test_websocket.py index dc8818b5..21a89c9c 100644 --- a/autopush/tests/test_websocket.py +++ b/autopush/tests/test_websocket.py @@ -583,7 +583,8 @@ def fake_msg(data): return (True, msg_data, data) mock_msg = Mock(wraps=db.Message) - mock_msg.fetch_messages.return_value = [] + mock_msg.fetch_messages.return_value = "01;", [] + mock_msg.fetch_timestamp_messages.return_value = None, [] mock_msg.all_channels.return_value = (None, []) self.proto.ap_settings.router.register_user = fake_msg # massage message_tables to include our fake range @@ -656,7 +657,8 @@ def fake_msg(data): return (True, msg_data, data) mock_msg = Mock(wraps=db.Message) - mock_msg.fetch_messages.return_value = [] + mock_msg.fetch_messages.return_value = "01;", [] + mock_msg.fetch_timestamp_messages.return_value = None, [] mock_msg.all_channels.return_value = (None, []) self.proto.ap_settings.router.register_user = fake_msg # massage message_tables to include our fake range @@ -730,7 +732,8 @@ def test_hello_webpush_uses_one_db_call(self): channelIDs=[])) def check_result(msg): - eq_(db.DB_CALLS, ['register_user', 'fetch_messages']) + eq_(db.DB_CALLS, ['register_user', 'fetch_messages', + 'fetch_timestamp_messages']) eq_(msg["status"], 200) db.DB_CALLS = [] db.TRACK_DB_CALLS = False @@ -1882,6 +1885,13 @@ def test_process_notif_doesnt_run_after_stop(self): self.proto.process_notifications() eq_(self.proto.ps._notification_fetch, None) + def test_check_notif_doesnt_run_after_stop(self): + self._connect() + self.proto.ps.uaid = uuid.uuid4().hex + self.proto.ps._should_stop = True + self.proto.check_missed_notifications(None) + eq_(self.proto.ps._notification_fetch, None) + def test_process_notif_paused_on_finish(self): self._connect() self.proto.ps.uaid = uuid.uuid4().hex @@ -1896,7 +1906,8 @@ def test_notif_finished_with_webpush(self): self.proto.ps.use_webpush = True self.proto.deferToLater = Mock() self.proto.ps._check_notifications = True - self.proto.finish_notifications(None) + self.proto.ps.scan_timestamps = True + self.proto.finish_notifications((None, [])) ok_(self.proto.deferToLater.called) def test_notif_finished_with_webpush_with_notifications(self): @@ -1912,7 +1923,7 @@ def test_notif_finished_with_webpush_with_notifications(self): ) self.proto.ps.updates_sent[str(notif.channel_id)] = [] - self.proto.finish_webpush_notifications([notif]) + self.proto.finish_webpush_notifications((None, [notif])) ok_(self.send_mock.called) def test_notif_finished_with_webpush_with_old_notifications(self): @@ -1930,7 +1941,7 @@ def test_notif_finished_with_webpush_with_old_notifications(self): self.proto.ps.updates_sent[str(notif.channel_id)] = [] self.proto.force_retry = Mock() - self.proto.finish_webpush_notifications([notif]) + self.proto.finish_webpush_notifications((None, [notif])) ok_(self.proto.force_retry.called) ok_(not self.send_mock.called) diff --git a/autopush/utils.py b/autopush/utils.py index 4576120a..c808b6ca 100644 --- a/autopush/utils.py +++ b/autopush/utils.py @@ -14,16 +14,16 @@ attrs, attrib ) -from boto.dynamodb2.items import Item # flake8: noqa -from cryptography.fernet import Fernet # flake8: noqa +from boto.dynamodb2.items import Item # noqa +from cryptography.fernet import Fernet # noqa from jose import jwt -from typing import ( +from typing import ( # noqa Any, Dict, Optional, Union, Tuple, -) # flake8: noqa +) from ua_parser import user_agent_parser from autopush.exceptions import InvalidTokenException @@ -57,6 +57,10 @@ $ """, re.VERBOSE) +# Time multipliers for conversion from seconds to ms/ns +MS_MULT = pow(10, 3) +NS_MULT = pow(10, 6) + def normalize_id(ident): # type: (Union[uuid.UUID, str]) -> str @@ -269,6 +273,7 @@ class WebPushNotification(object): data = attrib(default=None) # type: Optional[str] headers = attrib(default=None) # type: Optional[Dict[str, str]] timestamp = attrib(default=Factory(lambda: int(time.time()))) # type: int + sortkey_timestamp = attrib(default=None) # type: Optional[int] topic = attrib(default=None) # type: Optional[str] message_id = attrib(default=None) # type: str @@ -277,19 +282,26 @@ class WebPushNotification(object): # message with any update_id should be removed. update_id = attrib(default=None) # type: str + # Whether this notification should follow legacy non-topic rules + legacy = attrib(default=False) # type: bool + def generate_message_id(self, fernet): # type: (Fernet) -> str """Generate a message-id suitable for accessing the message - For non-topic messages, no sort_key version is currently used and the - message-id is: - - Encrypted(m : uaid.hex : channel_id.hex) - For topic messages, a sort_key version of 01 is used, and the topic is included for reference: - Encrypted(01 : uaid.hex : channel_id.hex : topic) + Encrypted('01' : uaid.hex : channel_id.hex : topic) + + For topic messages, a sort_key version of 02 is used: + + Encrypted('02' : uaid.hex : channel_id.hex : timestamp) + + For legacy non-topic messages, no sort_key version was used and the + message-id was: + + Encrypted('m' : uaid.hex : channel_id.hex) This is a blocking call. @@ -297,22 +309,33 @@ def generate_message_id(self, fernet): if self.topic: msg_key = ":".join(["01", self.uaid.hex, self.channel_id.hex, self.topic]) - else: + elif self.legacy: msg_key = ":".join(["m", self.uaid.hex, self.channel_id.hex]) + else: + self.sortkey_timestamp = self.sortkey_timestamp or ns_time() + msg_key = ":".join(["02", self.uaid.hex, self.channel_id.hex, + str(self.sortkey_timestamp)]) self.message_id = fernet.encrypt(msg_key.encode('utf8')) self.update_id = self.message_id return self.message_id @staticmethod def parse_decrypted_message_id(decrypted_token): - # type: (str) -> Dict[str, str] + # type: (str) -> Dict[str, Any] """Parses a decrypted message-id into component parts""" topic = None + sortkey_timestamp = None if decrypted_token.startswith("01:"): info = decrypted_token.split(":") if len(info) != 4: raise InvalidTokenException("Incorrect number of token parts.") api_ver, uaid, chid, topic = info + elif decrypted_token.startswith("02:"): + info = decrypted_token.split(":") + if len(info) != 4: + raise InvalidTokenException("Incorrect number of token parts.") + api_ver, uaid, chid, raw_sortkey = info + sortkey_timestamp = int(raw_sortkey) else: info = decrypted_token.split(":") if len(info) != 3: @@ -324,6 +347,7 @@ def parse_decrypted_message_id(decrypted_token): uaid=uaid, chid=chid, topic=topic, + sortkey_timestamp=sortkey_timestamp, ) def cleanup_headers(self): @@ -358,27 +382,54 @@ def cleanup_headers(self): @property def sort_key(self): # type: () -> str - """Return an appropriate sort_key for this notification""" + """Return an appropriate sort_key for this notification + + For new messages: + + 02:{sortkey_timestamp}:{chid} + + For topic messages: + + 01:{chid}:{topic} + + Old format for non-topic messages that is no longer returned: + + {chid}:{message_id} + + """ chid = normalize_id(self.channel_id) if self.topic: return "01:{chid}:{topic}".format(chid=chid, topic=self.topic) + elif self.legacy: + return "{chid}:{message_id}".format( + chid=chid, message_id=self.message_id + ) else: - return "{chid}:{message_id}".format(chid=chid, - message_id=self.message_id) + # Created as late as possible when storing a message + self.sortkey_timestamp = self.sortkey_timestamp or ns_time() + return "02:{sortkey_timestamp}:{chid}".format( + sortkey_timestamp=self.sortkey_timestamp, + chid=chid, + ) @staticmethod def parse_sort_key(sort_key): - # type: (str) -> Dict[str, str] + # type: (str) -> Dict[str, Any] """Parse the sort key from the database""" topic = None + sortkey_timestamp = None message_id = None - if re.match(r'^\d\d:', sort_key): + if sort_key.startswith("01:"): api_ver, channel_id, topic = sort_key.split(":") + elif sort_key.startswith("02:"): + api_ver, raw_sortkey, channel_id = sort_key.split(":") + sortkey_timestamp = int(raw_sortkey) else: channel_id, message_id = sort_key.split(":") api_ver = "00" return dict(api_ver=api_ver, channel_id=channel_id, - topic=topic, message_id=message_id) + topic=topic, message_id=message_id, + sortkey_timestamp=sortkey_timestamp) @property def location(self): @@ -401,32 +452,41 @@ def from_message_table(cls, uaid, item): # type: (uuid.UUID, Union[Dict[str, Any], Item]) -> WebPushNotification """Create a WebPushNotification from a message table item""" key_info = cls.parse_sort_key(item["chidmessageid"]) - if key_info.get("topic"): + if key_info["api_ver"] in ["01", "02"]: key_info["message_id"] = item["updateid"] - return cls(uaid=uaid, - channel_id=uuid.UUID(key_info["channel_id"]), - data=item.get("data"), - headers=item.get("headers"), - ttl=item["ttl"], - topic=key_info.get("topic"), - message_id=key_info["message_id"], - update_id=item.get("updateid"), - timestamp=item.get("timestamp"), - ) + notif = cls( + uaid=uaid, + channel_id=uuid.UUID(key_info["channel_id"]), + data=item.get("data"), + headers=item.get("headers"), + ttl=item["ttl"], + topic=key_info.get("topic"), + message_id=key_info["message_id"], + update_id=item.get("updateid"), + timestamp=item.get("timestamp"), + sortkey_timestamp=key_info.get("sortkey_timestamp") + ) + + # Ensure we generate the sort-key properly for legacy messges + if key_info["api_ver"] == "00": + notif.legacy = True + + return notif @classmethod - def from_webpush_request_schema(cls, data, fernet): - # type: (Dict[str, Any], Fernet) -> WebPushNotification + def from_webpush_request_schema(cls, data, fernet, legacy=False): + # type: (Dict[str, Any], Fernet, bool) -> WebPushNotification """Create a WebPushNotification from a validated WebPushRequestSchema This is a blocking call. """ sub = data["subscription"] - notif = cls(uaid=sub["uaid"], channel_id=sub["chid"], - data=data["body"], headers=data["headers"], - ttl=data["headers"]["ttl"], - topic=data["headers"]["topic"]) + notif = cls( + uaid=sub["uaid"], channel_id=sub["chid"], data=data["body"], + headers=data["headers"], ttl=data["headers"]["ttl"], + topic=data["headers"]["topic"], legacy=legacy, + ) if notif.data: notif.cleanup_headers() @@ -458,6 +518,7 @@ def from_message_id(cls, message_id, fernet): ttl=None, topic=key_info["topic"], message_id=message_id, + sortkey_timestamp=key_info.get("sortkey_timestamp"), ) if key_info["topic"]: notif.update_id = message_id @@ -512,7 +573,7 @@ def websocket_format(self): messageType="notification", channelID=normalize_id(self.channel_id), version=self.version, - ) + ) # type: Dict[str, Any] if self.data: payload["data"] = self.data payload["headers"] = { @@ -524,4 +585,10 @@ def websocket_format(self): def ms_time(): # type: () -> int """Return current time.time call as ms and a Python int""" - return int(time.time() * 1000) + return int(time.time() * MS_MULT) + + +def ns_time(): + # type: () -> int + """Return current time.time call as a ns int""" + return int(time.time() * NS_MULT) diff --git a/autopush/web/push_validation.py b/autopush/web/push_validation.py index 7a672b5e..01a2f8fc 100644 --- a/autopush/web/push_validation.py +++ b/autopush/web/push_validation.py @@ -289,6 +289,7 @@ def fixup_output(self, d): # Set the notification based on the validated request schema data d["notification"] = WebPushNotification.from_webpush_request_schema( - data=d, fernet=self.context["settings"].fernet + data=d, fernet=self.context["settings"].fernet, + legacy=self.context["settings"]._notification_legacy, ) return d diff --git a/autopush/websocket.py b/autopush/websocket.py index 68c95718..94c7f06d 100644 --- a/autopush/websocket.py +++ b/autopush/websocket.py @@ -32,8 +32,8 @@ import json import time import uuid -from collections import defaultdict, namedtuple -from functools import wraps +from collections import defaultdict +from functools import partial, wraps from random import randrange from autobahn.twisted.websocket import WebSocketServerProtocol @@ -59,7 +59,12 @@ from twisted.python import failure from twisted.web._newclient import ResponseFailed from twisted.web.resource import Resource -from typing import List # flake8: noqa +from typing import ( # noqa + Dict, + List, + Optional, + Tuple, +) from zope.interface import implements from autopush import __version__ @@ -68,8 +73,9 @@ has_connected_this_month, hasher, generate_last_connect, - dump_uaid + dump_uaid, ) +from autopush.db import Message # noqa from autopush.noseplugin import track_object from autopush.protocol import IgnoreBody from autopush.utils import ( @@ -149,6 +155,10 @@ class PushState(object): 'message', 'rotate_message_table', + # Timestamped message handling + 'scan_timestamps', + 'current_timestamp', + 'ping_time_out', '_check_notifications', '_more_notifications', @@ -188,8 +198,8 @@ def __init__(self, settings, request): self.metrics = settings.metrics self.metrics.increment("client.socket.connect", tags=self._base_tags or None) - self.uaid = None - self.uaid_obj = None + self.uaid = None # Optional[str] + self.uaid_obj = None # Optional[uuid.UUID] self.uaid_hash = "" self.last_ping = 0 self.check_storage = False @@ -206,6 +216,10 @@ def __init__(self, settings, request): self._check_notifications = False self._more_notifications = False + # Timestamp message defaults + self.scan_timestamps = False + self.current_timestamp = None + # Hanger for common actions we defer self._notification_fetch = None self._register = None @@ -218,6 +232,7 @@ def __init__(self, settings, request): @property def message(self): + # type: () -> Message """Property to access the currently used message table""" return self.settings.message_tables[self.message_month] @@ -322,7 +337,8 @@ def log_failure(self, failure, **kwargs): if isinstance(exc, JSONResponseError): self.log.info("JSONResponseError: {exc}", exc=exc, **kwargs) else: - self.log.failure(format="Unexpected error", failure=failure, **kwargs) + self.log.failure(format="Unexpected error", failure=failure, + **kwargs) @property def paused(self): @@ -848,8 +864,7 @@ def process_notifications(self): self.ps._more_notifications = True if self.ps.use_webpush: - d = self.deferToThread(self.ps.message.fetch_messages, - self.ps.uaid_obj) + d = self.deferToThread(self.webpush_fetch()) else: d = self.deferToThread( self.ap_settings.storage.fetch_notifications, self.ps.uaid) @@ -859,6 +874,14 @@ def process_notifications(self): d.addErrback(self.error_notifications) self.ps._notification_fetch = d + def webpush_fetch(self): + """Helper to return an appropriate function to fetch messages""" + if self.ps.scan_timestamps: + return partial(self.ps.message.fetch_timestamp_messages, + self.ps.uaid_obj, self.ps.current_timestamp) + else: + return partial(self.ps.message.fetch_messages, self.ps.uaid_obj) + def error_notifications(self, fail): """errBack for notification check failing""" # If we error'd out on this important check, we drop the connection @@ -912,16 +935,29 @@ def finish_notifications(self, notifs): d = self.deferToLater(1, self.process_notifications) d.addErrback(self.trap_cancel) - def finish_webpush_notifications(self, notifs): - """webpush notification processor + def finish_webpush_notifications(self, result): + # type: (Tuple[str, List[WebPushNotification]]) -> None + """WebPush notification processor""" + timestamp, notifs = result - :type notifs: List[autopush.utils.WebPushNotification] + # If there's a timestamp, update our current one to it + if timestamp: + self.ps.current_timestamp = timestamp - """ if not notifs: - # No more notifications, we can stop. + # No more notifications, check timestamped? + if not self.ps.scan_timestamps: + # Scan for timestamped then + self.ps.scan_timestamps = True + d = self.deferToLater(0, self.process_notifications) + d.addErrback(self.trap_cancel) + return + + # No more notifications, and we've scanned timestamped. self.ps._more_notifications = False + self.ps.scan_timestamps = False if self.ps._check_notifications: + # Told to check again, start over self.ps._check_notifications = False d = self.deferToLater(1, self.process_notifications) d.addErrback(self.trap_cancel) @@ -935,16 +971,38 @@ def finish_webpush_notifications(self, notifs): # Send out all the notifications now = int(time.time()) + messages_sent = False for notif in notifs: # If the TTL is too old, don't deliver and fire a delete off if notif.expired(at_time=now): - self.force_retry(self.ps.message.delete_message, notif) - continue + if not notif.sortkey_timestamp: + # Delete non-timestamped messages + self.force_retry(self.ps.message.delete_message, notif) + + # nocover here as coverage gets confused on the line below + # for unknown reasons + continue # pragma: nocover self.ps.updates_sent[str(notif.channel_id)].append(notif) msg = notif.websocket_format() + messages_sent = True self.sendJSON(msg) + # Did we send any messages? + if messages_sent: + return + + # No messages sent, update the record if needed + if self.ps.current_timestamp: + self.force_retry( + self.ps.message.update_last_message_read, + self.ps.uaid_obj, + self.ps.current_timestamp + ) + + # Schedule a new process check + self.check_missed_notifications(None) + def _rotate_message_table(self): """Function to fire off a message table copy of channels + update the router current_month entry""" @@ -1152,7 +1210,9 @@ def _handle_webpush_ack(self, chid, version, code): def ver_filter(notif): return notif.version == version - found = filter(ver_filter, self.ps.direct_updates[chid]) + found = filter( + ver_filter, self.ps.direct_updates[chid] + ) # type: List[WebPushNotification] if found: msg = found[0] size = len(msg.data) if msg.data else 0 @@ -1164,7 +1224,9 @@ def ver_filter(notif): self.ps.direct_updates[chid].remove(msg) return - found = filter(ver_filter, self.ps.updates_sent[chid]) + found = filter( + ver_filter, self.ps.updates_sent[chid] + ) # type: List[WebPushNotification] if found: msg = found[0] size = len(msg.data) if msg.data else 0 @@ -1173,12 +1235,36 @@ def ver_filter(notif): message_size=size, uaid_hash=self.ps.uaid_hash, user_agent=self.ps.user_agent, code=code, **self.ps.raw_agent) - d = self.force_retry(self.ps.message.delete_message, msg) - # We don't remove the update until we know the delete ran - # This is because we don't use range queries on dynamodb and we - # need to make sure this notification is deleted from the db before - # we query it again (to avoid dupes). - d.addBoth(self._handle_webpush_update_remove, chid, msg) + + if msg.sortkey_timestamp: + # Is this the last un-acked message we're waiting for? + last_unacked = sum( + len(sent) for sent in self.ps.updates_sent.itervalues() + ) == 1 + + if (msg.sortkey_timestamp == self.ps.current_timestamp or + last_unacked): + # If it's the last message in the batch, or last un-acked + # message + d = self.force_retry( + self.ps.message.update_last_message_read, + self.ps.uaid_obj, + self.ps.current_timestamp, + ) + d.addBoth(self._handle_webpush_update_remove, chid, msg) + else: + # It's timestamped, but not the last of this batch, + # so we just remove it from local tracking + self._handle_webpush_update_remove(None, chid, msg) + d = None + else: + # No sortkey_timestamp, so legacy/topic message, delete + d = self.force_retry(self.ps.message.delete_message, msg) + # We don't remove the update until we know the delete ran + # This is because we don't use range queries on dynamodb and + # we need to make sure this notification is deleted from the + # db before we query it again (to avoid dupes). + d.addBoth(self._handle_webpush_update_remove, chid, msg) return d def _handle_webpush_update_remove(self, result, chid, notif): @@ -1262,7 +1348,12 @@ def check_missed_notifications(self, results, resume=False): return # Should we check again? - if self.ps._check_notifications or self.ps._more_notifications: + if self.ps._more_notifications: + self.process_notifications() + elif self.ps._check_notifications: + # If we were told to check notifications, start over since we might + # have missed a topic message + self.ps.scan_timestamps = False self.process_notifications() def bad_message(self, typ, message=None, url=DEFAULT_WS_ERR):