Skip to content
This repository has been archived by the owner on Jul 13, 2023. It is now read-only.

Commit

Permalink
feat: enforce AWS expiry on data lookups
Browse files Browse the repository at this point in the history
Closes #1051
  • Loading branch information
jrconlin committed Nov 7, 2017
1 parent 3283a6e commit a7d54bc
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 2 deletions.
15 changes: 14 additions & 1 deletion autopush/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,7 @@ def all_channels(self, uaid):
# Note: This only returns the chids associated with the UAID.
# Functions that call store_message() would be required to
# update that list as well using register_channel()
# TODO: once expiry is properly integrated, use FilterExpression
result = self.table.get_item(
Key={
'uaid': hasher(uaid),
Expand All @@ -534,6 +535,9 @@ def all_channels(self, uaid):
return False, set([])
if 'Item' not in result:
return False, set([])
now = int(time.time())
if result['Item'].get('expiry', now) < now:
return False, set([])
return True, result['Item'].get("chids", set([]))

@track_provisioned
Expand Down Expand Up @@ -609,6 +613,7 @@ def fetch_messages(
"""
# Eagerly fetches all results in the result set.
# TODO: once expiry is properly integrated, use FilterExpression
response = self.table.query(
KeyConditionExpression=(Key("uaid").eq(hasher(uaid.hex))
& Key('chidmessageid').lt('02')),
Expand Down Expand Up @@ -654,16 +659,20 @@ def fetch_timestamp_messages(
else:
sortkey = "01;"

# TODO: once expiry is properly integrated, use FilterExpression
response = self.table.query(
KeyConditionExpression=(Key('uaid').eq(hasher(uaid.hex))
& Key('chidmessageid').gt(sortkey)),
ConsistentRead=True,
Limit=limit
)
now = int(time.time())
notifs = [
WebPushNotification.from_message_table(uaid, x) for x in
response.get("Items")
response.get("Items") if x.get('expiry', now) >= now
]
if response.get("Items") and not notifs:
print "here"
ts_notifs = [x for x in notifs if x.sortkey_timestamp]
last_position = None
if ts_notifs:
Expand Down Expand Up @@ -717,6 +726,7 @@ def get_uaid(self, uaid):
"""
try:
# TODO: once expiry is properly integrated, use FilterExpression
item = self.table.get_item(
Key={
'uaid': hasher(uaid)
Expand All @@ -729,6 +739,9 @@ def get_uaid(self, uaid):
item = item.get('Item')
if item is None:
raise ItemNotFound("uaid not found")
now = int(time.time())
if item.get('expiry', now) < now:
raise ItemNotFound("uaid not found")
if item.keys() == ['uaid']:
# Incomplete record, drop it.
self.drop_user(uaid)
Expand Down
45 changes: 45 additions & 0 deletions autopush/tests/test_db.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import unittest
import uuid
import time
from datetime import datetime, timedelta

from autopush.websocket import ms_time
Expand Down Expand Up @@ -224,6 +225,15 @@ def test_all_channels(self):
assert chid2 not in chans
assert chid in chans

def test_all_channels_expiry(self):
chid = str(uuid.uuid4())
m = get_rotating_message_table()
message = Message(m, SinkMetrics())
message.register_channel(self.uaid, chid, -100)

_, chans = message.all_channels(self.uaid)
assert chid not in chans

def test_all_channels_fail(self):
m = get_rotating_message_table()
message = Message(m, SinkMetrics())
Expand Down Expand Up @@ -274,6 +284,24 @@ def test_message_storage(self):
uuid.UUID(self.uaid), " ")
assert len(all_messages) == 3

def test_message_storage_expiry(self):
chid = str(uuid.uuid4())
chidx = str(uuid.uuid4())
m = get_rotating_message_table()
message = Message(m, SinkMetrics())
message.register_channel(self.uaid, chid)

expired = make_webpush_notification(self.uaid, chidx)
expired.ttl = -100
message.store_message(make_webpush_notification(self.uaid, chid))
message.store_message(expired)

_, all_messages = message.fetch_timestamp_messages(
uuid.UUID(self.uaid), " ")
assert len(all_messages) == 1
for x in all_messages:
assert x.channel_id != expired.channel_id

def test_message_storage_overwrite(self):
"""Test that store_message can overwrite existing messages which
can occur in some reconnect cases but shouldn't error"""
Expand Down Expand Up @@ -374,6 +402,23 @@ def test_no_uaid_found(self):
with pytest.raises(ItemNotFound):
router.get_uaid(uaid)

def test_uaid_expiry(self):
uaid = uuid.uuid4()
r = get_router_table()
router = Router(r, SinkMetrics())
router.table.get_item = Mock()
router.table.get_item.return_value = {
"ResponseMetadata": {
"HTTPStatusCode": 200
},
"Item": {
"uaid": uaid.hex,
"expiry": int(time.time()) - 100
}
}
with pytest.raises(ItemNotFound):
router.get_uaid(str(uaid))

def test_uaid_provision_failed(self):
r = get_router_table()
router = Router(r, SinkMetrics())
Expand Down
2 changes: 1 addition & 1 deletion autopush/tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,7 +1005,7 @@ def test_ttl_batch_expired_and_good_one(self):
data2 = str(uuid.uuid4())
client = yield self.quick_register()
yield client.disconnect()
for x in range(0, 12):
for x in range(0, 7):
yield client.send_notification(data=data, ttl=1)

yield client.send_notification(data=data2)
Expand Down

0 comments on commit a7d54bc

Please sign in to comment.