diff --git a/autopush/db.py b/autopush/db.py index 60614e5e..e451e125 100644 --- a/autopush/db.py +++ b/autopush/db.py @@ -47,28 +47,40 @@ from boto.dynamodb2.layer1 import DynamoDBConnection from boto.dynamodb2.table import Table, Item from boto.dynamodb2.types import NUMBER -from typing import Iterable, List # flake8: noqa +from typing import ( # noqa + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Set, + TypeVar, + Tuple, + Union, +) from autopush.exceptions import AutopushException +from autopush.metrics import IMetrics # noqa from autopush.utils import ( generate_hash, normalize_id, WebPushNotification, ) +# Typing +T = TypeVar('T') # noqa + key_hash = "" TRACK_DB_CALLS = False DB_CALLS = [] def get_month(delta=0): + # type: (int) -> datetime.date """Basic helper function to get a datetime.date object iterations months ahead/behind of now. - :type delta: int - - :rtype: datetime.datetime - """ new = last = datetime.date.today() # Move until we hit a new month, this avoids having to manually @@ -85,12 +97,15 @@ def get_month(delta=0): def hasher(uaid): + # type: (str) -> str + """Hashes a key using a key_hash if present""" if key_hash: return generate_hash(key_hash, uaid) return uaid def dump_uaid(uaid_data): + # type: (Union[Dict[str, Any], Item]) -> str """Return a dict for a uaid. This is utilized instead of repr since some db methods return a @@ -105,6 +120,7 @@ def dump_uaid(uaid_data): def make_rotating_tablename(prefix, delta=0, date=None): + # type: (str, int, Optional[datetime.date]) -> str """Creates a tablename for table rotation based on a prefix with a given month delta.""" if not date: @@ -114,6 +130,7 @@ def make_rotating_tablename(prefix, delta=0, date=None): def create_rotating_message_table(prefix="message", read_throughput=5, write_throughput=5, delta=0): + # type: (str, int, int, int) -> Table """Create a new message table for webpush style message storage""" tablename = make_rotating_tablename(prefix, delta) return Table.create(tablename, @@ -127,6 +144,7 @@ def create_rotating_message_table(prefix="message", read_throughput=5, def get_rotating_message_table(prefix="message", delta=0, date=None, message_read_throughput=5, message_write_throughput=5): + # type: (str, int, Optional[datetime.date], int, int) -> Table """Gets the message table for the current month.""" db = DynamoDBConnection() dblist = db.list_tables()["TableNames"] @@ -142,6 +160,7 @@ def get_rotating_message_table(prefix="message", delta=0, date=None, def create_router_table(tablename="router", read_throughput=5, write_throughput=5): + # type: (str, int, int) -> Table """Create a new router table The last_connect index is a value used to determine the last month a user @@ -164,13 +183,14 @@ def create_router_table(tablename="router", read_throughput=5, 'AccessIndex', parts=[ HashKey('last_connect', - data_type=NUMBER)], + data_type=NUMBER)], throughput=dict(read=5, write=5))], ) def create_storage_table(tablename="storage", read_throughput=5, write_throughput=5): + # type: (str, int, int) -> Table """Create a new storage table for simplepush style notification storage""" return Table.create(tablename, schema=[HashKey("uaid"), RangeKey("chid")], @@ -180,6 +200,7 @@ def create_storage_table(tablename="storage", read_throughput=5, def _make_table(table_func, tablename, read_throughput, write_throughput): + # type: (Callable[[str, int, int], Table], str, int, int) -> Table """Private common function to make a table with a table func""" db = DynamoDBConnection() dblist = db.list_tables()["TableNames"] @@ -191,6 +212,7 @@ def _make_table(table_func, tablename, read_throughput, write_throughput): def get_router_table(tablename="router", read_throughput=5, write_throughput=5): + # type: (str, int, int) -> Table """Get the main router table object Creates the table if it doesn't already exist, otherwise returns the @@ -203,6 +225,7 @@ def get_router_table(tablename="router", read_throughput=5, def get_storage_table(tablename="storage", read_throughput=5, write_throughput=5): + # type: (str, int, int) -> Table """Get the main storage table object Creates the table if it doesn't already exist, otherwise returns the @@ -214,6 +237,7 @@ def get_storage_table(tablename="storage", read_throughput=5, def preflight_check(storage, router, uaid="deadbeef00000000deadbeef00000000"): + # type: (Storage, Router, str) -> None """Performs a pre-flight check of the storage/router/message to ensure appropriate permissions for operation. @@ -255,6 +279,7 @@ def preflight_check(storage, router, uaid="deadbeef00000000deadbeef00000000"): def track_provisioned(func): + # type: (Callable[..., T]) -> Callable[..., T] """Tracks provisioned exceptions and increments a metric for them named after the function decorated""" @wraps(func) @@ -276,13 +301,8 @@ def wrapper(self, *args, **kwargs): def has_connected_this_month(item): - """Whether or not a router item has connected this month - - :type item: dict - - :rtype: bool - - """ + # type: (Dict[str, Any]) -> bool + """Whether or not a router item has connected this month""" last_connect = item.get("last_connect") if not last_connect: return False @@ -293,16 +313,13 @@ def has_connected_this_month(item): def generate_last_connect(): + # type: () -> int """Generate a last_connect This intentionally generates a limited set of keys for each month in a known sequence. For each month, there's 24 hours * 10 random numbers for a total of 240 keys per month depending on when the user migrates forward. - :type date: datetime.datetime - - :rtype: int - """ today = datetime.datetime.today() val = "".join([ @@ -315,15 +332,12 @@ def generate_last_connect(): def generate_last_connect_values(date): + # type: (datetime.date) -> Iterable[int] """Generator of last_connect values for a given date Creates an iterator that yields all the valid values for ``last_connect`` for a given year/month. - :type date: datetime.datetime - - :rtype: Iterable[int] - """ year = str(date.year) month = str(date.month).zfill(2) @@ -337,6 +351,7 @@ def generate_last_connect_values(date): class Storage(object): """Create a Storage table abstraction on top of a DynamoDB Table object""" def __init__(self, table, metrics): + # type: (Table, IMetrics) -> None """Create a new Storage object :param table: :class:`Table` object. @@ -350,6 +365,7 @@ def __init__(self, table, metrics): @track_provisioned def fetch_notifications(self, uaid): + # type: (str) -> List[Dict[str, Any]] """Fetch all notifications for a UAID :raises: @@ -363,6 +379,7 @@ def fetch_notifications(self, uaid): @track_provisioned def save_notification(self, uaid, chid, version): + # type: (str, str, Optional[int]) -> bool """Save a notification for the UAID :raises: @@ -388,10 +405,10 @@ def save_notification(self, uaid, chid, version): return False def delete_notification(self, uaid, chid, version=None): + # type: (str, str, Optional[int]) -> bool """Delete a notification for a UAID :returns: Whether or not the notification was able to be deleted. - :rtype: bool """ try: @@ -411,6 +428,7 @@ def delete_notification(self, uaid, chid, version=None): class Message(object): """Create a Message table abstraction on top of a DynamoDB Table object""" def __init__(self, table, metrics): + # type: (Table, IMetrics) -> None """Create a new Message object :param table: :class:`Table` object. @@ -424,6 +442,7 @@ def __init__(self, table, metrics): @track_provisioned def register_channel(self, uaid, channel_id): + # type: (str, str) -> bool """Register a channel for a given uaid""" conn = self.table.connection db_key = self.encode({"uaid": hasher(uaid), "chidmessageid": " "}) @@ -441,6 +460,7 @@ def register_channel(self, uaid, channel_id): @track_provisioned def unregister_channel(self, uaid, channel_id, **kwargs): + # type: (str, str, **str) -> bool """Remove a channel registration for a given uaid""" conn = self.table.connection db_key = self.encode({"uaid": hasher(uaid), "chidmessageid": " "}) @@ -466,6 +486,7 @@ def unregister_channel(self, uaid, channel_id, **kwargs): @track_provisioned def all_channels(self, uaid): + # type: (str) -> Tuple[bool, Set[str]] """Retrieve a list of all channels for a given uaid""" # Note: This only returns the chids associated with the UAID. @@ -480,6 +501,7 @@ def all_channels(self, uaid): @track_provisioned def save_channels(self, uaid, channels): + # type: (str, Set[str]) -> None """Save out a set of channels""" self.table.put_item(data=dict( uaid=hasher(uaid), @@ -489,12 +511,8 @@ def save_channels(self, uaid, channels): @track_provisioned def store_message(self, notification): - """Stores a WebPushNotification in the message table - - :type notification: WebPushNotification - :type timestamp: int - - """ + # type: (WebPushNotification) -> bool + """Stores a WebPushNotification in the message table""" item = dict( uaid=hasher(notification.uaid.hex), chidmessageid=notification.sort_key, @@ -509,11 +527,8 @@ def store_message(self, notification): @track_provisioned def delete_message(self, notification): - """Deletes a specific message - - :type notification: WebPushNotification - - """ + # type: (WebPushNotification) -> bool + """Deletes a specific message""" if notification.update_id: try: self.table.delete_item( @@ -530,25 +545,94 @@ def delete_message(self, notification): return True @track_provisioned - def fetch_messages(self, uaid, limit=10): + def fetch_messages( + self, + uaid, # type: uuid.UUID + limit=10, # type: int + ): + # type: (...) -> Tuple[Optional[int], List[WebPushNotification]] """Fetches messages for a uaid - :type uaid: uuid.UUID - :type limit: int + :returns: A tuple of the last timestamp to read for timestamped + messages and the list of non-timestamped messages. """ # Eagerly fetches all results in the result set. - results = self.table.query_2(uaid__eq=hasher(uaid.hex), - chidmessageid__gt=" ", - consistent=True, limit=limit) - return [ + results = list(self.table.query_2(uaid__eq=hasher(uaid.hex), + chidmessageid__lt="02", + consistent=True, limit=limit)) + + # First extract the position if applicable, slightly higher than 01: + # to ensure we don't load any 01 remainders that didn't get deleted + # yet + last_position = None + if results: + # Ensure we return an int, as boto2 can return Decimals + if results[0].get("current_timestamp"): + last_position = int(results[0]["current_timestamp"]) + + return last_position, [ + WebPushNotification.from_message_table(uaid, x) + for x in results[1:] + ] + + @track_provisioned + def fetch_timestamp_messages( + self, + uaid, # type: uuid.UUID + timestamp=None, # type: Optional[int] + limit=10, # type: int + ): + # type: (...) -> Tuple[Optional[int], List[WebPushNotification]] + """Fetches timestamped messages for a uaid + + Note that legacy messages start with a hex UUID, so they may be mixed + in with timestamp messages beginning with 02. As such we only move our + last_position forward to the last timestamped message. + + :returns: A tuple of the last timestamp to read and the list of + timestamped messages. + + """ + # Turn the timestamp into a proper sort key + if timestamp: + sortkey = "02:{timestamp}:z".format(timestamp=timestamp) + else: + sortkey = "01;" + + results = list(self.table.query_2(uaid__eq=hasher(uaid.hex), + chidmessageid__gt=sortkey, + consistent=True, limit=limit)) + notifs = [ WebPushNotification.from_message_table(uaid, x) for x in results ] + ts_notifs = [x for x in notifs if x.sortkey_timestamp] + last_position = None + if ts_notifs: + last_position = ts_notifs[-1].sortkey_timestamp + return last_position, notifs + + @track_provisioned + def update_last_message_read(self, uaid, timestamp): + # type: (uuid.UUID, int) -> bool + """Update the last read timestamp for a user""" + conn = self.table.connection + db_key = self.encode({"uaid": hasher(uaid.hex), "chidmessageid": " "}) + expr = "SET current_timestamp=:timestamp" + expr_values = self.encode({":timestamp": timestamp}) + conn.update_item( + self.table.table_name, + db_key, + update_expression=expr, + expression_attribute_values=expr_values, + ) + return True class Router(object): """Create a Router table abstraction on top of a DynamoDB Table object""" def __init__(self, table, metrics): + # type: (Table, IMetrics) -> None """Create a new Router object :param table: :class:`Table` object. @@ -561,10 +645,9 @@ def __init__(self, table, metrics): self.encode = table._encode_keys def get_uaid(self, uaid): + # type: (str) -> Item """Get the database record for the UAID - :returns: User item - :rtype: :class:`~boto.dynamodb2.items.Item` :raises: :exc:`ItemNotFound` if there is no record for this UAID. :exc:`ProvisionedThroughputExceededException` if dynamodb table @@ -592,13 +675,13 @@ def get_uaid(self, uaid): @track_provisioned def register_user(self, data): + # type: (Dict[str, Any]) -> Tuple[bool, Dict[str, Any], Dict[str, Any]] """Register this user If a record exists with a newer ``connected_at``, then the user will not be registered. :returns: Whether the user was registered or not. - :rtype: tuple :raises: :exc:`ProvisionedThroughputExceededException` if dynamodb table exceeds throughput. @@ -647,6 +730,8 @@ def register_user(self, data): @track_provisioned def drop_user(self, uaid): + # type: (str) -> bool + """Drops a user record""" # The following hack ensures that only uaids that exist and are # deleted return true. huaid = hasher(uaid) @@ -654,16 +739,14 @@ def drop_user(self, uaid): expected={"uaid__eq": huaid}) def delete_uaids(self, uaids): - """Issue a batch delete call for the given uaids - - :type uaids: List[str] - - """ + # type: (List[str]) -> None + """Issue a batch delete call for the given uaids""" with self.table.batch_write() as batch: for uaid in uaids: batch.delete_item(uaid=uaid) def drop_old_users(self, months_ago=2): + # type: (int) -> Iterable[int] """Drops user records that have no recent connection Utilizes the last_connect index to locate users that haven't @@ -684,10 +767,8 @@ def drop_old_users(self, months_ago=2): quickly as possible. :param months_ago: how many months ago since the last connect - :type months_ago: int :returns: Iterable of how many deletes were run - :rtype: Iterable[int] """ prior_date = get_month(-months_ago) @@ -713,17 +794,20 @@ def drop_old_users(self, months_ago=2): @track_provisioned def update_message_month(self, uaid, month): + # type: (str, str) -> bool """Update the route tables current_message_month Note that we also update the last_connect at this point since webpush - users when connecting will always call this once that month. + users when connecting will always call this once that month. The + current_timestamp is also reset as a new month has no last read + timestamp. """ conn = self.table.connection db_key = self.encode({"uaid": hasher(uaid)}) - expr = "SET current_month=:curmonth, last_connect=:last_connect" + expr = ("SET current_month=:curmonth, last_connect=:last_connect") expr_values = self.encode({":curmonth": month, - ":last_connect": generate_last_connect() + ":last_connect": generate_last_connect(), }) conn.update_item( self.table.table_name, @@ -735,13 +819,13 @@ def update_message_month(self, uaid, month): @track_provisioned def clear_node(self, item): + # type: (Item) -> bool """Given a router item and remove the node_id The node_id will only be cleared if the ``connected_at`` matches up with the item's ``connected_at``. :returns: Whether the node was cleared or not. - :rtype: bool :raises: :exc:`ProvisionedThroughputExceededException` if dynamodb table exceeds throughput. diff --git a/autopush/settings.py b/autopush/settings.py index 7c5126c9..2936ed3a 100644 --- a/autopush/settings.py +++ b/autopush/settings.py @@ -211,6 +211,10 @@ def __init__(self, self.ami_id = ami_id + # Generate messages per legacy rules, only used for testing to + # generate legacy data. + self._notification_legacy = False + @property def message(self): """Property that access the current message table""" diff --git a/autopush/tests/test_db.py b/autopush/tests/test_db.py index f27ca00f..e930be53 100644 --- a/autopush/tests/test_db.py +++ b/autopush/tests/test_db.py @@ -315,7 +315,8 @@ def test_message_storage(self): message.store_message(make_webpush_notification(self.uaid, chid)) message.store_message(make_webpush_notification(self.uaid, chid)) - all_messages = list(message.fetch_messages(uuid.UUID(self.uaid))) + _, all_messages = message.fetch_timestamp_messages( + uuid.UUID(self.uaid), " ") eq_(len(all_messages), 3) def test_message_storage_overwrite(self): diff --git a/autopush/tests/test_endpoint.py b/autopush/tests/test_endpoint.py index d1a7d751..cd2d0fa0 100644 --- a/autopush/tests/test_endpoint.py +++ b/autopush/tests/test_endpoint.py @@ -139,6 +139,17 @@ def handle_finish(result): self.message.delete(self._make_req('ignored')) return self.finish_deferred + def test_delete_invalid_timestamp_token(self): + tok = ":".join(["02", str(dummy_chid)]) + self.fernet_mock.decrypt.return_value = tok + + def handle_finish(result): + self.status_mock.assert_called_with(400, reason=None) + self.finish_deferred.addCallback(handle_finish) + + self.message.delete(self._make_req('ignored')) + return self.finish_deferred + def test_delete_success(self): tok = ":".join(["m", dummy_uaid.hex, str(dummy_chid)]) self.fernet_mock.decrypt.return_value = tok diff --git a/autopush/tests/test_integration.py b/autopush/tests/test_integration.py index 61940d12..5d013458 100644 --- a/autopush/tests/test_integration.py +++ b/autopush/tests/test_integration.py @@ -8,6 +8,7 @@ import time import urlparse import uuid +from contextlib import contextmanager from StringIO import StringIO from unittest.case import SkipTest @@ -448,6 +449,12 @@ def shut_down(self, client=None): if client: yield client.disconnect() + @contextmanager + def legacy_endpoint(self): + self._settings._notification_legacy = True + yield + self._settings._notification_legacy = False + class TestSimple(IntegrationBase): @inlineCallbacks @@ -703,15 +710,16 @@ def test_uaid_resumption_on_reconnect(self): class TestWebPush(IntegrationBase): @inlineCallbacks - def test_hello_only_has_two_calls(self): + def test_hello_only_has_three_calls(self): db.TRACK_DB_CALLS = True client = Client(self._ws_url, use_webpush=True) yield client.connect() result = yield client.hello() ok_(result != {}) eq_(result["use_webpush"], True) - yield client.wait_for(lambda: len(db.DB_CALLS) == 2) - eq_(db.DB_CALLS, ['register_user', 'fetch_messages']) + yield client.wait_for(lambda: len(db.DB_CALLS) == 3) + eq_(db.DB_CALLS, ['register_user', 'fetch_messages', + 'fetch_timestamp_messages']) db.DB_CALLS = [] db.TRACK_DB_CALLS = False @@ -872,17 +880,18 @@ def test_multiple_delivery_repeat_without_ack(self): yield self.shut_down(client) @inlineCallbacks - def test_multiple_delivery_with_single_ack(self): + def test_multiple_legacy_delivery_with_single_ack(self): data = str(uuid.uuid4()) data2 = str(uuid.uuid4()) client = yield self.quick_register(use_webpush=True) yield client.disconnect() ok_(client.channels) - yield client.send_notification(data=data) - yield client.send_notification(data=data2) + with self.legacy_endpoint(): + yield client.send_notification(data=data) + yield client.send_notification(data=data2) yield client.connect() yield client.hello() - result = yield client.get_notification() + result = yield client.get_notification(timeout=5) ok_(result != {}) ok_(result["data"] in map(base64url_encode, [data, data2])) result = yield client.get_notification() @@ -901,6 +910,45 @@ def test_multiple_delivery_with_single_ack(self): eq_(result, None) yield self.shut_down(client) + @inlineCallbacks + def test_multiple_delivery_with_single_ack(self): + data = str(uuid.uuid4()) + data2 = str(uuid.uuid4()) + client = yield self.quick_register(use_webpush=True) + yield client.disconnect() + ok_(client.channels) + yield client.send_notification(data=data) + yield client.send_notification(data=data2) + yield client.connect() + yield client.hello() + result = yield client.get_notification() + ok_(result != {}) + eq_(result["data"], base64url_encode(data)) + result2 = yield client.get_notification() + ok_(result2 != {}) + eq_(result2["data"], base64url_encode(data2)) + yield client.ack(result["channelID"], result["version"]) + + yield client.disconnect() + yield client.connect() + yield client.hello() + result = yield client.get_notification() + ok_(result != {}) + eq_(result["data"], base64url_encode(data)) + ok_(result["messageType"], "notification") + result2 = yield client.get_notification() + ok_(result2 != {}) + eq_(result2["data"], base64url_encode(data2)) + yield client.ack(result2["channelID"], result2["version"]) + + # Verify no messages are delivered + yield client.disconnect() + yield client.connect() + yield client.hello() + result = yield client.get_notification() + ok_(result is None) + yield self.shut_down(client) + @inlineCallbacks def test_multiple_delivery_with_multiple_ack(self): data = str(uuid.uuid4()) @@ -1021,6 +1069,64 @@ def test_ttl_expired(self): eq_(result, None) yield self.shut_down(client) + @inlineCallbacks + def test_ttl_batch_expired_and_good_one(self): + data = str(uuid.uuid4()) + data2 = str(uuid.uuid4()) + client = yield self.quick_register(use_webpush=True) + yield client.disconnect() + for x in range(0, 12): + yield client.send_notification(data=data, ttl=1) + + yield client.send_notification(data=data2) + time.sleep(1.5) + yield client.connect() + yield client.hello() + result = yield client.get_notification(timeout=4) + ok_(result is not None) + eq_(result["headers"]["encryption"], client._crypto_key) + eq_(result["data"], base64url_encode(data2)) + eq_(result["messageType"], "notification") + result = yield client.get_notification() + eq_(result, None) + yield self.shut_down(client) + + @inlineCallbacks + def test_ttl_batch_partly_expired_and_good_one(self): + data = str(uuid.uuid4()) + data1 = str(uuid.uuid4()) + data2 = str(uuid.uuid4()) + client = yield self.quick_register(use_webpush=True) + yield client.disconnect() + for x in range(0, 6): + yield client.send_notification(data=data) + + for x in range(0, 6): + yield client.send_notification(data=data1, ttl=1) + + yield client.send_notification(data=data2) + time.sleep(1.5) + yield client.connect() + yield client.hello() + + # Pull out and ack the first + for x in range(0, 6): + result = yield client.get_notification(timeout=4) + ok_(result is not None) + eq_(result["data"], base64url_encode(data)) + yield client.ack(result["channelID"], result["version"]) + + # Should have one more that is data2, this will only arrive if the + # other six were acked as that hits the batch size + result = yield client.get_notification(timeout=4) + ok_(result is not None) + eq_(result["data"], base64url_encode(data2)) + + # No more + result = yield client.get_notification() + eq_(result, None) + yield self.shut_down(client) + @inlineCallbacks def test_message_without_crypto_headers(self): data = str(uuid.uuid4()) @@ -1154,9 +1260,11 @@ def test_webpush_monthly_rotation(self): # Send in a notification, verify it landed in last months notification # table data = uuid.uuid4().hex - yield client.send_notification(data=data) - notifs = yield deferToThread(lm_message.fetch_messages, - uuid.UUID(client.uaid)) + with self.legacy_endpoint(): + yield client.send_notification(data=data) + ts, notifs = yield deferToThread(lm_message.fetch_timestamp_messages, + uuid.UUID(client.uaid), + " ") eq_(len(notifs), 1) # Connect the client, verify the migration @@ -1247,9 +1355,11 @@ def test_webpush_monthly_rotation_prior_record_exists(self): # Send in a notification, verify it landed in last months notification # table data = uuid.uuid4().hex - yield client.send_notification(data=data) - notifs = yield deferToThread(lm_message.fetch_messages, - uuid.UUID(client.uaid)) + with self.legacy_endpoint(): + yield client.send_notification(data=data) + _, notifs = yield deferToThread(lm_message.fetch_timestamp_messages, + uuid.UUID(client.uaid), + " ") eq_(len(notifs), 1) # Connect the client, verify the migration diff --git a/autopush/tests/test_websocket.py b/autopush/tests/test_websocket.py index dc8818b5..21a89c9c 100644 --- a/autopush/tests/test_websocket.py +++ b/autopush/tests/test_websocket.py @@ -583,7 +583,8 @@ def fake_msg(data): return (True, msg_data, data) mock_msg = Mock(wraps=db.Message) - mock_msg.fetch_messages.return_value = [] + mock_msg.fetch_messages.return_value = "01;", [] + mock_msg.fetch_timestamp_messages.return_value = None, [] mock_msg.all_channels.return_value = (None, []) self.proto.ap_settings.router.register_user = fake_msg # massage message_tables to include our fake range @@ -656,7 +657,8 @@ def fake_msg(data): return (True, msg_data, data) mock_msg = Mock(wraps=db.Message) - mock_msg.fetch_messages.return_value = [] + mock_msg.fetch_messages.return_value = "01;", [] + mock_msg.fetch_timestamp_messages.return_value = None, [] mock_msg.all_channels.return_value = (None, []) self.proto.ap_settings.router.register_user = fake_msg # massage message_tables to include our fake range @@ -730,7 +732,8 @@ def test_hello_webpush_uses_one_db_call(self): channelIDs=[])) def check_result(msg): - eq_(db.DB_CALLS, ['register_user', 'fetch_messages']) + eq_(db.DB_CALLS, ['register_user', 'fetch_messages', + 'fetch_timestamp_messages']) eq_(msg["status"], 200) db.DB_CALLS = [] db.TRACK_DB_CALLS = False @@ -1882,6 +1885,13 @@ def test_process_notif_doesnt_run_after_stop(self): self.proto.process_notifications() eq_(self.proto.ps._notification_fetch, None) + def test_check_notif_doesnt_run_after_stop(self): + self._connect() + self.proto.ps.uaid = uuid.uuid4().hex + self.proto.ps._should_stop = True + self.proto.check_missed_notifications(None) + eq_(self.proto.ps._notification_fetch, None) + def test_process_notif_paused_on_finish(self): self._connect() self.proto.ps.uaid = uuid.uuid4().hex @@ -1896,7 +1906,8 @@ def test_notif_finished_with_webpush(self): self.proto.ps.use_webpush = True self.proto.deferToLater = Mock() self.proto.ps._check_notifications = True - self.proto.finish_notifications(None) + self.proto.ps.scan_timestamps = True + self.proto.finish_notifications((None, [])) ok_(self.proto.deferToLater.called) def test_notif_finished_with_webpush_with_notifications(self): @@ -1912,7 +1923,7 @@ def test_notif_finished_with_webpush_with_notifications(self): ) self.proto.ps.updates_sent[str(notif.channel_id)] = [] - self.proto.finish_webpush_notifications([notif]) + self.proto.finish_webpush_notifications((None, [notif])) ok_(self.send_mock.called) def test_notif_finished_with_webpush_with_old_notifications(self): @@ -1930,7 +1941,7 @@ def test_notif_finished_with_webpush_with_old_notifications(self): self.proto.ps.updates_sent[str(notif.channel_id)] = [] self.proto.force_retry = Mock() - self.proto.finish_webpush_notifications([notif]) + self.proto.finish_webpush_notifications((None, [notif])) ok_(self.proto.force_retry.called) ok_(not self.send_mock.called) diff --git a/autopush/utils.py b/autopush/utils.py index 4576120a..c808b6ca 100644 --- a/autopush/utils.py +++ b/autopush/utils.py @@ -14,16 +14,16 @@ attrs, attrib ) -from boto.dynamodb2.items import Item # flake8: noqa -from cryptography.fernet import Fernet # flake8: noqa +from boto.dynamodb2.items import Item # noqa +from cryptography.fernet import Fernet # noqa from jose import jwt -from typing import ( +from typing import ( # noqa Any, Dict, Optional, Union, Tuple, -) # flake8: noqa +) from ua_parser import user_agent_parser from autopush.exceptions import InvalidTokenException @@ -57,6 +57,10 @@ $ """, re.VERBOSE) +# Time multipliers for conversion from seconds to ms/ns +MS_MULT = pow(10, 3) +NS_MULT = pow(10, 6) + def normalize_id(ident): # type: (Union[uuid.UUID, str]) -> str @@ -269,6 +273,7 @@ class WebPushNotification(object): data = attrib(default=None) # type: Optional[str] headers = attrib(default=None) # type: Optional[Dict[str, str]] timestamp = attrib(default=Factory(lambda: int(time.time()))) # type: int + sortkey_timestamp = attrib(default=None) # type: Optional[int] topic = attrib(default=None) # type: Optional[str] message_id = attrib(default=None) # type: str @@ -277,19 +282,26 @@ class WebPushNotification(object): # message with any update_id should be removed. update_id = attrib(default=None) # type: str + # Whether this notification should follow legacy non-topic rules + legacy = attrib(default=False) # type: bool + def generate_message_id(self, fernet): # type: (Fernet) -> str """Generate a message-id suitable for accessing the message - For non-topic messages, no sort_key version is currently used and the - message-id is: - - Encrypted(m : uaid.hex : channel_id.hex) - For topic messages, a sort_key version of 01 is used, and the topic is included for reference: - Encrypted(01 : uaid.hex : channel_id.hex : topic) + Encrypted('01' : uaid.hex : channel_id.hex : topic) + + For topic messages, a sort_key version of 02 is used: + + Encrypted('02' : uaid.hex : channel_id.hex : timestamp) + + For legacy non-topic messages, no sort_key version was used and the + message-id was: + + Encrypted('m' : uaid.hex : channel_id.hex) This is a blocking call. @@ -297,22 +309,33 @@ def generate_message_id(self, fernet): if self.topic: msg_key = ":".join(["01", self.uaid.hex, self.channel_id.hex, self.topic]) - else: + elif self.legacy: msg_key = ":".join(["m", self.uaid.hex, self.channel_id.hex]) + else: + self.sortkey_timestamp = self.sortkey_timestamp or ns_time() + msg_key = ":".join(["02", self.uaid.hex, self.channel_id.hex, + str(self.sortkey_timestamp)]) self.message_id = fernet.encrypt(msg_key.encode('utf8')) self.update_id = self.message_id return self.message_id @staticmethod def parse_decrypted_message_id(decrypted_token): - # type: (str) -> Dict[str, str] + # type: (str) -> Dict[str, Any] """Parses a decrypted message-id into component parts""" topic = None + sortkey_timestamp = None if decrypted_token.startswith("01:"): info = decrypted_token.split(":") if len(info) != 4: raise InvalidTokenException("Incorrect number of token parts.") api_ver, uaid, chid, topic = info + elif decrypted_token.startswith("02:"): + info = decrypted_token.split(":") + if len(info) != 4: + raise InvalidTokenException("Incorrect number of token parts.") + api_ver, uaid, chid, raw_sortkey = info + sortkey_timestamp = int(raw_sortkey) else: info = decrypted_token.split(":") if len(info) != 3: @@ -324,6 +347,7 @@ def parse_decrypted_message_id(decrypted_token): uaid=uaid, chid=chid, topic=topic, + sortkey_timestamp=sortkey_timestamp, ) def cleanup_headers(self): @@ -358,27 +382,54 @@ def cleanup_headers(self): @property def sort_key(self): # type: () -> str - """Return an appropriate sort_key for this notification""" + """Return an appropriate sort_key for this notification + + For new messages: + + 02:{sortkey_timestamp}:{chid} + + For topic messages: + + 01:{chid}:{topic} + + Old format for non-topic messages that is no longer returned: + + {chid}:{message_id} + + """ chid = normalize_id(self.channel_id) if self.topic: return "01:{chid}:{topic}".format(chid=chid, topic=self.topic) + elif self.legacy: + return "{chid}:{message_id}".format( + chid=chid, message_id=self.message_id + ) else: - return "{chid}:{message_id}".format(chid=chid, - message_id=self.message_id) + # Created as late as possible when storing a message + self.sortkey_timestamp = self.sortkey_timestamp or ns_time() + return "02:{sortkey_timestamp}:{chid}".format( + sortkey_timestamp=self.sortkey_timestamp, + chid=chid, + ) @staticmethod def parse_sort_key(sort_key): - # type: (str) -> Dict[str, str] + # type: (str) -> Dict[str, Any] """Parse the sort key from the database""" topic = None + sortkey_timestamp = None message_id = None - if re.match(r'^\d\d:', sort_key): + if sort_key.startswith("01:"): api_ver, channel_id, topic = sort_key.split(":") + elif sort_key.startswith("02:"): + api_ver, raw_sortkey, channel_id = sort_key.split(":") + sortkey_timestamp = int(raw_sortkey) else: channel_id, message_id = sort_key.split(":") api_ver = "00" return dict(api_ver=api_ver, channel_id=channel_id, - topic=topic, message_id=message_id) + topic=topic, message_id=message_id, + sortkey_timestamp=sortkey_timestamp) @property def location(self): @@ -401,32 +452,41 @@ def from_message_table(cls, uaid, item): # type: (uuid.UUID, Union[Dict[str, Any], Item]) -> WebPushNotification """Create a WebPushNotification from a message table item""" key_info = cls.parse_sort_key(item["chidmessageid"]) - if key_info.get("topic"): + if key_info["api_ver"] in ["01", "02"]: key_info["message_id"] = item["updateid"] - return cls(uaid=uaid, - channel_id=uuid.UUID(key_info["channel_id"]), - data=item.get("data"), - headers=item.get("headers"), - ttl=item["ttl"], - topic=key_info.get("topic"), - message_id=key_info["message_id"], - update_id=item.get("updateid"), - timestamp=item.get("timestamp"), - ) + notif = cls( + uaid=uaid, + channel_id=uuid.UUID(key_info["channel_id"]), + data=item.get("data"), + headers=item.get("headers"), + ttl=item["ttl"], + topic=key_info.get("topic"), + message_id=key_info["message_id"], + update_id=item.get("updateid"), + timestamp=item.get("timestamp"), + sortkey_timestamp=key_info.get("sortkey_timestamp") + ) + + # Ensure we generate the sort-key properly for legacy messges + if key_info["api_ver"] == "00": + notif.legacy = True + + return notif @classmethod - def from_webpush_request_schema(cls, data, fernet): - # type: (Dict[str, Any], Fernet) -> WebPushNotification + def from_webpush_request_schema(cls, data, fernet, legacy=False): + # type: (Dict[str, Any], Fernet, bool) -> WebPushNotification """Create a WebPushNotification from a validated WebPushRequestSchema This is a blocking call. """ sub = data["subscription"] - notif = cls(uaid=sub["uaid"], channel_id=sub["chid"], - data=data["body"], headers=data["headers"], - ttl=data["headers"]["ttl"], - topic=data["headers"]["topic"]) + notif = cls( + uaid=sub["uaid"], channel_id=sub["chid"], data=data["body"], + headers=data["headers"], ttl=data["headers"]["ttl"], + topic=data["headers"]["topic"], legacy=legacy, + ) if notif.data: notif.cleanup_headers() @@ -458,6 +518,7 @@ def from_message_id(cls, message_id, fernet): ttl=None, topic=key_info["topic"], message_id=message_id, + sortkey_timestamp=key_info.get("sortkey_timestamp"), ) if key_info["topic"]: notif.update_id = message_id @@ -512,7 +573,7 @@ def websocket_format(self): messageType="notification", channelID=normalize_id(self.channel_id), version=self.version, - ) + ) # type: Dict[str, Any] if self.data: payload["data"] = self.data payload["headers"] = { @@ -524,4 +585,10 @@ def websocket_format(self): def ms_time(): # type: () -> int """Return current time.time call as ms and a Python int""" - return int(time.time() * 1000) + return int(time.time() * MS_MULT) + + +def ns_time(): + # type: () -> int + """Return current time.time call as a ns int""" + return int(time.time() * NS_MULT) diff --git a/autopush/web/push_validation.py b/autopush/web/push_validation.py index 7a672b5e..01a2f8fc 100644 --- a/autopush/web/push_validation.py +++ b/autopush/web/push_validation.py @@ -289,6 +289,7 @@ def fixup_output(self, d): # Set the notification based on the validated request schema data d["notification"] = WebPushNotification.from_webpush_request_schema( - data=d, fernet=self.context["settings"].fernet + data=d, fernet=self.context["settings"].fernet, + legacy=self.context["settings"]._notification_legacy, ) return d diff --git a/autopush/websocket.py b/autopush/websocket.py index 68c95718..5bd58d85 100644 --- a/autopush/websocket.py +++ b/autopush/websocket.py @@ -32,8 +32,8 @@ import json import time import uuid -from collections import defaultdict, namedtuple -from functools import wraps +from collections import defaultdict +from functools import partial, wraps from random import randrange from autobahn.twisted.websocket import WebSocketServerProtocol @@ -59,7 +59,12 @@ from twisted.python import failure from twisted.web._newclient import ResponseFailed from twisted.web.resource import Resource -from typing import List # flake8: noqa +from typing import ( # noqa + Dict, + List, + Optional, + Tuple, +) from zope.interface import implements from autopush import __version__ @@ -68,8 +73,9 @@ has_connected_this_month, hasher, generate_last_connect, - dump_uaid + dump_uaid, ) +from autopush.db import Message # noqa from autopush.noseplugin import track_object from autopush.protocol import IgnoreBody from autopush.utils import ( @@ -149,6 +155,10 @@ class PushState(object): 'message', 'rotate_message_table', + # Timestamped message handling + 'scan_timestamps', + 'current_timestamp', + 'ping_time_out', '_check_notifications', '_more_notifications', @@ -188,8 +198,8 @@ def __init__(self, settings, request): self.metrics = settings.metrics self.metrics.increment("client.socket.connect", tags=self._base_tags or None) - self.uaid = None - self.uaid_obj = None + self.uaid = None # Optional[str] + self.uaid_obj = None # Optional[uuid.UUID] self.uaid_hash = "" self.last_ping = 0 self.check_storage = False @@ -206,6 +216,10 @@ def __init__(self, settings, request): self._check_notifications = False self._more_notifications = False + # Timestamp message defaults + self.scan_timestamps = False + self.current_timestamp = None + # Hanger for common actions we defer self._notification_fetch = None self._register = None @@ -218,6 +232,7 @@ def __init__(self, settings, request): @property def message(self): + # type: () -> Message """Property to access the currently used message table""" return self.settings.message_tables[self.message_month] @@ -322,7 +337,8 @@ def log_failure(self, failure, **kwargs): if isinstance(exc, JSONResponseError): self.log.info("JSONResponseError: {exc}", exc=exc, **kwargs) else: - self.log.failure(format="Unexpected error", failure=failure, **kwargs) + self.log.failure(format="Unexpected error", failure=failure, + **kwargs) @property def paused(self): @@ -848,8 +864,7 @@ def process_notifications(self): self.ps._more_notifications = True if self.ps.use_webpush: - d = self.deferToThread(self.ps.message.fetch_messages, - self.ps.uaid_obj) + d = self.deferToThread(self.webpush_fetch()) else: d = self.deferToThread( self.ap_settings.storage.fetch_notifications, self.ps.uaid) @@ -859,6 +874,14 @@ def process_notifications(self): d.addErrback(self.error_notifications) self.ps._notification_fetch = d + def webpush_fetch(self): + """Helper to return an appropriate function to fetch messages""" + if self.ps.scan_timestamps: + return partial(self.ps.message.fetch_timestamp_messages, + self.ps.uaid_obj, self.ps.current_timestamp) + else: + return partial(self.ps.message.fetch_messages, self.ps.uaid_obj) + def error_notifications(self, fail): """errBack for notification check failing""" # If we error'd out on this important check, we drop the connection @@ -912,16 +935,29 @@ def finish_notifications(self, notifs): d = self.deferToLater(1, self.process_notifications) d.addErrback(self.trap_cancel) - def finish_webpush_notifications(self, notifs): - """webpush notification processor + def finish_webpush_notifications(self, result): + # type: (Tuple[str, List[WebPushNotification]]) -> None + """WebPush notification processor""" + timestamp, notifs = result - :type notifs: List[autopush.utils.WebPushNotification] + # If there's a timestamp, update our current one to it + if timestamp: + self.ps.current_timestamp = timestamp - """ if not notifs: - # No more notifications, we can stop. + # No more notifications, check timestamped? + if not self.ps.scan_timestamps: + # Scan for timestamped then + self.ps.scan_timestamps = True + d = self.deferToLater(0, self.process_notifications) + d.addErrback(self.trap_cancel) + return + + # No more notifications, and we've scanned timestamped. self.ps._more_notifications = False + self.ps.scan_timestamps = False if self.ps._check_notifications: + # Told to check again, start over self.ps._check_notifications = False d = self.deferToLater(1, self.process_notifications) d.addErrback(self.trap_cancel) @@ -935,16 +971,38 @@ def finish_webpush_notifications(self, notifs): # Send out all the notifications now = int(time.time()) + messages_sent = False for notif in notifs: # If the TTL is too old, don't deliver and fire a delete off if notif.expired(at_time=now): - self.force_retry(self.ps.message.delete_message, notif) - continue + if not notif.sortkey_timestamp: + # Delete non-timestamped messages + self.force_retry(self.ps.message.delete_message, notif) + + # nocover here as coverage gets confused on the line below + # for unknown reasons + continue # pragma: nocover self.ps.updates_sent[str(notif.channel_id)].append(notif) msg = notif.websocket_format() + messages_sent = True self.sendJSON(msg) + # Did we send any messages? + if messages_sent: + return + + # No messages sent, update the record if needed + if self.ps.current_timestamp: + self.force_retry( + self.ps.message.update_last_message_read, + self.ps.uaid_obj, + self.ps.current_timestamp + ) + + # Schedule a new process check + self.check_missed_notifications(None) + def _rotate_message_table(self): """Function to fire off a message table copy of channels + update the router current_month entry""" @@ -1152,7 +1210,9 @@ def _handle_webpush_ack(self, chid, version, code): def ver_filter(notif): return notif.version == version - found = filter(ver_filter, self.ps.direct_updates[chid]) + found = filter( + ver_filter, self.ps.direct_updates[chid] + ) # type: List[WebPushNotification] if found: msg = found[0] size = len(msg.data) if msg.data else 0 @@ -1164,7 +1224,9 @@ def ver_filter(notif): self.ps.direct_updates[chid].remove(msg) return - found = filter(ver_filter, self.ps.updates_sent[chid]) + found = filter( + ver_filter, self.ps.updates_sent[chid] + ) # type: List[WebPushNotification] if found: msg = found[0] size = len(msg.data) if msg.data else 0 @@ -1173,12 +1235,36 @@ def ver_filter(notif): message_size=size, uaid_hash=self.ps.uaid_hash, user_agent=self.ps.user_agent, code=code, **self.ps.raw_agent) - d = self.force_retry(self.ps.message.delete_message, msg) - # We don't remove the update until we know the delete ran - # This is because we don't use range queries on dynamodb and we - # need to make sure this notification is deleted from the db before - # we query it again (to avoid dupes). - d.addBoth(self._handle_webpush_update_remove, chid, msg) + + if msg.sortkey_timestamp: + # Is this the last un-acked message we're waiting for? + last_unacked = sum( + len(self.ps.updates_sent[x]) for x in self.ps.updates_sent + ) == 1 + + if (msg.sortkey_timestamp == self.ps.current_timestamp or + last_unacked): + # If it's the last message in the batch, or last un-acked + # message + d = self.force_retry( + self.ps.message.update_last_message_read, + self.ps.uaid_obj, + self.ps.current_timestamp, + ) + d.addBoth(self._handle_webpush_update_remove, chid, msg) + else: + # It's timestamped, but not the last of this batch, + # so we just remove it from local tracking + self._handle_webpush_update_remove(None, chid, msg) + d = None + else: + # No sortkey_timestamp, so legacy/topic message, delete + d = self.force_retry(self.ps.message.delete_message, msg) + # We don't remove the update until we know the delete ran + # This is because we don't use range queries on dynamodb and + # we need to make sure this notification is deleted from the + # db before we query it again (to avoid dupes). + d.addBoth(self._handle_webpush_update_remove, chid, msg) return d def _handle_webpush_update_remove(self, result, chid, notif): @@ -1262,7 +1348,12 @@ def check_missed_notifications(self, results, resume=False): return # Should we check again? - if self.ps._check_notifications or self.ps._more_notifications: + if self.ps._more_notifications: + self.process_notifications() + elif self.ps._check_notifications: + # If we were told to check notifications, start over since we might + # have missed a topic message + self.ps.scan_timestamps = False self.process_notifications() def bad_message(self, typ, message=None, url=DEFAULT_WS_ERR):