Skip to content
This repository has been archived by the owner on Jul 13, 2023. It is now read-only.

Commit

Permalink
bug: limit valid months to acceptable range
Browse files Browse the repository at this point in the history
A user that tries to connect from a period longer than we currently
allow for could cause a "KeyError" on the server. Instead, we should
require that the user use a new UAID, which shoud cause the client to
re-register older connections.

Closes #350
  • Loading branch information
jrconlin committed Mar 23, 2016
1 parent 9fdd5ce commit a06c5ad
Show file tree
Hide file tree
Showing 6 changed files with 226 additions and 15 deletions.
9 changes: 5 additions & 4 deletions autopush/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,11 @@ def normalize_id(id):
return '-'.join((raw[:8], raw[8:12], raw[12:16], raw[16:20], raw[20:]))


def make_rotating_tablename(prefix, delta=0):
def make_rotating_tablename(prefix, delta=0, date=None):
"""Creates a tablename for table rotation based on a prefix with a given
month delta."""
date = get_month(delta=delta)
if not date:
date = get_month(delta=delta)
return "{}_{}_{}".format(prefix, date.year, date.month)


Expand All @@ -77,11 +78,11 @@ def create_rotating_message_table(prefix="message", read_throughput=5,
)


def get_rotating_message_table(prefix="message", delta=0):
def get_rotating_message_table(prefix="message", delta=0, date=None):
"""Gets the message table for the current month."""
db = DynamoDBConnection()
dblist = db.list_tables()["TableNames"]
tablename = make_rotating_tablename(prefix, delta)
tablename = make_rotating_tablename(prefix, delta, date)
if tablename not in dblist:
return create_rotating_message_table(prefix=prefix, delta=delta)
else:
Expand Down
38 changes: 30 additions & 8 deletions autopush/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@
get_router_table,
get_storage_table,
get_rotating_message_table,
make_rotating_tablename,
preflight_check,
Storage,
Router,
Message
Message,
)
from autopush.exceptions import InvalidTokenException
from autopush.metrics import (
Expand Down Expand Up @@ -163,9 +162,13 @@ def __init__(self,
self.router = Router(self.router_table, self.metrics)

# Used to determine whether a connection is out of date with current
# db objects
self.current_msg_month = make_rotating_tablename(self._message_prefix)
self.current_month = datetime.date.today().month
# db objects. There are three noteworty cases:
# 1 "Last Month" the table requires a rollover.
# 2 "This Month" the most common case.
# 3 "Next Month" where the system will soon be rolling over, but with
# timing, some nodes may roll over sooner. Ensuring the next month's
# table is present before the switchover is the main reason for this,
# just in case some nodes do switch sooner.
self.create_initial_message_tables()

# Run preflight check
Expand Down Expand Up @@ -204,18 +207,29 @@ def message(self, value):
"""Setter to set the current message table"""
self.message_tables[self.current_msg_month] = value

def _tomorrow(self):
return datetime.date.today() + datetime.timedelta(days=1)

def create_initial_message_tables(self):
"""Initializes a dict of the initial rotating messages tables.
An entry for last months table, and an entry for this months table.
An entry for last months table, an entry for this months table,
an entry for tomorrow, if tomorrow is a new month.
"""
today = datetime.date.today()
last_month = get_rotating_message_table(self._message_prefix, -1)
this_month = get_rotating_message_table(self._message_prefix)
self.current_month = today.month
self.current_msg_month = this_month.table_name
self.message_tables = {
last_month.table_name: Message(last_month, self.metrics),
this_month.table_name: Message(this_month, self.metrics),
this_month.table_name: Message(this_month, self.metrics)
}
if self._tomorrow().month != today.month:
next_month = get_rotating_message_table(delta=1)
self.message_tables[next_month.table_name] = Message(
next_month, self.metrics)

@inlineCallbacks
def update_rotating_tables(self):
Expand All @@ -227,6 +241,15 @@ def update_rotating_tables(self):
"""
today = datetime.date.today()
tomorrow = self._tomorrow()
if ((tomorrow.month != today.month) and
sorted(self.message_tables.keys())[-1] !=
tomorrow.month):
next_month = get_rotating_message_table(
self._message_prefix, 0, tomorrow)
self.message_tables[next_month.table_name] = Message(
next_month, self.metrics)

if today.month == self.current_month:
# No change in month, we're fine.
returnValue(False)
Expand All @@ -241,7 +264,6 @@ def update_rotating_tables(self):
self.current_msg_month = message_table.table_name
self.message_tables[self.current_msg_month] = \
Message(message_table, self.metrics)

returnValue(True)

def update(self, **kwargs):
Expand Down
8 changes: 8 additions & 0 deletions autopush/tests/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,14 @@ def test_hasher(self):
'd8f614d06cdd592cb8470f31177c8331a')
db.key_hash = ""

def test_normalize_id(self):
import autopush.db as db
abnormal = "deadbeef00000000decafbad00000000"
normal = "deadbeef-0000-0000-deca-fbad00000000"
eq_(db.normalize_id(abnormal), normal)
self.assertRaises(ValueError, db.normalize_id, "invalid")
eq_(db.normalize_id(abnormal.upper()), normal)


class StorageTestCase(unittest.TestCase):
def setUp(self):
Expand Down
45 changes: 45 additions & 0 deletions autopush/tests/test_main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest
import datetime

from mock import Mock, patch
from moto import mock_dynamodb2, mock_s3
Expand Down Expand Up @@ -44,6 +45,21 @@ def test_resolve_host_no_interface(self, mock_socket):
ip = resolve_ip("example.com")
eq_(ip, "example.com")

def test_new_month(self):
today = datetime.date.today()
next_month = today.month + 1
next_year = today.year
if next_month > 12: # pragma: nocover
next_month = 1
next_year += 1
tomorrow = datetime.datetime(year=next_year,
month=next_month,
day=1)
AutopushSettings._tomorrow = Mock()
AutopushSettings._tomorrow.return_value = tomorrow
settings = AutopushSettings()
eq_(len(settings.message_tables), 3)


class SettingsAsyncTestCase(trialtest.TestCase):
def test_update_rotating_tables(self):
Expand All @@ -65,6 +81,35 @@ def check_tables(result):
d.addCallback(check_tables)
return d

def test_update_rotating_tables_month_end(self):
today = datetime.date.today()
next_month = today.month + 1
next_year = today.year
if next_month > 12: # pragma: nocover
next_month = 1
next_year += 1
tomorrow = datetime.datetime(year=next_year,
month=next_month,
day=1)
AutopushSettings._tomorrow = Mock()
AutopushSettings._tomorrow.return_value = tomorrow
settings = AutopushSettings(
hostname="example.com", resolve_hostname=True)
# shift off tomorrow's table.

tomorrow_table = sorted(settings.message_tables.keys())[-1]
settings.message_tables.pop(tomorrow_table)

# Get the deferred back
d = settings.update_rotating_tables()

def check_tables(result):
eq_(len(settings.message_tables), 3)
eq_(sorted(settings.message_tables.keys())[-1], tomorrow_table)

d.addCallback(check_tables)
return d

def test_update_not_needed(self):
settings = AutopushSettings(
hostname="google.com", resolve_hostname=True)
Expand Down
132 changes: 131 additions & 1 deletion autopush/tests/test_websocket.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import datetime
import time
import uuid
from hashlib import sha256
Expand All @@ -18,7 +19,9 @@
from twisted.trial import unittest

import autopush.db as db
from autopush.db import create_rotating_message_table
from autopush.db import (
create_rotating_message_table,
)
from autopush.settings import AutopushSettings
from autopush.websocket import (
PushState,
Expand Down Expand Up @@ -406,6 +409,133 @@ def wait_for_agent_call(): # pragma: nocover
reactor.callLater(0.1, wait_for_agent_call)
return d

def test_hello_old(self):
orig_uaid = "deadbeef12345678decafbad12345678"
# router.register_user returns (registered, previous
target_day = datetime.date(2016, 2, 29)
msg_day = datetime.date(2015, 12, 15)
msg_date = "{}_{}_{}".format(
self.proto.ap_settings._message_prefix,
msg_day.year,
msg_day.month)
msg_data = {
"router_type": "webpush",
"node_id": "http://localhost",
"last_connect": int(msg_day.strftime("%s")),
"current_month": msg_date,
}

def fake_msg(data):
return (True, msg_data, data)

mock_msg = Mock(wraps=db.Message)
mock_msg.fetch_messages.return_value = []
self.proto.ap_settings.router.register_user = fake_msg
# massage message_tables to include our fake range
mt = self.proto.ps.settings.message_tables
for k in mt.keys():
del(mt[k])
mt['message_2016_1'] = mock_msg
mt['message_2016_2'] = mock_msg
mt['message_2016_3'] = mock_msg
with patch.object(datetime, 'date',
Mock(wraps=datetime.date)) as patched:
patched.today.return_value = target_day
self._connect()
self._send_message(dict(messageType="hello",
uaid=orig_uaid,
channelIDs=[],
use_webpush=True))

def check_result(msg):
eq_(self.proto.ps.rotate_message_table, False)
# it's fine you've not connected in a while, but
# you should recycle your endpoints since they're probably
# invalid by now anyway.
eq_(msg["status"], 200)
ok_(msg["uaid"] != orig_uaid)

return self._check_response(check_result)

def test_hello_tomorrow(self):
orig_uaid = "deadbeef12345678decafbad12345678"
# router.register_user returns (registered, previous
target_day = datetime.date(2016, 2, 29)
msg_day = datetime.date(2016, 3, 1)
msg_date = "{}_{}_{}".format(
self.proto.ap_settings._message_prefix,
msg_day.year,
msg_day.month)
msg_data = {
"router_type": "webpush",
"node_id": "http://localhost",
"last_connect": int(msg_day.strftime("%s")),
"current_month": msg_date,
}

def fake_msg(data):
return (True, msg_data, data)

mock_msg = Mock(wraps=db.Message)
mock_msg.fetch_messages.return_value = []
self.proto.ap_settings.router.register_user = fake_msg
# massage message_tables to include our fake range
mt = self.proto.ps.settings.message_tables
for k in mt.keys():
del(mt[k])
mt['message_2016_1'] = mock_msg
mt['message_2016_2'] = mock_msg
mt['message_2016_3'] = mock_msg
with patch.object(datetime, 'date',
Mock(wraps=datetime.date)) as patched:
patched.today.return_value = target_day
self._connect()
self._send_message(dict(messageType="hello",
uaid=orig_uaid,
channelIDs=[],
use_webpush=True))

def check_result(msg):
eq_(self.proto.ps.rotate_message_table, False)
# it's fine you've not connected in a while, but
# you should recycle your endpoints since they're probably
# invalid by now anyway.
eq_(msg["status"], 200)
eq_(msg["uaid"], orig_uaid)

return self._check_response(check_result)

"""
def test_add_tomorrow(self):
today = datetime.date(2016, 2, 29)
yester = datetime.date(2016, 1, 1)
tomorrow = datetime.date(2016, 3, 1)
today_table = "{}_{}_{}".format(
self.proto.ap_settings._message_prefix,
today.year,
today.month)
yester_table = "{}_{}_{}".format(
self.proto.ap_settings._message_prefix,
yester.year,
yester.month)
tomorrow_table = "{}_{}_{}".format(
self.proto.ap_settings._message_prefix,
tomorrow.year,
tomorrow.month)
mock_msg = Mock(wraps=db.Message)
mock_msg.fetch_messages.return_value = []
mt = self.proto.ps.settings.message_tables
for k in mt.keys():
del(mt[k])
mt[yester_table] = mock_msg
mt[today_table] = mock_msg
self._connect()
self.proto.ps.settings.add_tomorrow(today, today_table)
ok_(tomorrow_table in self.proto.ps.settings.message_tables)
"""

def test_hello(self):
self._connect()
self._send_message(dict(messageType="hello", channelIDs=[]))
Expand Down
9 changes: 7 additions & 2 deletions autopush/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,9 +776,14 @@ def _check_message_table_rotation(self, previous):
self.transport.pauseProducing()
# Check for table rotation
cur_month = previous.get("current_month")
# Previous month user or new user, flag for message rotation and
# set the message_month to the router month
if cur_month != self.ps.message_month:
# Previous month user or new user, flag for message rotation and
# set the message_month to the router month
if cur_month not in self.ps.settings.message_tables:
# This UAID has expired. Force client to reregister.
self.ps.uaid = uuid.uuid4().hex
self._finish_webpush_hello()
return
self.ps.message_month = cur_month
self.ps.rotate_message_table = True

Expand Down

0 comments on commit a06c5ad

Please sign in to comment.