From a06c5ad683af9a60438dedec55b23be6e80355ba Mon Sep 17 00:00:00 2001 From: jrconlin Date: Mon, 21 Mar 2016 13:05:16 -0700 Subject: [PATCH] bug: limit valid months to acceptable range A user that tries to connect from a period longer than we currently allow for could cause a "KeyError" on the server. Instead, we should require that the user use a new UAID, which shoud cause the client to re-register older connections. Closes #350 --- autopush/db.py | 9 ++- autopush/settings.py | 38 +++++++-- autopush/tests/test_db.py | 8 ++ autopush/tests/test_main.py | 45 +++++++++++ autopush/tests/test_websocket.py | 132 ++++++++++++++++++++++++++++++- autopush/websocket.py | 9 ++- 6 files changed, 226 insertions(+), 15 deletions(-) diff --git a/autopush/db.py b/autopush/db.py index ba8c2fc4..7b427063 100644 --- a/autopush/db.py +++ b/autopush/db.py @@ -58,10 +58,11 @@ def normalize_id(id): return '-'.join((raw[:8], raw[8:12], raw[12:16], raw[16:20], raw[20:])) -def make_rotating_tablename(prefix, delta=0): +def make_rotating_tablename(prefix, delta=0, date=None): """Creates a tablename for table rotation based on a prefix with a given month delta.""" - date = get_month(delta=delta) + if not date: + date = get_month(delta=delta) return "{}_{}_{}".format(prefix, date.year, date.month) @@ -77,11 +78,11 @@ def create_rotating_message_table(prefix="message", read_throughput=5, ) -def get_rotating_message_table(prefix="message", delta=0): +def get_rotating_message_table(prefix="message", delta=0, date=None): """Gets the message table for the current month.""" db = DynamoDBConnection() dblist = db.list_tables()["TableNames"] - tablename = make_rotating_tablename(prefix, delta) + tablename = make_rotating_tablename(prefix, delta, date) if tablename not in dblist: return create_rotating_message_table(prefix=prefix, delta=delta) else: diff --git a/autopush/settings.py b/autopush/settings.py index 669cfe7e..3478e78e 100644 --- a/autopush/settings.py +++ b/autopush/settings.py @@ -19,11 +19,10 @@ get_router_table, get_storage_table, get_rotating_message_table, - make_rotating_tablename, preflight_check, Storage, Router, - Message + Message, ) from autopush.exceptions import InvalidTokenException from autopush.metrics import ( @@ -163,9 +162,13 @@ def __init__(self, self.router = Router(self.router_table, self.metrics) # Used to determine whether a connection is out of date with current - # db objects - self.current_msg_month = make_rotating_tablename(self._message_prefix) - self.current_month = datetime.date.today().month + # db objects. There are three noteworty cases: + # 1 "Last Month" the table requires a rollover. + # 2 "This Month" the most common case. + # 3 "Next Month" where the system will soon be rolling over, but with + # timing, some nodes may roll over sooner. Ensuring the next month's + # table is present before the switchover is the main reason for this, + # just in case some nodes do switch sooner. self.create_initial_message_tables() # Run preflight check @@ -204,18 +207,29 @@ def message(self, value): """Setter to set the current message table""" self.message_tables[self.current_msg_month] = value + def _tomorrow(self): + return datetime.date.today() + datetime.timedelta(days=1) + def create_initial_message_tables(self): """Initializes a dict of the initial rotating messages tables. - An entry for last months table, and an entry for this months table. + An entry for last months table, an entry for this months table, + an entry for tomorrow, if tomorrow is a new month. """ + today = datetime.date.today() last_month = get_rotating_message_table(self._message_prefix, -1) this_month = get_rotating_message_table(self._message_prefix) + self.current_month = today.month + self.current_msg_month = this_month.table_name self.message_tables = { last_month.table_name: Message(last_month, self.metrics), - this_month.table_name: Message(this_month, self.metrics), + this_month.table_name: Message(this_month, self.metrics) } + if self._tomorrow().month != today.month: + next_month = get_rotating_message_table(delta=1) + self.message_tables[next_month.table_name] = Message( + next_month, self.metrics) @inlineCallbacks def update_rotating_tables(self): @@ -227,6 +241,15 @@ def update_rotating_tables(self): """ today = datetime.date.today() + tomorrow = self._tomorrow() + if ((tomorrow.month != today.month) and + sorted(self.message_tables.keys())[-1] != + tomorrow.month): + next_month = get_rotating_message_table( + self._message_prefix, 0, tomorrow) + self.message_tables[next_month.table_name] = Message( + next_month, self.metrics) + if today.month == self.current_month: # No change in month, we're fine. returnValue(False) @@ -241,7 +264,6 @@ def update_rotating_tables(self): self.current_msg_month = message_table.table_name self.message_tables[self.current_msg_month] = \ Message(message_table, self.metrics) - returnValue(True) def update(self, **kwargs): diff --git a/autopush/tests/test_db.py b/autopush/tests/test_db.py index f6e36c4b..45228d67 100644 --- a/autopush/tests/test_db.py +++ b/autopush/tests/test_db.py @@ -70,6 +70,14 @@ def test_hasher(self): 'd8f614d06cdd592cb8470f31177c8331a') db.key_hash = "" + def test_normalize_id(self): + import autopush.db as db + abnormal = "deadbeef00000000decafbad00000000" + normal = "deadbeef-0000-0000-deca-fbad00000000" + eq_(db.normalize_id(abnormal), normal) + self.assertRaises(ValueError, db.normalize_id, "invalid") + eq_(db.normalize_id(abnormal.upper()), normal) + class StorageTestCase(unittest.TestCase): def setUp(self): diff --git a/autopush/tests/test_main.py b/autopush/tests/test_main.py index 33b66f8b..89841d6e 100644 --- a/autopush/tests/test_main.py +++ b/autopush/tests/test_main.py @@ -1,4 +1,5 @@ import unittest +import datetime from mock import Mock, patch from moto import mock_dynamodb2, mock_s3 @@ -44,6 +45,21 @@ def test_resolve_host_no_interface(self, mock_socket): ip = resolve_ip("example.com") eq_(ip, "example.com") + def test_new_month(self): + today = datetime.date.today() + next_month = today.month + 1 + next_year = today.year + if next_month > 12: # pragma: nocover + next_month = 1 + next_year += 1 + tomorrow = datetime.datetime(year=next_year, + month=next_month, + day=1) + AutopushSettings._tomorrow = Mock() + AutopushSettings._tomorrow.return_value = tomorrow + settings = AutopushSettings() + eq_(len(settings.message_tables), 3) + class SettingsAsyncTestCase(trialtest.TestCase): def test_update_rotating_tables(self): @@ -65,6 +81,35 @@ def check_tables(result): d.addCallback(check_tables) return d + def test_update_rotating_tables_month_end(self): + today = datetime.date.today() + next_month = today.month + 1 + next_year = today.year + if next_month > 12: # pragma: nocover + next_month = 1 + next_year += 1 + tomorrow = datetime.datetime(year=next_year, + month=next_month, + day=1) + AutopushSettings._tomorrow = Mock() + AutopushSettings._tomorrow.return_value = tomorrow + settings = AutopushSettings( + hostname="example.com", resolve_hostname=True) + # shift off tomorrow's table. + + tomorrow_table = sorted(settings.message_tables.keys())[-1] + settings.message_tables.pop(tomorrow_table) + + # Get the deferred back + d = settings.update_rotating_tables() + + def check_tables(result): + eq_(len(settings.message_tables), 3) + eq_(sorted(settings.message_tables.keys())[-1], tomorrow_table) + + d.addCallback(check_tables) + return d + def test_update_not_needed(self): settings = AutopushSettings( hostname="google.com", resolve_hostname=True) diff --git a/autopush/tests/test_websocket.py b/autopush/tests/test_websocket.py index 248e4839..dbe66760 100644 --- a/autopush/tests/test_websocket.py +++ b/autopush/tests/test_websocket.py @@ -1,4 +1,5 @@ import json +import datetime import time import uuid from hashlib import sha256 @@ -18,7 +19,9 @@ from twisted.trial import unittest import autopush.db as db -from autopush.db import create_rotating_message_table +from autopush.db import ( + create_rotating_message_table, +) from autopush.settings import AutopushSettings from autopush.websocket import ( PushState, @@ -406,6 +409,133 @@ def wait_for_agent_call(): # pragma: nocover reactor.callLater(0.1, wait_for_agent_call) return d + def test_hello_old(self): + orig_uaid = "deadbeef12345678decafbad12345678" + # router.register_user returns (registered, previous + target_day = datetime.date(2016, 2, 29) + msg_day = datetime.date(2015, 12, 15) + msg_date = "{}_{}_{}".format( + self.proto.ap_settings._message_prefix, + msg_day.year, + msg_day.month) + msg_data = { + "router_type": "webpush", + "node_id": "http://localhost", + "last_connect": int(msg_day.strftime("%s")), + "current_month": msg_date, + } + + def fake_msg(data): + return (True, msg_data, data) + + mock_msg = Mock(wraps=db.Message) + mock_msg.fetch_messages.return_value = [] + self.proto.ap_settings.router.register_user = fake_msg + # massage message_tables to include our fake range + mt = self.proto.ps.settings.message_tables + for k in mt.keys(): + del(mt[k]) + mt['message_2016_1'] = mock_msg + mt['message_2016_2'] = mock_msg + mt['message_2016_3'] = mock_msg + with patch.object(datetime, 'date', + Mock(wraps=datetime.date)) as patched: + patched.today.return_value = target_day + self._connect() + self._send_message(dict(messageType="hello", + uaid=orig_uaid, + channelIDs=[], + use_webpush=True)) + + def check_result(msg): + eq_(self.proto.ps.rotate_message_table, False) + # it's fine you've not connected in a while, but + # you should recycle your endpoints since they're probably + # invalid by now anyway. + eq_(msg["status"], 200) + ok_(msg["uaid"] != orig_uaid) + + return self._check_response(check_result) + + def test_hello_tomorrow(self): + orig_uaid = "deadbeef12345678decafbad12345678" + # router.register_user returns (registered, previous + target_day = datetime.date(2016, 2, 29) + msg_day = datetime.date(2016, 3, 1) + msg_date = "{}_{}_{}".format( + self.proto.ap_settings._message_prefix, + msg_day.year, + msg_day.month) + msg_data = { + "router_type": "webpush", + "node_id": "http://localhost", + "last_connect": int(msg_day.strftime("%s")), + "current_month": msg_date, + } + + def fake_msg(data): + return (True, msg_data, data) + + mock_msg = Mock(wraps=db.Message) + mock_msg.fetch_messages.return_value = [] + self.proto.ap_settings.router.register_user = fake_msg + # massage message_tables to include our fake range + mt = self.proto.ps.settings.message_tables + for k in mt.keys(): + del(mt[k]) + mt['message_2016_1'] = mock_msg + mt['message_2016_2'] = mock_msg + mt['message_2016_3'] = mock_msg + with patch.object(datetime, 'date', + Mock(wraps=datetime.date)) as patched: + patched.today.return_value = target_day + self._connect() + self._send_message(dict(messageType="hello", + uaid=orig_uaid, + channelIDs=[], + use_webpush=True)) + + def check_result(msg): + eq_(self.proto.ps.rotate_message_table, False) + # it's fine you've not connected in a while, but + # you should recycle your endpoints since they're probably + # invalid by now anyway. + eq_(msg["status"], 200) + eq_(msg["uaid"], orig_uaid) + + return self._check_response(check_result) + + """ + def test_add_tomorrow(self): + today = datetime.date(2016, 2, 29) + yester = datetime.date(2016, 1, 1) + tomorrow = datetime.date(2016, 3, 1) + today_table = "{}_{}_{}".format( + self.proto.ap_settings._message_prefix, + today.year, + today.month) + yester_table = "{}_{}_{}".format( + self.proto.ap_settings._message_prefix, + yester.year, + yester.month) + tomorrow_table = "{}_{}_{}".format( + self.proto.ap_settings._message_prefix, + tomorrow.year, + tomorrow.month) + + mock_msg = Mock(wraps=db.Message) + mock_msg.fetch_messages.return_value = [] + mt = self.proto.ps.settings.message_tables + for k in mt.keys(): + del(mt[k]) + mt[yester_table] = mock_msg + mt[today_table] = mock_msg + + self._connect() + self.proto.ps.settings.add_tomorrow(today, today_table) + ok_(tomorrow_table in self.proto.ps.settings.message_tables) + """ + def test_hello(self): self._connect() self._send_message(dict(messageType="hello", channelIDs=[])) diff --git a/autopush/websocket.py b/autopush/websocket.py index 62016d72..f00eac5d 100644 --- a/autopush/websocket.py +++ b/autopush/websocket.py @@ -776,9 +776,14 @@ def _check_message_table_rotation(self, previous): self.transport.pauseProducing() # Check for table rotation cur_month = previous.get("current_month") + # Previous month user or new user, flag for message rotation and + # set the message_month to the router month if cur_month != self.ps.message_month: - # Previous month user or new user, flag for message rotation and - # set the message_month to the router month + if cur_month not in self.ps.settings.message_tables: + # This UAID has expired. Force client to reregister. + self.ps.uaid = uuid.uuid4().hex + self._finish_webpush_hello() + return self.ps.message_month = cur_month self.ps.rotate_message_table = True