Skip to content
This repository has been archived by the owner on Jul 13, 2023. It is now read-only.

Commit

Permalink
feat: enforce AWS expiry on data lookups
Browse files Browse the repository at this point in the history
Closes #1051
  • Loading branch information
jrconlin committed Nov 9, 2017
1 parent b57d9b5 commit 7209469
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 10 deletions.
15 changes: 13 additions & 2 deletions autopush/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,7 @@ def all_channels(self, uaid):
# Note: This only returns the chids associated with the UAID.
# Functions that call store_message() would be required to
# update that list as well using register_channel()
# TODO: once expiry is properly integrated, use FilterExpression
result = self.table.get_item(
Key={
'uaid': hasher(uaid),
Expand All @@ -534,6 +535,9 @@ def all_channels(self, uaid):
return False, set([])
if 'Item' not in result:
return False, set([])
now = int(time.time())
if result['Item'].get('expiry', now) < now:
return False, set([])
return True, result['Item'].get("chids", set([]))

@track_provisioned
Expand Down Expand Up @@ -609,6 +613,7 @@ def fetch_messages(
"""
# Eagerly fetches all results in the result set.
# TODO: once expiry is properly integrated, use FilterExpression
response = self.table.query(
KeyConditionExpression=(Key("uaid").eq(hasher(uaid.hex))
& Key('chidmessageid').lt('02')),
Expand Down Expand Up @@ -654,15 +659,17 @@ def fetch_timestamp_messages(
else:
sortkey = "01;"

# TODO: once expiry is properly integrated, use FilterExpression
response = self.table.query(
KeyConditionExpression=(Key('uaid').eq(hasher(uaid.hex))
& Key('chidmessageid').gt(sortkey)),
ConsistentRead=True,
Limit=limit
)
now = int(time.time())
notifs = [
WebPushNotification.from_message_table(uaid, x) for x in
response.get("Items")
response.get("Items") if x.get('expiry', now) >= now
]
ts_notifs = [x for x in notifs if x.sortkey_timestamp]
last_position = None
Expand Down Expand Up @@ -717,6 +724,7 @@ def get_uaid(self, uaid):
"""
try:
# TODO: once expiry is properly integrated, use FilterExpression
item = self.table.get_item(
Key={
'uaid': hasher(uaid)
Expand All @@ -729,6 +737,9 @@ def get_uaid(self, uaid):
item = item.get('Item')
if item is None:
raise ItemNotFound("uaid not found")
now = int(time.time())
if item.get('expiry', now) < now:
raise ItemNotFound("uaid not found")
if item.keys() == ['uaid']:
# Incomplete record, drop it.
self.drop_user(uaid)
Expand Down Expand Up @@ -813,7 +824,7 @@ def drop_user(self, uaid):
)
if 'Item' not in item:
return False
except ClientError:
except ClientError: # pragma nocover
pass
result = self.table.delete_item(Key={'uaid': hasher(uaid)})
return result['ResponseMetadata']['HTTPStatusCode'] == 200
Expand Down
45 changes: 45 additions & 0 deletions autopush/tests/test_db.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import unittest
import uuid
import time
from datetime import datetime, timedelta

from autopush.websocket import ms_time
Expand Down Expand Up @@ -224,6 +225,15 @@ def test_all_channels(self):
assert chid2 not in chans
assert chid in chans

def test_all_channels_expiry(self):
chid = str(uuid.uuid4())
m = get_rotating_message_table()
message = Message(m, SinkMetrics())
message.register_channel(self.uaid, chid, -100)

_, chans = message.all_channels(self.uaid)
assert chid not in chans

def test_all_channels_fail(self):
m = get_rotating_message_table()
message = Message(m, SinkMetrics())
Expand Down Expand Up @@ -274,6 +284,24 @@ def test_message_storage(self):
uuid.UUID(self.uaid), " ")
assert len(all_messages) == 3

def test_message_storage_expiry(self):
chid = str(uuid.uuid4())
chidx = str(uuid.uuid4())
m = get_rotating_message_table()
message = Message(m, SinkMetrics())
message.register_channel(self.uaid, chid)

expired = make_webpush_notification(self.uaid, chidx)
expired.ttl = -100
message.store_message(make_webpush_notification(self.uaid, chid))
message.store_message(expired)

_, all_messages = message.fetch_timestamp_messages(
uuid.UUID(self.uaid), " ")
assert len(all_messages) == 1
for x in all_messages:
assert x.channel_id != expired.channel_id

def test_message_storage_overwrite(self):
"""Test that store_message can overwrite existing messages which
can occur in some reconnect cases but shouldn't error"""
Expand Down Expand Up @@ -374,6 +402,23 @@ def test_no_uaid_found(self):
with pytest.raises(ItemNotFound):
router.get_uaid(uaid)

def test_uaid_expiry(self):
uaid = uuid.uuid4()
r = get_router_table()
router = Router(r, SinkMetrics())
router.table.get_item = Mock()
router.table.get_item.return_value = {
"ResponseMetadata": {
"HTTPStatusCode": 200
},
"Item": {
"uaid": uaid.hex,
"expiry": int(time.time()) - 100
}
}
with pytest.raises(ItemNotFound):
router.get_uaid(str(uaid))

def test_uaid_provision_failed(self):
r = get_router_table()
router = Router(r, SinkMetrics())
Expand Down
7 changes: 0 additions & 7 deletions autopush/tests/test_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,6 @@ def test_delete_db_error(self):
tok = ":".join(["m", dummy_uaid.hex, str(dummy_chid)])
self.fernet_mock.decrypt.return_value = tok

def raise_condition(*args, **kwargs):
import autopush.db
raise autopush.db.g_client.exceptions.ClientError(
{'Error': {'Code': 'ConditionalCheckFailedException'}},
'mock_update_item'
)

self.message_mock.configure_mock(**{"delete_message.return_value":
False})
resp = yield self.client.delete(self.url(message_id="ignored"))
Expand Down
2 changes: 1 addition & 1 deletion autopush/tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,7 +1005,7 @@ def test_ttl_batch_expired_and_good_one(self):
data2 = str(uuid.uuid4())
client = yield self.quick_register()
yield client.disconnect()
for x in range(0, 12):
for x in range(0, 7):
yield client.send_notification(data=data, ttl=1)

yield client.send_notification(data=data2)
Expand Down

0 comments on commit 7209469

Please sign in to comment.