From f1eb2f2ebe45dd05e4dd3e21923863ec8541d247 Mon Sep 17 00:00:00 2001 From: jrconlin Date: Wed, 25 Oct 2017 12:26:39 -0700 Subject: [PATCH] feat: convert to use AWS boto3 Issue #1050 Closes #1049 --- .travis.yml | 2 +- autopush/__init__.py | 3 + autopush/db.py | 470 ++++++++++++++++++++--------- autopush/tests/__init__.py | 1 + autopush/tests/test_db.py | 224 ++++++++++---- autopush/tests/test_health.py | 46 +-- autopush/tests/test_integration.py | 26 +- autopush/tests/test_websocket.py | 20 +- autopush/utils.py | 25 +- autopush/web/health.py | 2 +- tox.ini | 1 + 11 files changed, 580 insertions(+), 240 deletions(-) diff --git a/.travis.yml b/.travis.yml index 5009aa14..4cc77fb6 100644 --- a/.travis.yml +++ b/.travis.yml @@ -20,7 +20,7 @@ matrix: - env: TOXENV=py36-mypy WITH_RUST=false before_install: -# https://github.com/travis-ci/travis-ci/issues/7940 +#https://github.com/travis-ci/travis-ci/issues/7940 - sudo rm -f /etc/boto.cfg install: diff --git a/autopush/__init__.py b/autopush/__init__.py index 60a382fb..89b5086d 100644 --- a/autopush/__init__.py +++ b/autopush/__init__.py @@ -1 +1,4 @@ __version__ = '1.38.0' # pragma: nocover + +# Max DynamoDB record lifespan (~ 30 days) +MAX_EXPRY = 2592000 # pragma: nocover diff --git a/autopush/db.py b/autopush/db.py index cf3cb1d0..3b3c773c 100644 --- a/autopush/db.py +++ b/autopush/db.py @@ -43,16 +43,16 @@ attrib, Factory ) -from boto.exception import JSONResponseError + from boto.dynamodb2.exceptions import ( ConditionalCheckFailedException, ItemNotFound, ProvisionedThroughputExceededException, ) -from boto.dynamodb2.fields import HashKey, RangeKey, GlobalKeysOnlyIndex -from boto.dynamodb2.layer1 import DynamoDBConnection -from boto.dynamodb2.table import Table, Item -from boto.dynamodb2.types import NUMBER +import boto3 +from boto3.exceptions import Boto3Error +from botocore.exceptions import ClientError + from typing import ( # noqa TYPE_CHECKING, Any, @@ -72,6 +72,7 @@ from twisted.internet.threads import deferToThread import autopush.metrics +from autopush import MAX_EXPRY from autopush.exceptions import AutopushException from autopush.metrics import IMetrics # noqa from autopush.types import ItemLike # noqa @@ -92,6 +93,9 @@ TRACK_DB_CALLS = False DB_CALLS = [] +DYNAMODB = boto3.resource('dynamodb') +CLIENT = boto3.client('dynamodb') + def get_month(delta=0): # type: (int) -> datetime.date @@ -130,10 +134,7 @@ def dump_uaid(uaid_data): when dumped via repr. """ - if isinstance(uaid_data, Item): - return repr(uaid_data.items()) - else: - return repr(uaid_data) + return repr(uaid_data) def make_rotating_tablename(prefix, delta=0, date=None): @@ -151,12 +152,48 @@ def create_rotating_message_table(prefix="message", delta=0, date=None, # type: (str, int, Optional[datetime.date], int, int) -> Table """Create a new message table for webpush style message storage""" tablename = make_rotating_tablename(prefix, delta, date) - return Table.create(tablename, - schema=[HashKey("uaid"), - RangeKey("chidmessageid")], - throughput=dict(read=read_throughput, - write=write_throughput), - ) + + try: + table = DYNAMODB.Table(tablename) + if table.table_status == 'ACTIVE': + return table + except ClientError as ex: + if ex.response['Error']['Code'] != 'ResourceNotFoundException': + # If we hit this, our boto3 is misconfigured and we need to bail. + raise ex # pragma nocover + table = DYNAMODB.create_table( + TableName=tablename, + KeySchema=[ + { + 'AttributeName': 'uaid', + 'KeyType': 'HASH' + }, + { + 'AttributeName': 'chidmessageid', + 'KeyType': 'RANGE' + }], + AttributeDefinitions=[{ + 'AttributeName': 'uaid', + 'AttributeType': 'S' + }, + { + 'AttributeName': 'chidmessageid', + 'AttributeType': 'S' + }], + ProvisionedThroughput={ + 'ReadCapacityUnits': read_throughput, + 'WriteCapacityUnits': write_throughput + }) + table.meta.client.get_waiter('table_exists').wait( + TableName=tablename) + table.meta.client.update_time_to_live( + TableName=tablename, + TimeToLiveSpecification={ + 'Enabled': True, + 'AttributeName': 'expry' + } + ) + return table def get_rotating_message_table(prefix="message", delta=0, date=None, @@ -165,14 +202,14 @@ def get_rotating_message_table(prefix="message", delta=0, date=None, # type: (str, int, Optional[datetime.date], int, int) -> Table """Gets the message table for the current month.""" tablename = make_rotating_tablename(prefix, delta, date) - if not table_exists(DynamoDBConnection(), tablename): + if not table_exists(tablename): return create_rotating_message_table( prefix=prefix, delta=delta, date=date, read_throughput=message_read_throughput, write_throughput=message_write_throughput, ) else: - return Table(tablename) + return DYNAMODB.Table(tablename) def create_router_table(tablename="router", read_throughput=5, @@ -191,27 +228,78 @@ def create_router_table(tablename="router", read_throughput=5, cost of additional queries during GC to locate expired users. """ - return Table.create(tablename, - schema=[HashKey("uaid")], - throughput=dict(read=read_throughput, - write=write_throughput), - global_indexes=[ - GlobalKeysOnlyIndex( - 'AccessIndex', - parts=[ - HashKey('last_connect', - data_type=NUMBER)], - throughput=dict(read=5, write=5))], - ) + + table = DYNAMODB.create_table( + TableName=tablename, + KeySchema=[ + { + 'AttributeName': 'uaid', + 'KeyType': 'HASH' + } + ], + AttributeDefinitions=[ + { + 'AttributeName': 'uaid', + 'AttributeType': 'S' + }, + { + 'AttributeName': 'last_connect', + 'AttributeType': 'N' + }], + ProvisionedThroughput={ + 'ReadCapacityUnits': read_throughput, + 'WriteCapacityUnits': write_throughput, + }, + GlobalSecondaryIndexes=[{ + 'IndexName': 'AccessIndex', + 'KeySchema': [ + { + 'AttributeName': "last_connect", + 'KeyType': "HASH" + } + ], + 'Projection': { + 'ProjectionType': 'INCLUDE', + 'NonKeyAttributes': [ + 'last_connect' + ], + }, + 'ProvisionedThroughput': { + 'ReadCapacityUnits': read_throughput, + 'WriteCapacityUnits': write_throughput, + } + }] + ) + table.meta.client.get_waiter('table_exists').wait( + TableName=tablename) + table.meta.client.update_time_to_live( + TableName=tablename, + TimeToLiveSpecification={ + 'Enabled': True, + 'AttributeName': 'expry' + } + ) + return table + + +def _drop_table(tablename): + try: + CLIENT.delete_table(TableName=tablename) + except ClientError: # pragma nocover + pass def _make_table(table_func, tablename, read_throughput, write_throughput): # type: (Callable[[str, int, int], Table], str, int, int) -> Table """Private common function to make a table with a table func""" - if not table_exists(DynamoDBConnection(), tablename): + if not table_exists(tablename): return table_func(tablename, read_throughput, write_throughput) else: - return Table(tablename) + return DYNAMODB.Table(tablename) + + +def _expry(ttl): + return int(time.time() + ttl) def get_router_table(tablename="router", read_throughput=5, @@ -238,8 +326,8 @@ def preflight_check(message, router, uaid="deadbeef00000000deadbeef00000000"): # Verify tables are ready for use if they just got created ready = False while not ready: - tbl_status = [x.describe()["Table"]["TableStatus"] - for x in [message.table, router.table]] + tbl_status = [x.table_status() + for x in [message, router]] ready = all([status == "ACTIVE" for status in tbl_status]) if not ready: time.sleep(1) @@ -256,6 +344,7 @@ def preflight_check(message, router, uaid="deadbeef00000000deadbeef00000000"): update_id=message_id, message_id=message_id, ttl=60, + expry=_expry(60), ) # Store a notification, fetch it, delete it @@ -333,11 +422,14 @@ def generate_last_connect_values(date): yield int(val) -def list_tables(conn): +def list_tables(client=CLIENT): """Return a list of the names of all DynamoDB tables.""" start_table = None while True: - result = conn.list_tables(exclusive_start_table_name=start_table) + if start_table: + result = client.list_tables(ExclusiveStartTableName=start_table) + else: + result = client.list_tables() for table in result.get('TableNames', []): yield table start_table = result.get('LastEvaluatedTableName', None) @@ -345,14 +437,16 @@ def list_tables(conn): break -def table_exists(conn, tablename): +def table_exists(tablename, client=None): """Determine if the specified Table exists""" - return any(tablename == name for name in list_tables(conn)) + if not client: + client = CLIENT + return tablename in list_tables(client) class Message(object): """Create a Message table abstraction on top of a DynamoDB Table object""" - def __init__(self, table, metrics): + def __init__(self, table, metrics, max_ttl=MAX_EXPRY): # type: (Table, IMetrics) -> None """Create a new Message object @@ -363,23 +457,29 @@ def __init__(self, table, metrics): """ self.table = table self.metrics = metrics - self.encode = table._encode_keys + self._max_ttl = max_ttl + + def table_status(self): + return self.table.table_status @track_provisioned - def register_channel(self, uaid, channel_id): - # type: (str, str) -> bool + def register_channel(self, uaid, channel_id, ttl=None): + # type: (str, str, int) -> bool """Register a channel for a given uaid""" - conn = self.table.connection - db_key = self.encode({"uaid": hasher(uaid), "chidmessageid": " "}) # Generate our update expression - expr = "ADD chids :channel_id" - expr_values = self.encode({":channel_id": - set([normalize_id(channel_id)])}) - conn.update_item( - self.table.table_name, - db_key, - update_expression=expr, - expression_attribute_values=expr_values, + if ttl is None: + ttl = self._max_ttl + expr_values = { + ":channel_id": set([normalize_id(channel_id)]), + ":expry": _expry(ttl) + } + self.table.update_item( + Key={ + 'uaid': hasher(uaid), + 'chidmessageid': ' ', + }, + UpdateExpression='ADD chids :channel_id, expry :expry', + ExpressionAttributeValues=expr_values, ) return True @@ -387,23 +487,23 @@ def register_channel(self, uaid, channel_id): def unregister_channel(self, uaid, channel_id, **kwargs): # type: (str, str, **str) -> bool """Remove a channel registration for a given uaid""" - conn = self.table.connection - db_key = self.encode({"uaid": hasher(uaid), "chidmessageid": " "}) expr = "DELETE chids :channel_id" chid = normalize_id(channel_id) - expr_values = self.encode({":channel_id": set([chid])}) - - result = conn.update_item( - self.table.table_name, - db_key, - update_expression=expr, - expression_attribute_values=expr_values, - return_values="UPDATED_OLD", + expr_values = {":channel_id": set([chid])} + + response = self.table.update_item( + Key={ + 'uaid': hasher(uaid), + 'chidmessageid': ' ', + }, + UpdateExpression=expr, + ExpressionAttributeValues=expr_values, + ReturnValues="UPDATED_OLD", ) - chids = result.get('Attributes', {}).get('chids', {}) + chids = response.get('Attributes', {}).get('chids', {}) if chids: try: - return chid in self.table._dynamizer.decode(chids) + return chid in chids except (TypeError, AttributeError): # pragma: nocover pass # if, for some reason, there are no chids defined, return False. @@ -417,23 +517,31 @@ def all_channels(self, uaid): # Note: This only returns the chids associated with the UAID. # Functions that call store_message() would be required to # update that list as well using register_channel() - try: - result = self.table.get_item(consistent=True, - uaid=hasher(uaid), - chidmessageid=" ") - return (True, result["chids"] or set([])) - except ItemNotFound: + result = self.table.get_item( + Key={ + 'uaid': hasher(uaid), + 'chidmessageid': ' ', + }, + ConsistentRead=True + ) + if result['ResponseMetadata']['HTTPStatusCode'] != 200: + return False, set([]) + if 'Item' not in result: return False, set([]) + return True, result['Item'].get("chids", set([])) @track_provisioned def save_channels(self, uaid, channels): # type: (str, Set[str]) -> None """Save out a set of channels""" - self.table.put_item(data=dict( - uaid=hasher(uaid), - chidmessageid=" ", - chids=channels - ), overwrite=True) + self.table.put_item( + Item={ + 'uaid': hasher(uaid), + 'chidmessageid': ' ', + 'chids': channels, + 'expry': _expry(self._max_ttl), + }, + ) @track_provisioned def store_message(self, notification): @@ -442,13 +550,17 @@ def store_message(self, notification): item = dict( uaid=hasher(notification.uaid.hex), chidmessageid=notification.sort_key, - data=notification.data, headers=notification.headers, ttl=notification.ttl, timestamp=notification.timestamp, - updateid=notification.update_id + updateid=notification.update_id, + expry=_expry(min( + notification.ttl or 0, + self._max_ttl)) ) - self.table.put_item(data=item, overwrite=True) + if notification.data: + item['data'] = notification.data + self.table.put_item(Item=item) @track_provisioned def delete_message(self, notification): @@ -457,16 +569,24 @@ def delete_message(self, notification): if notification.update_id: try: self.table.delete_item( - uaid=hasher(notification.uaid.hex), - chidmessageid=notification.sort_key, - expected={'updateid__eq': notification.update_id}) - except ConditionalCheckFailedException: + Key={ + 'uaid': hasher(notification.uaid.hex), + 'chidmessageid': notification.sort_key + }, + Expected={ + 'updateid': { + 'Exists': True, + 'Value': notification.update_id + } + }) + except ClientError: return False else: self.table.delete_item( - uaid=hasher(notification.uaid.hex), - chidmessageid=notification.sort_key, - ) + Key={ + 'uaid': hasher(notification.uaid.hex), + 'chidmessageid': notification.sort_key, + }) return True @track_provisioned @@ -483,10 +603,21 @@ def fetch_messages( """ # Eagerly fetches all results in the result set. - results = list(self.table.query_2(uaid__eq=hasher(uaid.hex), - chidmessageid__lt="02", - consistent=True, limit=limit)) - + response = self.table.query( + KeyConditions={ + 'uaid': { + 'AttributeValueList': [hasher(uaid.hex)], + 'ComparisonOperator': 'EQ' + }, + 'chidmessageid': { + 'AttributeValueList': ['02'], + 'ComparisonOperator': 'LT' + }, + }, + ConsistentRead=True, + Limit=limit, + ) + results = list(response['Items']) # First extract the position if applicable, slightly higher than 01: # to ensure we don't load any 01 remainders that didn't get deleted # yet @@ -525,11 +656,23 @@ def fetch_timestamp_messages( else: sortkey = "01;" - results = list(self.table.query_2(uaid__eq=hasher(uaid.hex), - chidmessageid__gt=sortkey, - consistent=True, limit=limit)) + response = self.table.query( + KeyConditions={ + 'uaid': { + 'AttributeValueList': [hasher(uaid.hex)], + 'ComparisonOperator': 'EQ', + }, + 'chidmessageid': { + 'AttributeValueList': [sortkey], + 'ComparisonOperator': 'GT', + }, + }, + ConsistentRead=True, + Limit=limit + ) notifs = [ - WebPushNotification.from_message_table(uaid, x) for x in results + WebPushNotification.from_message_table(uaid, x) for x in + response.get("Items") ] ts_notifs = [x for x in notifs if x.sortkey_timestamp] last_position = None @@ -541,22 +684,23 @@ def fetch_timestamp_messages( def update_last_message_read(self, uaid, timestamp): # type: (uuid.UUID, int) -> bool """Update the last read timestamp for a user""" - conn = self.table.connection - db_key = self.encode({"uaid": hasher(uaid.hex), "chidmessageid": " "}) - expr = "SET current_timestamp=:timestamp" - expr_values = self.encode({":timestamp": timestamp}) - conn.update_item( - self.table.table_name, - db_key, - update_expression=expr, - expression_attribute_values=expr_values, + expr = "SET current_timestamp=:timestamp, expry=:expry" + expr_values = {":timestamp": timestamp, + ":expry": _expry(self._max_ttl)} + self.table.update_item( + Key={ + "uaid": hasher(uaid.hex), + "chidmessageid": " " + }, + UpdateExpression=expr, + ExpressionAttributeValues=expr_values, ) return True class Router(object): """Create a Router table abstraction on top of a DynamoDB Table object""" - def __init__(self, table, metrics): + def __init__(self, table, metrics, max_ttl=MAX_EXPRY): # type: (Table, IMetrics) -> None """Create a new Router object @@ -567,7 +711,10 @@ def __init__(self, table, metrics): """ self.table = table self.metrics = metrics - self.encode = table._encode_keys + self._max_ttl = max_ttl + + def table_status(self): + return self.table.table_status def get_uaid(self, uaid): # type: (str) -> Item @@ -580,7 +727,18 @@ def get_uaid(self, uaid): """ try: - item = self.table.get_item(consistent=True, uaid=hasher(uaid)) + item = self.table.get_item( + Key={ + 'uaid': hasher(uaid) + }, + ConsistentRead=True, + ) + + if item.get('ResponseMetadata').get('HTTPStatusCode') != 200: + raise ItemNotFound('uaid not found') + item = item.get('Item') + if item is None: + raise ItemNotFound("uaid not found") if item.keys() == ['uaid']: # Incomplete record, drop it. self.drop_user(uaid) @@ -591,7 +749,7 @@ def get_uaid(self, uaid): # will not see this, since JSONResponseError is a subclass and # will capture it raise - except JSONResponseError: # pragma: nocover + except Boto3Error: # pragma: nocover # We trap JSONResponseError because Moto returns text instead of # JSON when looking up values in empty tables. We re-throw the # correct ItemNotFound exception @@ -612,8 +770,7 @@ def register_user(self, data): """ # Fetch a senderid for this user - conn = self.table.connection - db_key = self.encode({"uaid": hasher(data["uaid"])}) + db_key = {"uaid": hasher(data["uaid"])} del data["uaid"] if "router_type" not in data or "connected_at" not in data: # Not specifying these values will generate an exception in AWS. @@ -621,7 +778,7 @@ def register_user(self, data): "or connected_at") # Generate our update expression expr = "SET " + ", ".join(["%s=:%s" % (x, x) for x in data.keys()]) - expr_values = self.encode({":%s" % k: v for k, v in data.items()}) + expr_values = {":%s" % k: v for k, v in data.items()} try: cond = """( attribute_not_exists(router_type) or @@ -630,14 +787,16 @@ def register_user(self, data): attribute_not_exists(node_id) or (connected_at < :connected_at) )""" - result = conn.update_item( - self.table.table_name, - db_key, - update_expression=expr, - condition_expression=cond, - expression_attribute_values=expr_values, - return_values="ALL_OLD", - ) + try: + result = self.table.update_item( + Key=db_key, + UpdateExpression=expr, + ConditionExpression=cond, + ExpressionAttributeValues=expr_values, + ReturnValues="ALL_OLD", + ) + except ClientError as ex: + raise ex if "Attributes" in result: r = {} for key, value in result["Attributes"].items(): @@ -649,6 +808,11 @@ def register_user(self, data): r[key] = value result = r return (True, result) + except ClientError as ex: + # ClientErrors are generated by a factory, and while they have a + # class, it's dynamically generated. + if 'ConditionalCheckFailedException' == ex.__class__.__name__: + return (False, {}) except ConditionalCheckFailedException: return (False, {}) @@ -658,16 +822,19 @@ def drop_user(self, uaid): """Drops a user record""" # The following hack ensures that only uaids that exist and are # deleted return true. - huaid = hasher(uaid) - return self.table.delete_item(uaid=huaid, - expected={"uaid__eq": huaid}) + try: + self.get_uaid(uaid) + except ItemNotFound: + return False + result = self.table.delete_item(Key={'uaid': hasher(uaid)}) + return result['ResponseMetadata']['HTTPStatusCode'] == 200 def delete_uaids(self, uaids): # type: (List[str]) -> None """Issue a batch delete call for the given uaids""" - with self.table.batch_write() as batch: + with self.table.batch_writer() as batch: for uaid in uaids: - batch.delete_item(uaid=uaid) + batch.delete_item(Key={'uaid': uaid}) def drop_old_users(self, months_ago=2): # type: (int) -> Iterable[int] @@ -699,10 +866,16 @@ def drop_old_users(self, months_ago=2): batched = [] for hash_key in generate_last_connect_values(prior_date): - result_set = self.table.query_2( - last_connect__eq=hash_key, - index="AccessIndex", + response = self.table.query( + KeyConditions={ + 'last_connect': { + 'AttributeValueList': [hash_key], + 'ComparisonOperator': 'EQ' + } + }, + IndexName="AccessIndex", ) + result_set = response.get('Items', []) for result in result_set: batched.append(result["uaid"]) @@ -716,6 +889,14 @@ def drop_old_users(self, months_ago=2): self.delete_uaids(batched) yield len(batched) + @track_provisioned + def _update_last_connect(self, uaid, last_connect): + self.table.update_item( + Key={"uaid": hasher(uaid)}, + UpdateExpression="SET last_connect=:last_connect", + ExpressionAttributeValues={":last_connect": last_connect} + ) + @track_provisioned def update_message_month(self, uaid, month): # type: (str, str) -> bool @@ -727,23 +908,23 @@ def update_message_month(self, uaid, month): timestamp. """ - conn = self.table.connection - db_key = self.encode({"uaid": hasher(uaid)}) - expr = ("SET current_month=:curmonth, last_connect=:last_connect") - expr_values = self.encode({":curmonth": month, - ":last_connect": generate_last_connect(), - }) - conn.update_item( - self.table.table_name, - db_key, - update_expression=expr, - expression_attribute_values=expr_values, + db_key = {"uaid": hasher(uaid)} + expr = ("SET current_month=:curmonth, last_connect=:last_connect, " + "expry=:expry") + expr_values = {":curmonth": month, + ":last_connect": generate_last_connect(), + ":expry": _expry(self._max_ttl), + } + self.table.update_item( + Key=db_key, + UpdateExpression=expr, + ExpressionAttributeValues=expr_values, ) return True @track_provisioned def clear_node(self, item): - # type: (Item) -> bool + # type: (dict) -> bool """Given a router item and remove the node_id The node_id will only be cleared if the ``connected_at`` matches up @@ -755,23 +936,24 @@ def clear_node(self, item): exceeds throughput. """ - conn = self.table.connection # Pop out the node_id node_id = item["node_id"] del item["node_id"] try: cond = "(node_id = :node) and (connected_at = :conn)" - conn.put_item( - self.table.table_name, - item=self.encode(item), - condition_expression=cond, - expression_attribute_values=self.encode({ + self.table.put_item( + Item=item, + ConditionExpression=cond, + ExpressionAttributeValues={ ":node": node_id, ":conn": item["connected_at"], - }), + }, ) return True + except ClientError: + # UAID not found. + return False except ConditionalCheckFailedException: return False @@ -789,6 +971,8 @@ class DatabaseManager(object): message_tables = attrib(default=Factory(dict)) # type: Dict[str, Message] current_msg_month = attrib(init=False) # type: Optional[str] current_month = attrib(init=False) # type: Optional[int] + # for testing: + client = attrib(default=CLIENT) # type: Optional[Any] def __attrs_post_init__(self): """Initialize sane defaults""" diff --git a/autopush/tests/__init__.py b/autopush/tests/__init__.py index 9940b8ee..9adf1027 100644 --- a/autopush/tests/__init__.py +++ b/autopush/tests/__init__.py @@ -29,6 +29,7 @@ def setUp(): # Setup the necessary message tables message_table = os.environ.get("MESSAGE_TABLE", "message_int_test") + create_rotating_message_table(prefix=message_table, delta=-1) create_rotating_message_table(prefix=message_table) diff --git a/autopush/tests/test_db.py b/autopush/tests/test_db.py index cd6c71df..b98b8d81 100644 --- a/autopush/tests/test_db.py +++ b/autopush/tests/test_db.py @@ -1,5 +1,6 @@ import unittest import uuid +import time from datetime import datetime, timedelta from autopush.websocket import ms_time @@ -8,8 +9,7 @@ ProvisionedThroughputExceededException, ItemNotFound, ) -from boto.dynamodb2.layer1 import DynamoDBConnection -from boto.dynamodb2.items import Item +from botocore.exceptions import ClientError from mock import Mock import pytest @@ -22,7 +22,9 @@ Message, Router, generate_last_connect, - make_rotating_tablename) + make_rotating_tablename, + _drop_table, + _make_table) from autopush.exceptions import AutopushException from autopush.metrics import SinkMetrics from autopush.utils import WebPushNotification @@ -40,9 +42,19 @@ def make_webpush_notification(uaid, chid, ttl=100): update_id=message_id, message_id=message_id, ttl=ttl, + expry=time.time() + ttl, ) +class DbUtilsTest(unittest.TestCase): + def test_make_table(self): + fake_func = Mock() + fake_table = "DoesNotExist_{}".format(uuid.uuid4()) + + _make_table(fake_func, fake_table, 5, 10) + assert fake_func.call_args[0] == (fake_table, 5, 10) + + class DbCheckTestCase(unittest.TestCase): def test_preflight_check_fail(self): router = Router(get_router_table(), SinkMetrics()) @@ -73,17 +85,9 @@ def test_preflight_check_wait(self): router = Router(get_router_table(), SinkMetrics()) message = Message(get_rotating_message_table(), SinkMetrics()) - message.table.describe = mock_describe = Mock() + values = ["PENDING", "ACTIVE"] + message.table_status = Mock(side_effect=values) - values = [ - dict(Table=dict(TableStatus="PENDING")), - dict(Table=dict(TableStatus="ACTIVE")), - ] - - def return_vals(*args, **kwargs): - return values.pop(0) - - mock_describe.side_effect = return_vals pf_uaid = "deadbeef00000000deadbeef01010101" preflight_check(message, router, pf_uaid) # now check that the database reports no entries. @@ -128,11 +132,10 @@ class MessageTestCase(unittest.TestCase): def setUp(self): table = get_rotating_message_table() self.real_table = table - self.real_connection = table.connection self.uaid = str(uuid.uuid4()) def tearDown(self): - self.real_table.connection = self.real_connection + pass def test_register(self): chid = str(uuid.uuid4()) @@ -140,10 +143,21 @@ def test_register(self): message = Message(m, SinkMetrics()) message.register_channel(self.uaid, chid) - # Verify its in the db - rows = m.query_2(uaid__eq=self.uaid, chidmessageid__eq=" ") - results = list(rows) - assert len(results) == 1 + # Verify it's in the db + response = m.query( + KeyConditions={ + 'uaid': { + 'AttributeValueList': [self.uaid], + 'ComparisonOperator': 'EQ' + }, + 'chidmessageid': { + 'AttributeValueList': ['02'], + 'ComparisonOperator': 'LT' + } + }, + ConsistentRead=True, + ) + assert len(response.get('Items')) def test_unregister(self): chid = str(uuid.uuid4()) @@ -152,24 +166,48 @@ def test_unregister(self): message.register_channel(self.uaid, chid) # Verify its in the db - rows = m.query_2(uaid__eq=self.uaid, chidmessageid__eq=" ") - results = list(rows) + response = m.query( + KeyConditions={ + 'uaid': { + 'AttributeValueList': [self.uaid], + 'ComparisonOperator': 'EQ' + }, + 'chidmessageid': { + 'AttributeValueList': [" "], + 'ComparisonOperator': 'EQ' + }, + }, + ConsistentRead=True, + ) + results = list(response.get('Items')) assert len(results) == 1 assert results[0]["chids"] == {chid} message.unregister_channel(self.uaid, chid) # Verify its not in the db - rows = m.query_2(uaid__eq=self.uaid, chidmessageid__eq=" ") - results = list(rows) + response = m.query( + KeyConditions={ + 'uaid': { + 'AttributeValueList': [self.uaid], + 'ComparisonOperator': 'EQ' + }, + 'chidmessageid': { + 'AttributeValueList': [" "], + 'ComparisonOperator': 'EQ' + }, + }, + ConsistentRead=True, + ) + results = list(response.get('Items')) assert len(results) == 1 - assert results[0]["chids"] is None + assert results[0].get("chids") is None # Test for the very unlikely case that there's no 'chid' - m.connection.update_item = Mock() - m.connection.update_item.return_value = { - 'Attributes': {'uaid': {'S': self.uaid}}, - 'ConsumedCapacityUnits': 0.5} + m.update_item = Mock(return_value={ + 'Attributes': {'uaid': self.uaid}, + 'ResponseMetaData': {} + }) r = message.unregister_channel(self.uaid, dummy_chid) assert r is False @@ -190,6 +228,20 @@ def test_all_channels(self): assert chid2 not in chans assert chid in chans + def test_all_channels_fail(self): + m = get_rotating_message_table() + message = Message(m, SinkMetrics()) + + message.table.get_item = Mock() + message.table.get_item.return_value = { + "ResponseMetadata": { + "HTTPStatusCode": 400 + }, + } + + res = message.all_channels(self.uaid) + assert res == (False, set([])) + def test_save_channels(self): chid = str(uuid.uuid4()) chid2 = str(uuid.uuid4()) @@ -254,7 +306,7 @@ def test_message_delete_fail_condition(self): message = Message(m, SinkMetrics()) def raise_condition(*args, **kwargs): - raise ConditionalCheckFailedException(None, None) + raise ClientError({}, 'delete_item') message.table = Mock() message.table.delete_item.side_effect = raise_condition @@ -268,6 +320,8 @@ def test_message_rotate_table_with_date(self): m = get_rotating_message_table(prefix=prefix, date=future) assert m.table_name == tbl_name + # Clean up the temp table. + _drop_table(tbl_name) class RouterTestCase(unittest.TestCase): @@ -275,11 +329,11 @@ class RouterTestCase(unittest.TestCase): def setup_class(self): table = get_router_table() self.real_table = table - self.real_connection = table.connection + self.real_connection = table.meta.client @classmethod def teardown_class(self): - self.real_table.connection = self.real_connection + self.real_table.meta.client = self.real_connection def _create_minimal_record(self): data = { @@ -294,6 +348,8 @@ def test_drop_old_users(self): # First create a bunch of users r = get_router_table() router = Router(r, SinkMetrics()) + # Purge any existing users from previous runs. + router.drop_old_users(0) for _ in range(0, 53): router.register_user(self._create_minimal_record()) @@ -301,18 +357,19 @@ def test_drop_old_users(self): assert list(results) == [25, 25, 3] def test_custom_tablename(self): - db = DynamoDBConnection() db_name = "router_%s" % uuid.uuid4() - assert not table_exists(db, db_name) + assert not table_exists(db_name) create_router_table(db_name) - assert table_exists(db, db_name) + assert table_exists(db_name) + # Clean up the temp table. + _drop_table(db_name) def test_provisioning(self): db_name = "router_%s" % uuid.uuid4() r = create_router_table(db_name, 3, 17) - assert r.throughput["read"] == 3 - assert r.throughput["write"] == 17 + assert r.provisioned_throughput.get('ReadCapacityUnits') == 3 + assert r.provisioned_throughput.get('WriteCapacityUnits') == 17 def test_no_uaid_found(self): uaid = str(uuid.uuid4()) @@ -336,31 +393,58 @@ def raise_error(*args, **kwargs): def test_register_user_provision_failed(self): r = get_router_table() router = Router(r, SinkMetrics()) - router.table.connection = Mock() + router.table.meta.client = Mock() def raise_error(*args, **kwargs): raise ProvisionedThroughputExceededException(None, None) - router.table.connection.update_item.side_effect = raise_error + router.table.update_item = Mock(side_effect=raise_error) with pytest.raises(ProvisionedThroughputExceededException): router.register_user(dict(uaid=dummy_uaid, node_id="me", connected_at=1234, router_type="webpush")) + def test_register_user_condition_failed(self): + r = get_router_table() + router = Router(r, SinkMetrics()) + router.table.meta.client = Mock() + + def raise_error(*args, **kwargs): + raise ConditionalCheckFailedException(None, None) + + router.table.update_item = Mock(side_effect=raise_error) + res = router.register_user(dict(uaid=dummy_uaid, node_id="me", + connected_at=1234, + router_type="webpush")) + assert res == (False, {}) + def test_clear_node_provision_failed(self): r = get_router_table() router = Router(r, SinkMetrics()) - router.table.connection.put_item = Mock() def raise_error(*args, **kwargs): raise ProvisionedThroughputExceededException(None, None) - router.table.connection.put_item.side_effect = raise_error + router.table.put_item = Mock(side_effect=raise_error) with pytest.raises(ProvisionedThroughputExceededException): - router.clear_node(Item(r, dict(uaid=dummy_uaid, - connected_at="1234", - node_id="asdf", - router_type="webpush"))) + router.clear_node(dict(uaid=dummy_uaid, + connected_at="1234", + node_id="asdf", + router_type="webpush")) + + def test_clear_node_condition_failed(self): + r = get_router_table() + router = Router(r, SinkMetrics()) + + def raise_error(*args, **kwargs): + raise ConditionalCheckFailedException(None, None) + + router.table.put_item = Mock(side_effect=raise_error) + res = router.clear_node(dict(uaid=dummy_uaid, + connected_at="1234", + node_id="asdf", + router_type="webpush")) + assert res is False def test_incomplete_uaid(self): # Older records may be incomplete. We can't inject them using normal @@ -370,7 +454,14 @@ def test_incomplete_uaid(self): router = Router(r, SinkMetrics()) router.table.get_item = Mock() router.drop_user = Mock() - router.table.get_item.return_value = {"uaid": uuid.uuid4().hex} + router.table.get_item.return_value = { + "ResponseMetadata": { + "HTTPStatusCode": 200 + }, + "Item": { + "uaid": uuid.uuid4().hex + } + } try: router.register_user(dict(uaid=uaid)) except AutopushException: @@ -379,14 +470,28 @@ def test_incomplete_uaid(self): router.get_uaid(uaid) assert router.drop_user.called + def test_failed_uaid(self): + uaid = str(uuid.uuid4()) + r = get_router_table() + router = Router(r, SinkMetrics()) + router.table.get_item = Mock() + router.drop_user = Mock() + router.table.get_item.return_value = { + "ResponseMetadata": { + "HTTPStatusCode": 400 + }, + } + with pytest.raises(ItemNotFound): + router.get_uaid(uaid) + def test_save_new(self): r = get_router_table() router = Router(r, SinkMetrics()) # Sadly, moto currently does not return an empty value like boto # when not updating data. - router.table.connection = Mock() - router.table.connection.update_item.return_value = {} - result = router.register_user(dict(uaid="", node_id="me", + router.table.update_item = Mock(return_value={}) + result = router.register_user(dict(uaid=dummy_uaid, + node_id="me", router_type="webpush", connected_at=1234)) assert result[0] is True @@ -396,10 +501,14 @@ def test_save_fail(self): router = Router(r, SinkMetrics()) def raise_condition(*args, **kwargs): - raise ConditionalCheckFailedException(None, None) - - router.table.connection = Mock() - router.table.connection.update_item.side_effect = raise_condition + # client.exceptions.ConditionalCheckFailedException + import boto3 + raise boto3.client( + 'dynamodb').exceptions.ConditionalCheckFailedException( + {}, 'mock_update_item', + ) + + router.table.update_item = Mock(side_effect=raise_condition) router_data = dict(uaid=dummy_uaid, node_id="asdf", connected_at=1234, router_type="webpush") result = router.register_user(router_data) @@ -433,12 +542,15 @@ def test_node_clear_fail(self): router = Router(r, SinkMetrics()) def raise_condition(*args, **kwargs): - raise ConditionalCheckFailedException(None, None) + raise ClientError( + error_response={ + 'Code': 'ConditionalCheckFailedException', + }, + operation_name='update_item') - router.table.connection.put_item = Mock() - router.table.connection.put_item.side_effect = raise_condition + router.table.put_item = Mock(side_effect=raise_condition) data = dict(uaid=dummy_uaid, node_id="asdf", connected_at=1234) - result = router.clear_node(Item(r, data)) + result = router.clear_node(data) assert result is False def test_drop_user(self): diff --git a/autopush/tests/test_health.py b/autopush/tests/test_health.py index df47605c..097e7ad4 100644 --- a/autopush/tests/test_health.py +++ b/autopush/tests/test_health.py @@ -2,7 +2,7 @@ import twisted.internet.base from boto.dynamodb2.exceptions import InternalServerError -from mock import Mock +from mock import Mock, patch from twisted.internet.defer import inlineCallbacks from twisted.logger import globalLogPublisher from twisted.trial import unittest @@ -20,7 +20,7 @@ class HealthTestCase(unittest.TestCase): def setUp(self): - self.timeout = 0.5 + # self.timeout = 0.5 twisted.internet.base.DelayedCall.debug = True conf = AutopushConfig( @@ -51,33 +51,38 @@ def test_healthy(self): }) @inlineCallbacks - def test_aws_error(self): - from autopush.db import make_rotating_tablename + @patch('boto3.client') + def test_aws_error(self, mb): def raise_error(*args, **kwargs): raise InternalServerError(None, None) - self.router_table.connection.list_tables = Mock( - side_effect=raise_error) - table_name = make_rotating_tablename("message") - self.message.table.connection.list_tables = Mock( - return_value={"TableNames": [table_name]}) + + safe = self.client.app.db.client + self.client.app.db.client = Mock() + self.client.app.db.client.list_tables = Mock(side_effect=raise_error) yield self._assert_reply({ "status": "NOT OK", "version": __version__, "clients": 0, - "storage": {"status": "OK"}, + "storage": { + "status": "NOT OK", + "error": "Server error" + }, "router": { "status": "NOT OK", "error": "Server error" } }, InternalServerError) + self.client.app.db.client = safe + @inlineCallbacks def test_nonexistent_table(self): no_tables = Mock(return_value={"TableNames": []}) - self.message.table.connection.list_tables = no_tables - self.router_table.connection.list_tables = no_tables + safe = self.client.app.db.client + self.client.app.db.client = Mock() + self.client.app.db.client.list_tables = no_tables yield self._assert_reply({ "status": "NOT OK", @@ -92,15 +97,18 @@ def test_nonexistent_table(self): "error": "Nonexistent table" } }, MissingTableException) + self.client.app.db.client = safe @inlineCallbacks def test_internal_error(self): def raise_error(*args, **kwargs): raise Exception("synergies not aligned") - self.router_table.connection.list_tables = Mock( - return_value={"TableNames": ["router"]}) - self.message.table.connection.list_tables = Mock( - side_effect=raise_error) + + safe = self.client.app.db.client + self.client.app.db.client = Mock() + self.client.app.db.client.list_tables = Mock( + side_effect=raise_error + ) yield self._assert_reply({ "status": "NOT OK", @@ -110,8 +118,12 @@ def raise_error(*args, **kwargs): "status": "NOT OK", "error": "Internal error" }, - "router": {"status": "OK"} + "router": { + "status": "NOT OK", + "error": "Internal error" + } }, Exception) + self.client.app.db.client = safe @inlineCallbacks def _assert_reply(self, reply, exception=None): diff --git a/autopush/tests/test_integration.py b/autopush/tests/test_integration.py index b25fa39e..e9f11d58 100644 --- a/autopush/tests/test_integration.py +++ b/autopush/tests/test_integration.py @@ -1167,9 +1167,15 @@ def test_webpush_monthly_rotation(self): # Verify last_connect is current, then move that back assert has_connected_this_month(c) today = get_month(delta=-1) - c["last_connect"] = int("%s%s020001" % (today.year, - str(today.month).zfill(2))) - yield deferToThread(c.partial_save) + last_connect = int("%s%s020001" % (today.year, + str(today.month).zfill(2))) + + yield deferToThread( + self.conn.db.router._update_last_connect, + client.uaid, + last_connect, + ) + c = yield deferToThread(self.conn.db.router.get_uaid, client.uaid) assert has_connected_this_month(c) is False # Move the clients channels back one month @@ -1187,9 +1193,8 @@ def test_webpush_monthly_rotation(self): # Remove the channels entry entirely from this month yield deferToThread( self.conn.db.message.table.delete_item, - uaid=client.uaid, - chidmessageid=" " - ) + Key={'uaid': client.uaid, 'chidmessageid': ' '} + ) # Verify the channel is gone exists, chans = yield deferToThread( @@ -1278,9 +1283,12 @@ def test_webpush_monthly_rotation_prior_record_exists(self): # Verify last_connect is current, then move that back assert has_connected_this_month(c) today = get_month(delta=-1) - c["last_connect"] = int("%s%s020001" % (today.year, - str(today.month).zfill(2))) - yield deferToThread(c.partial_save) + yield deferToThread( + self.conn.db.router._update_last_connect, + client.uaid, + int("%s%s020001" % (today.year, str(today.month).zfill(2))) + ) + c = yield deferToThread(self.conn.db.router.get_uaid, client.uaid) assert has_connected_this_month(c) is False # Move the clients channels back one month diff --git a/autopush/tests/test_websocket.py b/autopush/tests/test_websocket.py index 85681e35..67b87bf1 100644 --- a/autopush/tests/test_websocket.py +++ b/autopush/tests/test_websocket.py @@ -454,12 +454,14 @@ def test_hello_old(self): "current_month": msg_date, } router = self.proto.db.router - router.table.put_item(data=dict( - uaid=orig_uaid, - connected_at=ms_time(), - current_month=msg_date, - router_type="webpush" - )) + router.table.put_item( + Item=dict( + uaid=orig_uaid, + connected_at=ms_time(), + current_month=msg_date, + router_type="webpush" + ) + ) def fake_msg(data): return (True, msg_data) @@ -708,7 +710,7 @@ def test_hello_failure(self): self._connect() # Fail out the register_user call router = self.proto.db.router - router.table.connection.update_item = Mock(side_effect=KeyError) + router.table.update_item = Mock(side_effect=KeyError) self._send_message(dict(messageType="hello", channelIDs=[], use_webpush=True, stop=1)) @@ -727,7 +729,7 @@ def throw_error(*args, **kwargs): raise ProvisionedThroughputExceededException(None, None) router = self.proto.db.router - router.table.connection.update_item = Mock(side_effect=throw_error) + router.table.update_item = Mock(side_effect=throw_error) self._send_message(dict(messageType="hello", use_webpush=True, channelIDs=[])) @@ -747,7 +749,7 @@ def throw_error(*args, **kwargs): raise JSONResponseError(None, None) router = self.proto.db.router - router.table.connection.update_item = Mock(side_effect=throw_error) + router.table.update_item = Mock(side_effect=throw_error) self._send_message(dict(messageType="hello", use_webpush=True, channelIDs=[])) diff --git a/autopush/utils.py b/autopush/utils.py index 66c35ff7..e42322c1 100644 --- a/autopush/utils.py +++ b/autopush/utils.py @@ -25,6 +25,7 @@ import ecdsa from jose import jwt +from autopush import MAX_EXPRY from autopush.exceptions import (InvalidTokenException, VapidAuthException) from autopush.jwt import repad, VerifyJWT from autopush.types import ItemLike # noqa @@ -301,6 +302,8 @@ class WebPushNotification(object): # Whether this notification should follow legacy non-topic rules legacy = attrib(default=False) # type: bool + expry = attrib(default=MAX_EXPRY) # type: int + def generate_message_id(self, fernet): # type: (Fernet) -> str """Generate a message-id suitable for accessing the message @@ -335,6 +338,12 @@ def generate_message_id(self, fernet): self.update_id = self.message_id return self.message_id + @classmethod + def calc_expry(cls, ttl=0): + if not ttl: + return 0 + return int(time.time() + int(ttl)) + @staticmethod def parse_decrypted_message_id(decrypted_token): # type: (str) -> Dict[str, Any] @@ -475,13 +484,13 @@ def from_message_table(cls, uaid, item): key_info = cls.parse_sort_key(item["chidmessageid"]) if key_info["api_ver"] in ["01", "02"]: key_info["message_id"] = item["updateid"] - notif = cls( uaid=uaid, channel_id=uuid.UUID(key_info["channel_id"]), data=item.get("data"), headers=item.get("headers"), ttl=item["ttl"], + expry=cls.calc_expry(item['ttl']), topic=key_info.get("topic"), message_id=key_info["message_id"], update_id=item.get("updateid"), @@ -505,9 +514,14 @@ def from_webpush_request_schema(cls, data, fernet, legacy=False): """ sub = data["subscription"] notif = cls( - uaid=sub["uaid"], channel_id=sub["chid"], data=data["body"], - headers=data["headers"], ttl=data["headers"]["ttl"], - topic=data["headers"]["topic"], legacy=legacy, + uaid=sub["uaid"], + channel_id=sub["chid"], + data=data["body"], + headers=data["headers"], + ttl=data["headers"]["ttl"], + expry=cls.calc_expry(data["headers"]["ttl"]), + topic=data["headers"]["topic"], + legacy=legacy, ) if notif.data: @@ -538,6 +552,7 @@ def from_message_id(cls, message_id, fernet): channel_id=uuid.UUID(key_info["chid"]), data=None, ttl=None, + expry=0, topic=key_info["topic"], message_id=message_id, sortkey_timestamp=key_info.get("sortkey_timestamp"), @@ -554,6 +569,7 @@ def from_serialized(cls, uaid, data): data=data.get("data"), headers=data.get("headers"), ttl=data.get("ttl"), + expry=cls.calc_expry(data.get('ttl', 0)), topic=data.get("topic"), message_id=str(data["version"]), update_id=str(data["version"]), @@ -579,6 +595,7 @@ def serialize(self): channelID=normalize_id(self.channel_id), version=self.version, ttl=self.ttl, + expry=self.expry, topic=self.topic, timestamp=self.timestamp, ) diff --git a/autopush/web/health.py b/autopush/web/health.py index 2bd5fe4f..d28265b2 100644 --- a/autopush/web/health.py +++ b/autopush/web/health.py @@ -40,7 +40,7 @@ def get(self): def _check_table(self, table, name_over=None): """Checks the tables known about in DynamoDB""" - d = deferToThread(table_exists, table.connection, table.table_name) + d = deferToThread(table_exists, table.table_name, self.db.client) d.addCallback(self._check_success, name_over or table.table_name) d.addErrback(self._check_error, name_over or table.table_name) return d diff --git a/tox.ini b/tox.ini index 629eaf6c..e46a1caf 100644 --- a/tox.ini +++ b/tox.ini @@ -5,6 +5,7 @@ envlist = py27,pypy,flake8,py36-mypy deps = -rtest-requirements.txt usedevelop = True passenv = SKIP_INTEGRATION +setenv = AWS_DEFAULT_REGION=us-east-1 commands = nosetests {posargs} autopush install_command = pip install --pre {opts} {packages}