From 72094691241c694ef1742bad28af077c1a290f9f Mon Sep 17 00:00:00 2001 From: jr conlin Date: Thu, 9 Nov 2017 11:09:36 -0800 Subject: [PATCH] feat: enforce AWS expiry on data lookups Closes #1051 --- autopush/db.py | 15 ++++++++-- autopush/tests/test_db.py | 45 ++++++++++++++++++++++++++++++ autopush/tests/test_endpoint.py | 7 ----- autopush/tests/test_integration.py | 2 +- 4 files changed, 59 insertions(+), 10 deletions(-) diff --git a/autopush/db.py b/autopush/db.py index bae7eaea..987a0ee1 100644 --- a/autopush/db.py +++ b/autopush/db.py @@ -523,6 +523,7 @@ def all_channels(self, uaid): # Note: This only returns the chids associated with the UAID. # Functions that call store_message() would be required to # update that list as well using register_channel() + # TODO: once expiry is properly integrated, use FilterExpression result = self.table.get_item( Key={ 'uaid': hasher(uaid), @@ -534,6 +535,9 @@ def all_channels(self, uaid): return False, set([]) if 'Item' not in result: return False, set([]) + now = int(time.time()) + if result['Item'].get('expiry', now) < now: + return False, set([]) return True, result['Item'].get("chids", set([])) @track_provisioned @@ -609,6 +613,7 @@ def fetch_messages( """ # Eagerly fetches all results in the result set. + # TODO: once expiry is properly integrated, use FilterExpression response = self.table.query( KeyConditionExpression=(Key("uaid").eq(hasher(uaid.hex)) & Key('chidmessageid').lt('02')), @@ -654,15 +659,17 @@ def fetch_timestamp_messages( else: sortkey = "01;" + # TODO: once expiry is properly integrated, use FilterExpression response = self.table.query( KeyConditionExpression=(Key('uaid').eq(hasher(uaid.hex)) & Key('chidmessageid').gt(sortkey)), ConsistentRead=True, Limit=limit ) + now = int(time.time()) notifs = [ WebPushNotification.from_message_table(uaid, x) for x in - response.get("Items") + response.get("Items") if x.get('expiry', now) >= now ] ts_notifs = [x for x in notifs if x.sortkey_timestamp] last_position = None @@ -717,6 +724,7 @@ def get_uaid(self, uaid): """ try: + # TODO: once expiry is properly integrated, use FilterExpression item = self.table.get_item( Key={ 'uaid': hasher(uaid) @@ -729,6 +737,9 @@ def get_uaid(self, uaid): item = item.get('Item') if item is None: raise ItemNotFound("uaid not found") + now = int(time.time()) + if item.get('expiry', now) < now: + raise ItemNotFound("uaid not found") if item.keys() == ['uaid']: # Incomplete record, drop it. self.drop_user(uaid) @@ -813,7 +824,7 @@ def drop_user(self, uaid): ) if 'Item' not in item: return False - except ClientError: + except ClientError: # pragma nocover pass result = self.table.delete_item(Key={'uaid': hasher(uaid)}) return result['ResponseMetadata']['HTTPStatusCode'] == 200 diff --git a/autopush/tests/test_db.py b/autopush/tests/test_db.py index 920565dc..9eb1e186 100644 --- a/autopush/tests/test_db.py +++ b/autopush/tests/test_db.py @@ -1,5 +1,6 @@ import unittest import uuid +import time from datetime import datetime, timedelta from autopush.websocket import ms_time @@ -224,6 +225,15 @@ def test_all_channels(self): assert chid2 not in chans assert chid in chans + def test_all_channels_expiry(self): + chid = str(uuid.uuid4()) + m = get_rotating_message_table() + message = Message(m, SinkMetrics()) + message.register_channel(self.uaid, chid, -100) + + _, chans = message.all_channels(self.uaid) + assert chid not in chans + def test_all_channels_fail(self): m = get_rotating_message_table() message = Message(m, SinkMetrics()) @@ -274,6 +284,24 @@ def test_message_storage(self): uuid.UUID(self.uaid), " ") assert len(all_messages) == 3 + def test_message_storage_expiry(self): + chid = str(uuid.uuid4()) + chidx = str(uuid.uuid4()) + m = get_rotating_message_table() + message = Message(m, SinkMetrics()) + message.register_channel(self.uaid, chid) + + expired = make_webpush_notification(self.uaid, chidx) + expired.ttl = -100 + message.store_message(make_webpush_notification(self.uaid, chid)) + message.store_message(expired) + + _, all_messages = message.fetch_timestamp_messages( + uuid.UUID(self.uaid), " ") + assert len(all_messages) == 1 + for x in all_messages: + assert x.channel_id != expired.channel_id + def test_message_storage_overwrite(self): """Test that store_message can overwrite existing messages which can occur in some reconnect cases but shouldn't error""" @@ -374,6 +402,23 @@ def test_no_uaid_found(self): with pytest.raises(ItemNotFound): router.get_uaid(uaid) + def test_uaid_expiry(self): + uaid = uuid.uuid4() + r = get_router_table() + router = Router(r, SinkMetrics()) + router.table.get_item = Mock() + router.table.get_item.return_value = { + "ResponseMetadata": { + "HTTPStatusCode": 200 + }, + "Item": { + "uaid": uaid.hex, + "expiry": int(time.time()) - 100 + } + } + with pytest.raises(ItemNotFound): + router.get_uaid(str(uaid)) + def test_uaid_provision_failed(self): r = get_router_table() router = Router(r, SinkMetrics()) diff --git a/autopush/tests/test_endpoint.py b/autopush/tests/test_endpoint.py index fdf38cc9..af7426aa 100644 --- a/autopush/tests/test_endpoint.py +++ b/autopush/tests/test_endpoint.py @@ -121,13 +121,6 @@ def test_delete_db_error(self): tok = ":".join(["m", dummy_uaid.hex, str(dummy_chid)]) self.fernet_mock.decrypt.return_value = tok - def raise_condition(*args, **kwargs): - import autopush.db - raise autopush.db.g_client.exceptions.ClientError( - {'Error': {'Code': 'ConditionalCheckFailedException'}}, - 'mock_update_item' - ) - self.message_mock.configure_mock(**{"delete_message.return_value": False}) resp = yield self.client.delete(self.url(message_id="ignored")) diff --git a/autopush/tests/test_integration.py b/autopush/tests/test_integration.py index 98b5b667..a2a365ea 100644 --- a/autopush/tests/test_integration.py +++ b/autopush/tests/test_integration.py @@ -1005,7 +1005,7 @@ def test_ttl_batch_expired_and_good_one(self): data2 = str(uuid.uuid4()) client = yield self.quick_register() yield client.disconnect() - for x in range(0, 12): + for x in range(0, 7): yield client.send_notification(data=data, ttl=1) yield client.send_notification(data=data2)