diff --git a/.travis.yml b/.travis.yml index 5009aa14..6a2e6457 100644 --- a/.travis.yml +++ b/.travis.yml @@ -28,6 +28,8 @@ install: - pip install tox ${CODECOV:+codecov} - if [ ${WITH_RUST:-true} != "false" ]; then curl https://sh.rustup.rs | sh -s -- -y || travis_terminate 1; fi - export PATH=$PATH:$HOME/.cargo/bin +- export AWS_SHARED_CREDENTIALS_FILE=./automock/credentials.cfg +- export BOTO_CONFIG=./automock/boto.cfg script: - tox -- ${CODECOV:+--with-coverage --cover-xml --cover-package=autopush} after_success: diff --git a/automock/boto.cfg b/automock/boto.cfg index be93ac94..318b7a7f 100644 --- a/automock/boto.cfg +++ b/automock/boto.cfg @@ -1,9 +1,11 @@ -[Credentials] -aws_access_key_id = -aws_secret_access_key = +[default] [Boto] is_secure = False https_validate_certificates = False proxy_port = 8000 proxy = 127.0.0.1 + +[DynamoDB] +region=us-east-1 +validate_checksums=False diff --git a/autopush/__init__.py b/autopush/__init__.py index 0c7a68a3..b893a1d7 100644 --- a/autopush/__init__.py +++ b/autopush/__init__.py @@ -1 +1,4 @@ __version__ = '1.39.0' # pragma: nocover + +# Max DynamoDB record lifespan (~ 30 days) +MAX_EXPIRY = 2592000 # pragma: nocover diff --git a/autopush/db.py b/autopush/db.py index cf3cb1d0..bae7eaea 100644 --- a/autopush/db.py +++ b/autopush/db.py @@ -32,6 +32,7 @@ from __future__ import absolute_import import datetime +import os import random import time import uuid @@ -43,16 +44,16 @@ attrib, Factory ) -from boto.exception import JSONResponseError + from boto.dynamodb2.exceptions import ( - ConditionalCheckFailedException, ItemNotFound, - ProvisionedThroughputExceededException, ) -from boto.dynamodb2.fields import HashKey, RangeKey, GlobalKeysOnlyIndex -from boto.dynamodb2.layer1 import DynamoDBConnection -from boto.dynamodb2.table import Table, Item -from boto.dynamodb2.types import NUMBER +import boto3 +import botocore +from boto3.dynamodb.conditions import Key +from boto3.exceptions import Boto3Error +from botocore.exceptions import ClientError + from typing import ( # noqa TYPE_CHECKING, Any, @@ -72,6 +73,7 @@ from twisted.internet.threads import deferToThread import autopush.metrics +from autopush import MAX_EXPIRY from autopush.exceptions import AutopushException from autopush.metrics import IMetrics # noqa from autopush.types import ItemLike # noqa @@ -92,6 +94,16 @@ TRACK_DB_CALLS = False DB_CALLS = [] +# See https://botocore.readthedocs.io/en/stable/reference/config.html for +# additional config options +g_dynamodb = boto3.resource( + 'dynamodb', + config=botocore.config.Config( + region_name=os.getenv("AWS_REGION_NAME", "us-east-1") + ) +) +g_client = g_dynamodb.meta.client + def get_month(delta=0): # type: (int) -> datetime.date @@ -121,21 +133,6 @@ def hasher(uaid): return uaid -def dump_uaid(uaid_data): - # type: (ItemLike) -> str - """Return a dict for a uaid. - - This is utilized instead of repr since some db methods return a - DynamoDB Item which does not actually show its dict key/values - when dumped via repr. - - """ - if isinstance(uaid_data, Item): - return repr(uaid_data.items()) - else: - return repr(uaid_data) - - def make_rotating_tablename(prefix, delta=0, date=None): # type: (str, int, Optional[datetime.date]) -> str """Creates a tablename for table rotation based on a prefix with a given @@ -151,12 +148,54 @@ def create_rotating_message_table(prefix="message", delta=0, date=None, # type: (str, int, Optional[datetime.date], int, int) -> Table """Create a new message table for webpush style message storage""" tablename = make_rotating_tablename(prefix, delta, date) - return Table.create(tablename, - schema=[HashKey("uaid"), - RangeKey("chidmessageid")], - throughput=dict(read=read_throughput, - write=write_throughput), - ) + + try: + table = g_dynamodb.Table(tablename) + if table.table_status == 'ACTIVE': # pragma nocover + return table + except ClientError as ex: + if ex.response['Error']['Code'] != 'ResourceNotFoundException': + # If we hit this, our boto3 is misconfigured and we need to bail. + raise ex # pragma nocover + table = g_dynamodb.create_table( + TableName=tablename, + KeySchema=[ + { + 'AttributeName': 'uaid', + 'KeyType': 'HASH' + }, + { + 'AttributeName': 'chidmessageid', + 'KeyType': 'RANGE' + }], + AttributeDefinitions=[ + { + 'AttributeName': 'uaid', + 'AttributeType': 'S' + }, + { + 'AttributeName': 'chidmessageid', + 'AttributeType': 'S' + }], + ProvisionedThroughput={ + 'ReadCapacityUnits': read_throughput, + 'WriteCapacityUnits': write_throughput + }) + table.meta.client.get_waiter('table_exists').wait( + TableName=tablename) + try: + table.meta.client.update_time_to_live( + TableName=tablename, + TimeToLiveSpecification={ + 'Enabled': True, + 'AttributeName': 'expiry' + } + ) + except ClientError as ex: # pragma nocover + if ex.response['Error']['Code'] != 'UnknownOperationException': + # DynamoDB local library does not yet support TTL + raise + return table def get_rotating_message_table(prefix="message", delta=0, date=None, @@ -165,14 +204,14 @@ def get_rotating_message_table(prefix="message", delta=0, date=None, # type: (str, int, Optional[datetime.date], int, int) -> Table """Gets the message table for the current month.""" tablename = make_rotating_tablename(prefix, delta, date) - if not table_exists(DynamoDBConnection(), tablename): + if not table_exists(tablename): return create_rotating_message_table( prefix=prefix, delta=delta, date=date, read_throughput=message_read_throughput, write_throughput=message_write_throughput, ) else: - return Table(tablename) + return g_dynamodb.Table(tablename) def create_router_table(tablename="router", read_throughput=5, @@ -191,27 +230,84 @@ def create_router_table(tablename="router", read_throughput=5, cost of additional queries during GC to locate expired users. """ - return Table.create(tablename, - schema=[HashKey("uaid")], - throughput=dict(read=read_throughput, - write=write_throughput), - global_indexes=[ - GlobalKeysOnlyIndex( - 'AccessIndex', - parts=[ - HashKey('last_connect', - data_type=NUMBER)], - throughput=dict(read=5, write=5))], - ) + + table = g_dynamodb.create_table( + TableName=tablename, + KeySchema=[ + { + 'AttributeName': 'uaid', + 'KeyType': 'HASH' + } + ], + AttributeDefinitions=[ + { + 'AttributeName': 'uaid', + 'AttributeType': 'S' + }, + { + 'AttributeName': 'last_connect', + 'AttributeType': 'N' + }], + ProvisionedThroughput={ + 'ReadCapacityUnits': read_throughput, + 'WriteCapacityUnits': write_throughput, + }, + GlobalSecondaryIndexes=[ + { + 'IndexName': 'AccessIndex', + 'KeySchema': [ + { + 'AttributeName': "last_connect", + 'KeyType': "HASH" + } + ], + 'Projection': { + 'ProjectionType': 'INCLUDE', + 'NonKeyAttributes': [ + 'last_connect' + ], + }, + 'ProvisionedThroughput': { + 'ReadCapacityUnits': read_throughput, + 'WriteCapacityUnits': write_throughput, + } + } + ] + ) + table.meta.client.get_waiter('table_exists').wait( + TableName=tablename) + try: + table.meta.client.update_time_to_live( + TableName=tablename, + TimeToLiveSpecification={ + 'Enabled': True, + 'AttributeName': 'expiry' + } + ) + except ClientError as ex: # pragma nocover + if ex.response["Error"]["Code"] != "UnknownOperationException": + raise + return table + + +def _drop_table(tablename): + try: + g_client.delete_table(TableName=tablename) + except ClientError: # pragma nocover + pass def _make_table(table_func, tablename, read_throughput, write_throughput): # type: (Callable[[str, int, int], Table], str, int, int) -> Table """Private common function to make a table with a table func""" - if not table_exists(DynamoDBConnection(), tablename): + if not table_exists(tablename): return table_func(tablename, read_throughput, write_throughput) else: - return Table(tablename) + return g_dynamodb.Table(tablename) + + +def _expiry(ttl): + return int(time.time() + ttl) def get_router_table(tablename="router", read_throughput=5, @@ -238,8 +334,7 @@ def preflight_check(message, router, uaid="deadbeef00000000deadbeef00000000"): # Verify tables are ready for use if they just got created ready = False while not ready: - tbl_status = [x.describe()["Table"]["TableStatus"] - for x in [message.table, router.table]] + tbl_status = [x.table_status() for x in [message, router]] ready = all([status == "ACTIVE" for status in tbl_status]) if not ready: time.sleep(1) @@ -333,11 +428,14 @@ def generate_last_connect_values(date): yield int(val) -def list_tables(conn): +def list_tables(client=g_client): """Return a list of the names of all DynamoDB tables.""" start_table = None while True: - result = conn.list_tables(exclusive_start_table_name=start_table) + if start_table: # pragma nocover + result = client.list_tables(ExclusiveStartTableName=start_table) + else: + result = client.list_tables() for table in result.get('TableNames', []): yield table start_table = result.get('LastEvaluatedTableName', None) @@ -345,14 +443,16 @@ def list_tables(conn): break -def table_exists(conn, tablename): +def table_exists(tablename, client=None): """Determine if the specified Table exists""" - return any(tablename == name for name in list_tables(conn)) + if not client: + client = g_client + return tablename in list_tables(client) class Message(object): """Create a Message table abstraction on top of a DynamoDB Table object""" - def __init__(self, table, metrics): + def __init__(self, table, metrics, max_ttl=MAX_EXPIRY): # type: (Table, IMetrics) -> None """Create a new Message object @@ -363,23 +463,29 @@ def __init__(self, table, metrics): """ self.table = table self.metrics = metrics - self.encode = table._encode_keys + self._max_ttl = max_ttl + + def table_status(self): + return self.table.table_status @track_provisioned - def register_channel(self, uaid, channel_id): - # type: (str, str) -> bool + def register_channel(self, uaid, channel_id, ttl=None): + # type: (str, str, int) -> bool """Register a channel for a given uaid""" - conn = self.table.connection - db_key = self.encode({"uaid": hasher(uaid), "chidmessageid": " "}) # Generate our update expression - expr = "ADD chids :channel_id" - expr_values = self.encode({":channel_id": - set([normalize_id(channel_id)])}) - conn.update_item( - self.table.table_name, - db_key, - update_expression=expr, - expression_attribute_values=expr_values, + if ttl is None: + ttl = self._max_ttl + expr_values = { + ":channel_id": set([normalize_id(channel_id)]), + ":expiry": _expiry(ttl) + } + self.table.update_item( + Key={ + 'uaid': hasher(uaid), + 'chidmessageid': ' ', + }, + UpdateExpression='ADD chids :channel_id, expiry :expiry', + ExpressionAttributeValues=expr_values, ) return True @@ -387,23 +493,23 @@ def register_channel(self, uaid, channel_id): def unregister_channel(self, uaid, channel_id, **kwargs): # type: (str, str, **str) -> bool """Remove a channel registration for a given uaid""" - conn = self.table.connection - db_key = self.encode({"uaid": hasher(uaid), "chidmessageid": " "}) expr = "DELETE chids :channel_id" chid = normalize_id(channel_id) - expr_values = self.encode({":channel_id": set([chid])}) - - result = conn.update_item( - self.table.table_name, - db_key, - update_expression=expr, - expression_attribute_values=expr_values, - return_values="UPDATED_OLD", + expr_values = {":channel_id": set([chid])} + + response = self.table.update_item( + Key={ + 'uaid': hasher(uaid), + 'chidmessageid': ' ', + }, + UpdateExpression=expr, + ExpressionAttributeValues=expr_values, + ReturnValues="UPDATED_OLD", ) - chids = result.get('Attributes', {}).get('chids', {}) + chids = response.get('Attributes', {}).get('chids', {}) if chids: try: - return chid in self.table._dynamizer.decode(chids) + return chid in chids except (TypeError, AttributeError): # pragma: nocover pass # if, for some reason, there are no chids defined, return False. @@ -417,23 +523,31 @@ def all_channels(self, uaid): # Note: This only returns the chids associated with the UAID. # Functions that call store_message() would be required to # update that list as well using register_channel() - try: - result = self.table.get_item(consistent=True, - uaid=hasher(uaid), - chidmessageid=" ") - return (True, result["chids"] or set([])) - except ItemNotFound: + result = self.table.get_item( + Key={ + 'uaid': hasher(uaid), + 'chidmessageid': ' ', + }, + ConsistentRead=True + ) + if result['ResponseMetadata']['HTTPStatusCode'] != 200: + return False, set([]) + if 'Item' not in result: return False, set([]) + return True, result['Item'].get("chids", set([])) @track_provisioned def save_channels(self, uaid, channels): # type: (str, Set[str]) -> None """Save out a set of channels""" - self.table.put_item(data=dict( - uaid=hasher(uaid), - chidmessageid=" ", - chids=channels - ), overwrite=True) + self.table.put_item( + Item={ + 'uaid': hasher(uaid), + 'chidmessageid': ' ', + 'chids': channels, + 'expiry': _expiry(self._max_ttl), + }, + ) @track_provisioned def store_message(self, notification): @@ -442,13 +556,17 @@ def store_message(self, notification): item = dict( uaid=hasher(notification.uaid.hex), chidmessageid=notification.sort_key, - data=notification.data, headers=notification.headers, ttl=notification.ttl, timestamp=notification.timestamp, - updateid=notification.update_id + updateid=notification.update_id, + expiry=_expiry(min( + notification.ttl or 0, + self._max_ttl)) ) - self.table.put_item(data=item, overwrite=True) + if notification.data: + item['data'] = notification.data + self.table.put_item(Item=item) @track_provisioned def delete_message(self, notification): @@ -457,16 +575,24 @@ def delete_message(self, notification): if notification.update_id: try: self.table.delete_item( - uaid=hasher(notification.uaid.hex), - chidmessageid=notification.sort_key, - expected={'updateid__eq': notification.update_id}) - except ConditionalCheckFailedException: + Key={ + 'uaid': hasher(notification.uaid.hex), + 'chidmessageid': notification.sort_key + }, + Expected={ + 'updateid': { + 'Exists': True, + 'Value': notification.update_id + } + }) + except ClientError: return False else: self.table.delete_item( - uaid=hasher(notification.uaid.hex), - chidmessageid=notification.sort_key, - ) + Key={ + 'uaid': hasher(notification.uaid.hex), + 'chidmessageid': notification.sort_key, + }) return True @track_provisioned @@ -483,10 +609,13 @@ def fetch_messages( """ # Eagerly fetches all results in the result set. - results = list(self.table.query_2(uaid__eq=hasher(uaid.hex), - chidmessageid__lt="02", - consistent=True, limit=limit)) - + response = self.table.query( + KeyConditionExpression=(Key("uaid").eq(hasher(uaid.hex)) + & Key('chidmessageid').lt('02')), + ConsistentRead=True, + Limit=limit, + ) + results = list(response['Items']) # First extract the position if applicable, slightly higher than 01: # to ensure we don't load any 01 remainders that didn't get deleted # yet @@ -525,11 +654,15 @@ def fetch_timestamp_messages( else: sortkey = "01;" - results = list(self.table.query_2(uaid__eq=hasher(uaid.hex), - chidmessageid__gt=sortkey, - consistent=True, limit=limit)) + response = self.table.query( + KeyConditionExpression=(Key('uaid').eq(hasher(uaid.hex)) + & Key('chidmessageid').gt(sortkey)), + ConsistentRead=True, + Limit=limit + ) notifs = [ - WebPushNotification.from_message_table(uaid, x) for x in results + WebPushNotification.from_message_table(uaid, x) for x in + response.get("Items") ] ts_notifs = [x for x in notifs if x.sortkey_timestamp] last_position = None @@ -541,22 +674,23 @@ def fetch_timestamp_messages( def update_last_message_read(self, uaid, timestamp): # type: (uuid.UUID, int) -> bool """Update the last read timestamp for a user""" - conn = self.table.connection - db_key = self.encode({"uaid": hasher(uaid.hex), "chidmessageid": " "}) - expr = "SET current_timestamp=:timestamp" - expr_values = self.encode({":timestamp": timestamp}) - conn.update_item( - self.table.table_name, - db_key, - update_expression=expr, - expression_attribute_values=expr_values, + expr = "SET current_timestamp=:timestamp, expiry=:expiry" + expr_values = {":timestamp": timestamp, + ":expiry": _expiry(self._max_ttl)} + self.table.update_item( + Key={ + "uaid": hasher(uaid.hex), + "chidmessageid": " " + }, + UpdateExpression=expr, + ExpressionAttributeValues=expr_values, ) return True class Router(object): """Create a Router table abstraction on top of a DynamoDB Table object""" - def __init__(self, table, metrics): + def __init__(self, table, metrics, max_ttl=MAX_EXPIRY): # type: (Table, IMetrics) -> None """Create a new Router object @@ -567,7 +701,10 @@ def __init__(self, table, metrics): """ self.table = table self.metrics = metrics - self.encode = table._encode_keys + self._max_ttl = max_ttl + + def table_status(self): + return self.table.table_status def get_uaid(self, uaid): # type: (str) -> Item @@ -580,18 +717,24 @@ def get_uaid(self, uaid): """ try: - item = self.table.get_item(consistent=True, uaid=hasher(uaid)) + item = self.table.get_item( + Key={ + 'uaid': hasher(uaid) + }, + ConsistentRead=True, + ) + + if item.get('ResponseMetadata').get('HTTPStatusCode') != 200: + raise ItemNotFound('uaid not found') + item = item.get('Item') + if item is None: + raise ItemNotFound("uaid not found") if item.keys() == ['uaid']: # Incomplete record, drop it. self.drop_user(uaid) raise ItemNotFound("uaid not found") return item - except ProvisionedThroughputExceededException: - # We unfortunately have to catch this here, as track_provisioned - # will not see this, since JSONResponseError is a subclass and - # will capture it - raise - except JSONResponseError: # pragma: nocover + except Boto3Error: # pragma: nocover # We trap JSONResponseError because Moto returns text instead of # JSON when looking up values in empty tables. We re-throw the # correct ItemNotFound exception @@ -612,8 +755,7 @@ def register_user(self, data): """ # Fetch a senderid for this user - conn = self.table.connection - db_key = self.encode({"uaid": hasher(data["uaid"])}) + db_key = {"uaid": hasher(data["uaid"])} del data["uaid"] if "router_type" not in data or "connected_at" not in data: # Not specifying these values will generate an exception in AWS. @@ -621,7 +763,7 @@ def register_user(self, data): "or connected_at") # Generate our update expression expr = "SET " + ", ".join(["%s=:%s" % (x, x) for x in data.keys()]) - expr_values = self.encode({":%s" % k: v for k, v in data.items()}) + expr_values = {":%s" % k: v for k, v in data.items()} try: cond = """( attribute_not_exists(router_type) or @@ -630,13 +772,12 @@ def register_user(self, data): attribute_not_exists(node_id) or (connected_at < :connected_at) )""" - result = conn.update_item( - self.table.table_name, - db_key, - update_expression=expr, - condition_expression=cond, - expression_attribute_values=expr_values, - return_values="ALL_OLD", + result = self.table.update_item( + Key=db_key, + UpdateExpression=expr, + ConditionExpression=cond, + ExpressionAttributeValues=expr_values, + ReturnValues="ALL_OLD", ) if "Attributes" in result: r = {} @@ -649,8 +790,13 @@ def register_user(self, data): r[key] = value result = r return (True, result) - except ConditionalCheckFailedException: - return (False, {}) + except ClientError as ex: + # ClientErrors are generated by a factory, and while they have a + # class, it's dynamically generated. + if ex.response['Error']['Code'] == \ + 'ConditionalCheckFailedException': + return (False, {}) + raise @track_provisioned def drop_user(self, uaid): @@ -658,16 +804,26 @@ def drop_user(self, uaid): """Drops a user record""" # The following hack ensures that only uaids that exist and are # deleted return true. - huaid = hasher(uaid) - return self.table.delete_item(uaid=huaid, - expected={"uaid__eq": huaid}) + try: + item = self.table.get_item( + Key={ + 'uaid': hasher(uaid) + }, + ConsistentRead=True, + ) + if 'Item' not in item: + return False + except ClientError: + pass + result = self.table.delete_item(Key={'uaid': hasher(uaid)}) + return result['ResponseMetadata']['HTTPStatusCode'] == 200 def delete_uaids(self, uaids): # type: (List[str]) -> None """Issue a batch delete call for the given uaids""" - with self.table.batch_write() as batch: + with self.table.batch_writer() as batch: for uaid in uaids: - batch.delete_item(uaid=uaid) + batch.delete_item(Key={'uaid': uaid}) def drop_old_users(self, months_ago=2): # type: (int) -> Iterable[int] @@ -699,10 +855,11 @@ def drop_old_users(self, months_ago=2): batched = [] for hash_key in generate_last_connect_values(prior_date): - result_set = self.table.query_2( - last_connect__eq=hash_key, - index="AccessIndex", + response = self.table.query( + KeyConditionExpression=Key("last_connect").eq(hash_key), + IndexName="AccessIndex", ) + result_set = response.get('Items', []) for result in result_set: batched.append(result["uaid"]) @@ -716,6 +873,14 @@ def drop_old_users(self, months_ago=2): self.delete_uaids(batched) yield len(batched) + @track_provisioned + def _update_last_connect(self, uaid, last_connect): + self.table.update_item( + Key={"uaid": hasher(uaid)}, + UpdateExpression="SET last_connect=:last_connect", + ExpressionAttributeValues={":last_connect": last_connect} + ) + @track_provisioned def update_message_month(self, uaid, month): # type: (str, str) -> bool @@ -727,23 +892,23 @@ def update_message_month(self, uaid, month): timestamp. """ - conn = self.table.connection - db_key = self.encode({"uaid": hasher(uaid)}) - expr = ("SET current_month=:curmonth, last_connect=:last_connect") - expr_values = self.encode({":curmonth": month, - ":last_connect": generate_last_connect(), - }) - conn.update_item( - self.table.table_name, - db_key, - update_expression=expr, - expression_attribute_values=expr_values, + db_key = {"uaid": hasher(uaid)} + expr = ("SET current_month=:curmonth, last_connect=:last_connect, " + "expiry=:expiry") + expr_values = {":curmonth": month, + ":last_connect": generate_last_connect(), + ":expiry": _expiry(self._max_ttl), + } + self.table.update_item( + Key=db_key, + UpdateExpression=expr, + ExpressionAttributeValues=expr_values, ) return True @track_provisioned def clear_node(self, item): - # type: (Item) -> bool + # type: (dict) -> bool """Given a router item and remove the node_id The node_id will only be cleared if the ``connected_at`` matches up @@ -755,24 +920,26 @@ def clear_node(self, item): exceeds throughput. """ - conn = self.table.connection # Pop out the node_id node_id = item["node_id"] del item["node_id"] try: cond = "(node_id = :node) and (connected_at = :conn)" - conn.put_item( - self.table.table_name, - item=self.encode(item), - condition_expression=cond, - expression_attribute_values=self.encode({ + self.table.put_item( + Item=item, + ConditionExpression=cond, + ExpressionAttributeValues={ ":node": node_id, ":conn": item["connected_at"], - }), + }, ) return True - except ConditionalCheckFailedException: + except ClientError as ex: + if (ex.response["Error"]["Code"] == + "ProvisionedThroughputExceededException"): + raise + # UAID not found. return False @@ -789,6 +956,8 @@ class DatabaseManager(object): message_tables = attrib(default=Factory(dict)) # type: Dict[str, Message] current_msg_month = attrib(init=False) # type: Optional[str] current_month = attrib(init=False) # type: Optional[int] + # for testing: + client = attrib(default=g_client) # type: Optional[Any] def __attrs_post_init__(self): """Initialize sane defaults""" diff --git a/autopush/tests/__init__.py b/autopush/tests/__init__.py index 9940b8ee..f9ef6b6a 100644 --- a/autopush/tests/__init__.py +++ b/autopush/tests/__init__.py @@ -4,8 +4,11 @@ import subprocess import boto +import botocore +import boto3 import psutil +import autopush.db from autopush.db import create_rotating_message_table here_dir = os.path.abspath(os.path.dirname(__file__)) @@ -18,17 +21,28 @@ def setUp(): logging.getLogger('boto').setLevel(logging.CRITICAL) - boto_path = os.path.join(root_dir, "automock", "boto.cfg") - boto.config.load_from_path(boto_path) global ddb_process cmd = " ".join([ "java", "-Djava.library.path=%s" % ddb_lib_dir, "-jar", ddb_jar, "-sharedDb", "-inMemory" ]) + conf = botocore.config.Config( + region_name=os.getenv('AWS_REGION_NAME', 'us-east-1') + ) ddb_process = subprocess.Popen(cmd, shell=True, env=os.environ) + autopush.db.g_dynamodb = boto3.resource( + 'dynamodb', + config=conf, + endpoint_url=os.getenv("AWS_LOCAL_DYNAMODB", "http://127.0.0.1:8000"), + aws_access_key_id="BogusKey", + aws_secret_access_key="BogusKey", + ) + + autopush.db.g_client = autopush.db.g_dynamodb.meta.client # Setup the necessary message tables message_table = os.environ.get("MESSAGE_TABLE", "message_int_test") + create_rotating_message_table(prefix=message_table, delta=-1) create_rotating_message_table(prefix=message_table) diff --git a/autopush/tests/test_db.py b/autopush/tests/test_db.py index cd6c71df..920565dc 100644 --- a/autopush/tests/test_db.py +++ b/autopush/tests/test_db.py @@ -4,12 +4,9 @@ from autopush.websocket import ms_time from boto.dynamodb2.exceptions import ( - ConditionalCheckFailedException, - ProvisionedThroughputExceededException, ItemNotFound, ) -from boto.dynamodb2.layer1 import DynamoDBConnection -from boto.dynamodb2.items import Item +from botocore.exceptions import ClientError from mock import Mock import pytest @@ -22,7 +19,9 @@ Message, Router, generate_last_connect, - make_rotating_tablename) + make_rotating_tablename, + _drop_table, + _make_table) from autopush.exceptions import AutopushException from autopush.metrics import SinkMetrics from autopush.utils import WebPushNotification @@ -43,6 +42,15 @@ def make_webpush_notification(uaid, chid, ttl=100): ) +class DbUtilsTest(unittest.TestCase): + def test_make_table(self): + fake_func = Mock() + fake_table = "DoesNotExist_{}".format(uuid.uuid4()) + + _make_table(fake_func, fake_table, 5, 10) + assert fake_func.call_args[0] == (fake_table, 5, 10) + + class DbCheckTestCase(unittest.TestCase): def test_preflight_check_fail(self): router = Router(get_router_table(), SinkMetrics()) @@ -73,17 +81,9 @@ def test_preflight_check_wait(self): router = Router(get_router_table(), SinkMetrics()) message = Message(get_rotating_message_table(), SinkMetrics()) - message.table.describe = mock_describe = Mock() - - values = [ - dict(Table=dict(TableStatus="PENDING")), - dict(Table=dict(TableStatus="ACTIVE")), - ] - - def return_vals(*args, **kwargs): - return values.pop(0) + values = ["PENDING", "ACTIVE"] + message.table_status = Mock(side_effect=values) - mock_describe.side_effect = return_vals pf_uaid = "deadbeef00000000deadbeef01010101" preflight_check(message, router, pf_uaid) # now check that the database reports no entries. @@ -128,11 +128,10 @@ class MessageTestCase(unittest.TestCase): def setUp(self): table = get_rotating_message_table() self.real_table = table - self.real_connection = table.connection self.uaid = str(uuid.uuid4()) def tearDown(self): - self.real_table.connection = self.real_connection + pass def test_register(self): chid = str(uuid.uuid4()) @@ -140,10 +139,21 @@ def test_register(self): message = Message(m, SinkMetrics()) message.register_channel(self.uaid, chid) - # Verify its in the db - rows = m.query_2(uaid__eq=self.uaid, chidmessageid__eq=" ") - results = list(rows) - assert len(results) == 1 + # Verify it's in the db + response = m.query( + KeyConditions={ + 'uaid': { + 'AttributeValueList': [self.uaid], + 'ComparisonOperator': 'EQ' + }, + 'chidmessageid': { + 'AttributeValueList': ['02'], + 'ComparisonOperator': 'LT' + } + }, + ConsistentRead=True, + ) + assert len(response.get('Items')) def test_unregister(self): chid = str(uuid.uuid4()) @@ -152,24 +162,48 @@ def test_unregister(self): message.register_channel(self.uaid, chid) # Verify its in the db - rows = m.query_2(uaid__eq=self.uaid, chidmessageid__eq=" ") - results = list(rows) + response = m.query( + KeyConditions={ + 'uaid': { + 'AttributeValueList': [self.uaid], + 'ComparisonOperator': 'EQ' + }, + 'chidmessageid': { + 'AttributeValueList': [" "], + 'ComparisonOperator': 'EQ' + }, + }, + ConsistentRead=True, + ) + results = list(response.get('Items')) assert len(results) == 1 assert results[0]["chids"] == {chid} message.unregister_channel(self.uaid, chid) # Verify its not in the db - rows = m.query_2(uaid__eq=self.uaid, chidmessageid__eq=" ") - results = list(rows) + response = m.query( + KeyConditions={ + 'uaid': { + 'AttributeValueList': [self.uaid], + 'ComparisonOperator': 'EQ' + }, + 'chidmessageid': { + 'AttributeValueList': [" "], + 'ComparisonOperator': 'EQ' + }, + }, + ConsistentRead=True, + ) + results = list(response.get('Items')) assert len(results) == 1 - assert results[0]["chids"] is None + assert results[0].get("chids") is None # Test for the very unlikely case that there's no 'chid' - m.connection.update_item = Mock() - m.connection.update_item.return_value = { - 'Attributes': {'uaid': {'S': self.uaid}}, - 'ConsumedCapacityUnits': 0.5} + m.update_item = Mock(return_value={ + 'Attributes': {'uaid': self.uaid}, + 'ResponseMetaData': {} + }) r = message.unregister_channel(self.uaid, dummy_chid) assert r is False @@ -190,6 +224,20 @@ def test_all_channels(self): assert chid2 not in chans assert chid in chans + def test_all_channels_fail(self): + m = get_rotating_message_table() + message = Message(m, SinkMetrics()) + + message.table.get_item = Mock() + message.table.get_item.return_value = { + "ResponseMetadata": { + "HTTPStatusCode": 400 + }, + } + + res = message.all_channels(self.uaid) + assert res == (False, set([])) + def test_save_channels(self): chid = str(uuid.uuid4()) chid2 = str(uuid.uuid4()) @@ -254,7 +302,7 @@ def test_message_delete_fail_condition(self): message = Message(m, SinkMetrics()) def raise_condition(*args, **kwargs): - raise ConditionalCheckFailedException(None, None) + raise ClientError({}, 'delete_item') message.table = Mock() message.table.delete_item.side_effect = raise_condition @@ -268,6 +316,8 @@ def test_message_rotate_table_with_date(self): m = get_rotating_message_table(prefix=prefix, date=future) assert m.table_name == tbl_name + # Clean up the temp table. + _drop_table(tbl_name) class RouterTestCase(unittest.TestCase): @@ -275,11 +325,11 @@ class RouterTestCase(unittest.TestCase): def setup_class(self): table = get_router_table() self.real_table = table - self.real_connection = table.connection + self.real_connection = table.meta.client @classmethod def teardown_class(self): - self.real_table.connection = self.real_connection + self.real_table.meta.client = self.real_connection def _create_minimal_record(self): data = { @@ -294,6 +344,8 @@ def test_drop_old_users(self): # First create a bunch of users r = get_router_table() router = Router(r, SinkMetrics()) + # Purge any existing users from previous runs. + router.drop_old_users(0) for _ in range(0, 53): router.register_user(self._create_minimal_record()) @@ -301,18 +353,19 @@ def test_drop_old_users(self): assert list(results) == [25, 25, 3] def test_custom_tablename(self): - db = DynamoDBConnection() db_name = "router_%s" % uuid.uuid4() - assert not table_exists(db, db_name) + assert not table_exists(db_name) create_router_table(db_name) - assert table_exists(db, db_name) + assert table_exists(db_name) + # Clean up the temp table. + _drop_table(db_name) def test_provisioning(self): db_name = "router_%s" % uuid.uuid4() r = create_router_table(db_name, 3, 17) - assert r.throughput["read"] == 3 - assert r.throughput["write"] == 17 + assert r.provisioned_throughput.get('ReadCapacityUnits') == 3 + assert r.provisioned_throughput.get('WriteCapacityUnits') == 17 def test_no_uaid_found(self): uaid = str(uuid.uuid4()) @@ -326,41 +379,94 @@ def test_uaid_provision_failed(self): router = Router(r, SinkMetrics()) router.table = Mock() - def raise_error(*args, **kwargs): - raise ProvisionedThroughputExceededException(None, None) - - router.table.get_item.side_effect = raise_error - with pytest.raises(ProvisionedThroughputExceededException): + def raise_condition(*args, **kwargs): + import autopush.db + raise autopush.db.g_client.exceptions.ClientError( + {'Error': {'Code': 'ProvisionedThroughputExceededException'}}, + 'mock_update_item' + ) + + router.table.get_item.side_effect = raise_condition + with pytest.raises(ClientError) as ex: router.get_uaid(uaid="asdf") + assert (ex.value.response['Error']['Code'] == + "ProvisionedThroughputExceededException") def test_register_user_provision_failed(self): r = get_router_table() router = Router(r, SinkMetrics()) - router.table.connection = Mock() - - def raise_error(*args, **kwargs): - raise ProvisionedThroughputExceededException(None, None) + router.table.meta.client = Mock() - router.table.connection.update_item.side_effect = raise_error - with pytest.raises(ProvisionedThroughputExceededException): + def raise_condition(*args, **kwargs): + import autopush.db + raise autopush.db.g_client.exceptions.ClientError( + {'Error': {'Code': 'ProvisionedThroughputExceededException'}}, + 'mock_update_item' + ) + + router.table.update_item = Mock(side_effect=raise_condition) + with pytest.raises(ClientError) as ex: router.register_user(dict(uaid=dummy_uaid, node_id="me", connected_at=1234, router_type="webpush")) + assert (ex.value.response['Error']['Code'] == + "ProvisionedThroughputExceededException") - def test_clear_node_provision_failed(self): + def test_register_user_condition_failed(self): r = get_router_table() router = Router(r, SinkMetrics()) - router.table.connection.put_item = Mock() + router.table.meta.client = Mock() def raise_error(*args, **kwargs): - raise ProvisionedThroughputExceededException(None, None) + import autopush.db + raise autopush.db.g_client.exceptions.ClientError( + {'Error': {'Code': 'ConditionalCheckFailedException'}}, + 'mock_update_item' + ) + + router.table.update_item = Mock(side_effect=raise_error) + res = router.register_user(dict(uaid=dummy_uaid, node_id="me", + connected_at=1234, + router_type="webpush")) + assert res == (False, {}) - router.table.connection.put_item.side_effect = raise_error - with pytest.raises(ProvisionedThroughputExceededException): - router.clear_node(Item(r, dict(uaid=dummy_uaid, - connected_at="1234", - node_id="asdf", - router_type="webpush"))) + def test_clear_node_provision_failed(self): + r = get_router_table() + router = Router(r, SinkMetrics()) + + def raise_condition(*args, **kwargs): + import autopush.db + raise autopush.db.g_client.exceptions.ClientError( + {'Error': {'Code': 'ProvisionedThroughputExceededException'}}, + 'mock_update_item' + ) + + router.table.put_item = Mock(side_effect=raise_condition) + with pytest.raises(ClientError) as ex: + router.clear_node(dict(uaid=dummy_uaid, + connected_at="1234", + node_id="asdf", + router_type="webpush")) + assert (ex.value.response['Error']['Code'] == + "ProvisionedThroughputExceededException") + + def test_clear_node_condition_failed(self): + r = get_router_table() + router = Router(r, SinkMetrics()) + + def raise_error(*args, **kwargs): + import autopush.db + raise autopush.db.g_client.exceptions.ClientError( + {'Error': {'Code': 'ConditionalCheckFailedException'}}, + 'mock_put_item' + ) + + router.table.put_item = Mock(side_effect=raise_error) + res = router.clear_node(dict(uaid=dummy_uaid, + connected_at="1234", + node_id="asdf", + router_type="webpush")) + assert res is False def test_incomplete_uaid(self): # Older records may be incomplete. We can't inject them using normal @@ -370,7 +476,14 @@ def test_incomplete_uaid(self): router = Router(r, SinkMetrics()) router.table.get_item = Mock() router.drop_user = Mock() - router.table.get_item.return_value = {"uaid": uuid.uuid4().hex} + router.table.get_item.return_value = { + "ResponseMetadata": { + "HTTPStatusCode": 200 + }, + "Item": { + "uaid": uuid.uuid4().hex + } + } try: router.register_user(dict(uaid=uaid)) except AutopushException: @@ -379,14 +492,28 @@ def test_incomplete_uaid(self): router.get_uaid(uaid) assert router.drop_user.called + def test_failed_uaid(self): + uaid = str(uuid.uuid4()) + r = get_router_table() + router = Router(r, SinkMetrics()) + router.table.get_item = Mock() + router.drop_user = Mock() + router.table.get_item.return_value = { + "ResponseMetadata": { + "HTTPStatusCode": 400 + }, + } + with pytest.raises(ItemNotFound): + router.get_uaid(uaid) + def test_save_new(self): r = get_router_table() router = Router(r, SinkMetrics()) # Sadly, moto currently does not return an empty value like boto # when not updating data. - router.table.connection = Mock() - router.table.connection.update_item.return_value = {} - result = router.register_user(dict(uaid="", node_id="me", + router.table.update_item = Mock(return_value={}) + result = router.register_user(dict(uaid=dummy_uaid, + node_id="me", router_type="webpush", connected_at=1234)) assert result[0] is True @@ -396,10 +523,13 @@ def test_save_fail(self): router = Router(r, SinkMetrics()) def raise_condition(*args, **kwargs): - raise ConditionalCheckFailedException(None, None) + import autopush.db + raise autopush.db.g_client.exceptions.ClientError( + {'Error': {'Code': 'ConditionalCheckFailedException'}}, + 'mock_update_item' + ) - router.table.connection = Mock() - router.table.connection.update_item.side_effect = raise_condition + router.table.update_item = Mock(side_effect=raise_condition) router_data = dict(uaid=dummy_uaid, node_id="asdf", connected_at=1234, router_type="webpush") result = router.register_user(router_data) @@ -433,12 +563,14 @@ def test_node_clear_fail(self): router = Router(r, SinkMetrics()) def raise_condition(*args, **kwargs): - raise ConditionalCheckFailedException(None, None) + raise ClientError( + {'Error': {'Code': 'ConditionalCheckFailedException'}}, + 'mock_update_item' + ) - router.table.connection.put_item = Mock() - router.table.connection.put_item.side_effect = raise_condition + router.table.put_item = Mock(side_effect=raise_condition) data = dict(uaid=dummy_uaid, node_id="asdf", connected_at=1234) - result = router.clear_node(Item(r, data)) + result = router.clear_node(data) assert result is False def test_drop_user(self): diff --git a/autopush/tests/test_endpoint.py b/autopush/tests/test_endpoint.py index 4a8fdb41..fdf38cc9 100644 --- a/autopush/tests/test_endpoint.py +++ b/autopush/tests/test_endpoint.py @@ -11,7 +11,6 @@ import autopush.utils as utils from autopush.config import AutopushConfig from autopush.db import ( - ProvisionedThroughputExceededException, Message, ItemNotFound, has_connected_this_month, @@ -121,11 +120,18 @@ def test_delete_topic_error_parts(self): def test_delete_db_error(self): tok = ":".join(["m", dummy_uaid.hex, str(dummy_chid)]) self.fernet_mock.decrypt.return_value = tok - self.message_mock.configure_mock(**{ - "delete_message.side_effect": - ProvisionedThroughputExceededException(None, None)}) + + def raise_condition(*args, **kwargs): + import autopush.db + raise autopush.db.g_client.exceptions.ClientError( + {'Error': {'Code': 'ConditionalCheckFailedException'}}, + 'mock_update_item' + ) + + self.message_mock.configure_mock(**{"delete_message.return_value": + False}) resp = yield self.client.delete(self.url(message_id="ignored")) - assert resp.get_status() == 503 + assert 204 == resp.get_status() class RegistrationTestCase(unittest.TestCase): diff --git a/autopush/tests/test_health.py b/autopush/tests/test_health.py index df47605c..3e96896f 100644 --- a/autopush/tests/test_health.py +++ b/autopush/tests/test_health.py @@ -7,6 +7,7 @@ from twisted.logger import globalLogPublisher from twisted.trial import unittest +import autopush.db from autopush import __version__ from autopush.config import AutopushConfig from autopush.db import DatabaseManager @@ -27,7 +28,9 @@ def setUp(self): hostname="localhost", statsd_host=None, ) + db = DatabaseManager.from_config(conf) + db.client = autopush.db.g_client db.setup_tables() # ignore logging @@ -52,32 +55,36 @@ def test_healthy(self): @inlineCallbacks def test_aws_error(self): - from autopush.db import make_rotating_tablename def raise_error(*args, **kwargs): raise InternalServerError(None, None) - self.router_table.connection.list_tables = Mock( - side_effect=raise_error) - table_name = make_rotating_tablename("message") - self.message.table.connection.list_tables = Mock( - return_value={"TableNames": [table_name]}) + + safe = self.client.app.db.client + self.client.app.db.client = Mock() + self.client.app.db.client.list_tables = Mock(side_effect=raise_error) yield self._assert_reply({ "status": "NOT OK", "version": __version__, "clients": 0, - "storage": {"status": "OK"}, + "storage": { + "status": "NOT OK", + "error": "Server error" + }, "router": { "status": "NOT OK", "error": "Server error" } }, InternalServerError) + self.client.app.db.client = safe + @inlineCallbacks def test_nonexistent_table(self): no_tables = Mock(return_value={"TableNames": []}) - self.message.table.connection.list_tables = no_tables - self.router_table.connection.list_tables = no_tables + safe = self.client.app.db.client + self.client.app.db.client = Mock() + self.client.app.db.client.list_tables = no_tables yield self._assert_reply({ "status": "NOT OK", @@ -92,15 +99,18 @@ def test_nonexistent_table(self): "error": "Nonexistent table" } }, MissingTableException) + self.client.app.db.client = safe @inlineCallbacks def test_internal_error(self): def raise_error(*args, **kwargs): raise Exception("synergies not aligned") - self.router_table.connection.list_tables = Mock( - return_value={"TableNames": ["router"]}) - self.message.table.connection.list_tables = Mock( - side_effect=raise_error) + + safe = self.client.app.db.client + self.client.app.db.client = Mock() + self.client.app.db.client.list_tables = Mock( + side_effect=raise_error + ) yield self._assert_reply({ "status": "NOT OK", @@ -110,8 +120,12 @@ def raise_error(*args, **kwargs): "status": "NOT OK", "error": "Internal error" }, - "router": {"status": "OK"} + "router": { + "status": "NOT OK", + "error": "Internal error" + } }, Exception) + self.client.app.db.client = safe @inlineCallbacks def _assert_reply(self, reply, exception=None): diff --git a/autopush/tests/test_integration.py b/autopush/tests/test_integration.py index b25fa39e..4259dd61 100644 --- a/autopush/tests/test_integration.py +++ b/autopush/tests/test_integration.py @@ -320,12 +320,14 @@ def setUp(self): # Endpoint HTTP router self.ep = ep = EndpointApplication(ep_conf) + ep.db.client = db.g_client ep.setup(rotate_tables=False) ep.startService() self.addCleanup(ep.stopService) # Websocket server self.conn = conn = ConnectionApplication(conn_conf) + conn.db.client = db.g_client conn.setup(rotate_tables=False) conn.startService() self.addCleanup(conn.stopService) @@ -1167,9 +1169,15 @@ def test_webpush_monthly_rotation(self): # Verify last_connect is current, then move that back assert has_connected_this_month(c) today = get_month(delta=-1) - c["last_connect"] = int("%s%s020001" % (today.year, - str(today.month).zfill(2))) - yield deferToThread(c.partial_save) + last_connect = int("%s%s020001" % (today.year, + str(today.month).zfill(2))) + + yield deferToThread( + self.conn.db.router._update_last_connect, + client.uaid, + last_connect, + ) + c = yield deferToThread(self.conn.db.router.get_uaid, client.uaid) assert has_connected_this_month(c) is False # Move the clients channels back one month @@ -1187,9 +1195,8 @@ def test_webpush_monthly_rotation(self): # Remove the channels entry entirely from this month yield deferToThread( self.conn.db.message.table.delete_item, - uaid=client.uaid, - chidmessageid=" " - ) + Key={'uaid': client.uaid, 'chidmessageid': ' '} + ) # Verify the channel is gone exists, chans = yield deferToThread( @@ -1278,9 +1285,12 @@ def test_webpush_monthly_rotation_prior_record_exists(self): # Verify last_connect is current, then move that back assert has_connected_this_month(c) today = get_month(delta=-1) - c["last_connect"] = int("%s%s020001" % (today.year, - str(today.month).zfill(2))) - yield deferToThread(c.partial_save) + yield deferToThread( + self.conn.db.router._update_last_connect, + client.uaid, + int("%s%s020001" % (today.year, str(today.month).zfill(2))) + ) + c = yield deferToThread(self.conn.db.router.get_uaid, client.uaid) assert has_connected_this_month(c) is False # Move the clients channels back one month diff --git a/autopush/tests/test_web_base.py b/autopush/tests/test_web_base.py index 8a5620d3..dc56255d 100644 --- a/autopush/tests/test_web_base.py +++ b/autopush/tests/test_web_base.py @@ -2,6 +2,7 @@ import uuid from boto.exception import BotoServerError +from botocore.exceptions import ClientError from mock import Mock, patch from twisted.internet.defer import Deferred from twisted.logger import Logger @@ -9,7 +10,6 @@ from twisted.trial import unittest from autopush.config import AutopushConfig -from autopush.db import ProvisionedThroughputExceededException from autopush.http import EndpointHTTPFactory from autopush.exceptions import InvalidRequest from autopush.metrics import SinkMetrics @@ -195,12 +195,30 @@ def test_response_err(self): def test_overload_err(self): try: - raise ProvisionedThroughputExceededException("error", None, None) - except ProvisionedThroughputExceededException: + import autopush.db + raise autopush.db.g_client.exceptions.ClientError( + {'Error': { + 'Code': 'ProvisionedThroughputExceededException'}}, + 'mock_update_item' + ) + except ClientError: fail = Failure() self.base._overload_err(fail) self.status_mock.assert_called_with(503, reason=None) + def test_client_err(self): + try: + import autopush.db + raise autopush.db.g_client.exceptions.ClientError( + {'Error': { + 'Code': 'Flibbertygidgit'}}, + 'mock_update_item' + ) + except ClientError: + fail = Failure() + self.base._overload_err(fail) + self.status_mock.assert_called_with(500, reason=None) + def test_boto_err(self): try: raise BotoServerError(503, "derp") diff --git a/autopush/tests/test_webpush_server.py b/autopush/tests/test_webpush_server.py index 324bb3fe..4df8c042 100644 --- a/autopush/tests/test_webpush_server.py +++ b/autopush/tests/test_webpush_server.py @@ -7,7 +7,6 @@ import attr import factory from boto.dynamodb2.exceptions import ItemNotFound -from boto.dynamodb2.exceptions import ProvisionedThroughputExceededException from mock import Mock from twisted.logger import globalLogPublisher import pytest @@ -499,8 +498,16 @@ def test_register_bad_chid_nodash(self): self._test_invalid(uuid4().hex) def test_register_over_provisioning(self): - self.db.message.register_channel = Mock( - side_effect=ProvisionedThroughputExceededException(None, None)) + + def raise_condition(*args, **kwargs): + import autopush.db + raise autopush.db.g_client.exceptions.ClientError( + {'Error': {'Code': 'ProvisionedThroughputExceededException'}}, + 'mock_update_item' + ) + + self.db.message.table.update_item = Mock( + side_effect=raise_condition) self._test_invalid(str(uuid4()), "overloaded", 503) diff --git a/autopush/tests/test_websocket.py b/autopush/tests/test_websocket.py index 85681e35..9f973dad 100644 --- a/autopush/tests/test_websocket.py +++ b/autopush/tests/test_websocket.py @@ -8,10 +8,6 @@ import twisted.internet.base from autobahn.twisted.util import sleep from autobahn.websocket.protocol import ConnectionRequest -from boto.dynamodb2.exceptions import ( - ProvisionedThroughputExceededException, -) -from boto.exception import JSONResponseError from mock import Mock, patch import pytest from twisted.internet import reactor @@ -166,7 +162,8 @@ def get_response(self): """ calls = self.send_mock.call_args_list - yield self._wait_for(lambda: len(calls)) + yield self._wait_for(lambda: len(calls), + duration=4000) args = calls.pop(0) msg = args[0][0] returnValue(json.loads(msg)) @@ -454,12 +451,14 @@ def test_hello_old(self): "current_month": msg_date, } router = self.proto.db.router - router.table.put_item(data=dict( - uaid=orig_uaid, - connected_at=ms_time(), - current_month=msg_date, - router_type="webpush" - )) + router.table.put_item( + Item=dict( + uaid=orig_uaid, + connected_at=ms_time(), + current_month=msg_date, + router_type="webpush" + ) + ) def fake_msg(data): return (True, msg_data) @@ -600,11 +599,15 @@ def fake_msg(data): mock_patch = patch_range.start() mock_patch.return_value = 1 - def raise_error(*args): - raise ProvisionedThroughputExceededException(None, None) + def raise_condition(*args, **kwargs): + import autopush.db + raise autopush.db.g_client.exceptions.ClientError( + {'Error': {'Code': 'ProvisionedThroughputExceededException'}}, + 'mock_update_item' + ) self.proto.db.router.update_message_month = MockAssist([ - raise_error, + raise_condition, Mock(), ]) @@ -708,7 +711,7 @@ def test_hello_failure(self): self._connect() # Fail out the register_user call router = self.proto.db.router - router.table.connection.update_item = Mock(side_effect=KeyError) + router.table.update_item = Mock(side_effect=KeyError) self._send_message(dict(messageType="hello", channelIDs=[], use_webpush=True, stop=1)) @@ -723,31 +726,15 @@ def test_hello_provisioned_during_check(self): self.proto.randrange = Mock(return_value=0.1) # Fail out the register_user call - def throw_error(*args, **kwargs): - raise ProvisionedThroughputExceededException(None, None) - - router = self.proto.db.router - router.table.connection.update_item = Mock(side_effect=throw_error) - - self._send_message(dict(messageType="hello", use_webpush=True, - channelIDs=[])) - msg = yield self.get_response() - assert msg["status"] == 503 - assert msg["reason"] == "error - overloaded" - self.flushLoggedErrors() - - @inlineCallbacks - def test_hello_jsonresponseerror(self): - self._connect() - - self.proto.randrange = Mock() - self.proto.randrange.return_value = 0.1 - - def throw_error(*args, **kwargs): - raise JSONResponseError(None, None) + def raise_condition(*args, **kwargs): + import autopush.db + raise autopush.db.g_client.exceptions.ClientError( + {'Error': {'Code': 'ProvisionedThroughputExceededException'}}, + 'mock_update_item' + ) router = self.proto.db.router - router.table.connection.update_item = Mock(side_effect=throw_error) + router.table.update_item = Mock(side_effect=raise_condition) self._send_message(dict(messageType="hello", use_webpush=True, channelIDs=[])) @@ -1061,10 +1048,14 @@ def test_register_over_provisioning(self): self.proto.ps.uaid = uuid.uuid4().hex self.proto.db.message.register_channel = register = Mock() - def throw_provisioned(*args, **kwargs): - raise ProvisionedThroughputExceededException(None, None) + def raise_condition(*args, **kwargs): + import autopush.db + raise autopush.db.g_client.exceptions.ClientError( + {'Error': {'Code': 'ProvisionedThroughputExceededException'}}, + 'mock_update_item' + ) - register.side_effect = throw_provisioned + register.side_effect = raise_condition yield self.proto.process_register(dict(channelID=chid)) assert self.proto.db.message.register_channel.called @@ -1374,12 +1365,13 @@ def wait(result): return d def test_process_notifications_provision_err(self): - from boto.dynamodb2.exceptions import ( - ProvisionedThroughputExceededException - ) - def throw(*args, **kwargs): - raise ProvisionedThroughputExceededException(500, "whoops") + def raise_condition(*args, **kwargs): + import autopush.db + raise autopush.db.g_client.exceptions.ClientError( + {'Error': {'Code': 'ProvisionedThroughputExceededException'}}, + 'mock_update_item' + ) twisted.internet.base.DelayedCall.debug = True self._connect() @@ -1387,9 +1379,9 @@ def throw(*args, **kwargs): return_value=(None, []) ) self.proto.db.message.fetch_messages = Mock( - side_effect=throw) + side_effect=raise_condition) self.proto.db.message.fetch_timestamp_messages = Mock( - side_effect=throw) + side_effect=raise_condition) self.proto.deferToLater = Mock() self.proto.ps.uaid = uuid.uuid4().hex @@ -1402,7 +1394,7 @@ def throw(*args, **kwargs): notif_d.addErrback(lambda x: d.errback(x)) def wait(result): - assert self.proto.deferToLater.called + assert self.proto.deferToLater.called, "Defer not called" d.callback(True) self.proto.ps._notification_fetch.addCallback(wait) diff --git a/autopush/utils.py b/autopush/utils.py index 66c35ff7..84d094e5 100644 --- a/autopush/utils.py +++ b/autopush/utils.py @@ -1,4 +1,3 @@ -"""A small collection of Autopush utility functions""" import base64 import hashlib import hmac @@ -475,7 +474,6 @@ def from_message_table(cls, uaid, item): key_info = cls.parse_sort_key(item["chidmessageid"]) if key_info["api_ver"] in ["01", "02"]: key_info["message_id"] = item["updateid"] - notif = cls( uaid=uaid, channel_id=uuid.UUID(key_info["channel_id"]), @@ -505,9 +503,13 @@ def from_webpush_request_schema(cls, data, fernet, legacy=False): """ sub = data["subscription"] notif = cls( - uaid=sub["uaid"], channel_id=sub["chid"], data=data["body"], - headers=data["headers"], ttl=data["headers"]["ttl"], - topic=data["headers"]["topic"], legacy=legacy, + uaid=sub["uaid"], + channel_id=sub["chid"], + data=data["body"], + headers=data["headers"], + ttl=data["headers"]["ttl"], + topic=data["headers"]["topic"], + legacy=legacy, ) if notif.data: diff --git a/autopush/web/base.py b/autopush/web/base.py index 00b1d0dd..3b7055ae 100644 --- a/autopush/web/base.py +++ b/autopush/web/base.py @@ -2,8 +2,8 @@ import time from functools import wraps -from boto.dynamodb2.exceptions import ProvisionedThroughputExceededException from boto.exception import BotoServerError +from botocore.exceptions import ClientError from marshmallow.schema import UnmarshalResult # noqa from typing import ( # noqa Any, @@ -247,11 +247,18 @@ def _response_err(self, fail): def _overload_err(self, fail): """errBack for throughput provisioned exceptions""" - fail.trap(ProvisionedThroughputExceededException) - self.log.debug(format="Throughput Exceeded", status_code=503, - errno=201, client_info=self._client_info) - self._write_response(503, 201, - message="Please slow message send rate") + fail.trap(ClientError) + if (fail.value.response['Error']['Code'] == + "ProvisionedThroughputExceededException"): + self.log.debug(format="Throughput Exceeded", status_code=503, + errno=201, client_info=self._client_info) + self._write_response(503, 201, + message="Please slow message send rate") + return + self.log.debug(format="Unhandled Client Error: {}".format( + json.dumps(fail.value.response)), status_code=500, + client_info=self._client_info) + self._write_response(500, 999, message="Unexpected Error") def _boto_err(self, fail): """errBack for random boto exceptions""" diff --git a/autopush/web/health.py b/autopush/web/health.py index 2bd5fe4f..d28265b2 100644 --- a/autopush/web/health.py +++ b/autopush/web/health.py @@ -40,7 +40,7 @@ def get(self): def _check_table(self, table, name_over=None): """Checks the tables known about in DynamoDB""" - d = deferToThread(table_exists, table.connection, table.table_name) + d = deferToThread(table_exists, table.table_name, self.db.client) d.addCallback(self._check_success, name_over or table.table_name) d.addErrback(self._check_error, name_over or table.table_name) return d diff --git a/autopush/web/webpush.py b/autopush/web/webpush.py index ac8c29d2..ab387f43 100644 --- a/autopush/web/webpush.py +++ b/autopush/web/webpush.py @@ -28,7 +28,7 @@ from autopush.crypto_key import CryptoKey from autopush.db import DatabaseManager # noqa from autopush.metrics import Metrics, make_tags # noqa -from autopush.db import dump_uaid, hasher +from autopush.db import hasher from autopush.exceptions import ( InvalidRequest, InvalidTokenException, @@ -93,7 +93,7 @@ def validate_uaid_month_and_chid(self, d): if router_type not in VALID_ROUTER_TYPES: self.context["log"].debug(format="Dropping User", code=102, uaid_hash=hasher(result["uaid"]), - uaid_record=dump_uaid(result)) + uaid_record=repr(result)) self.context["metrics"].increment("updates.drop_user", tags=make_tags(errno=102)) self.context["db"].router.drop_user(result["uaid"]) @@ -134,9 +134,8 @@ def _validate_webpush(self, d, result): if 'current_month' not in result: log.debug(format="Dropping User", code=102, uaid_hash=hasher(uaid), - uaid_record=dump_uaid(result)) - metrics.increment("updates.drop_user", - tags=make_tags(errno=102)) + uaid_record=repr(result)) + metrics.increment("updates.drop_user", tags=make_tags(errno=102)) db.router.drop_user(uaid) raise InvalidRequest("No such subscription", status_code=410, errno=106) @@ -145,9 +144,8 @@ def _validate_webpush(self, d, result): if month_table not in db.message_tables: log.debug(format="Dropping User", code=103, uaid_hash=hasher(uaid), - uaid_record=dump_uaid(result)) - metrics.increment("updates.drop_user", - tags=make_tags(errno=103)) + uaid_record=repr(result)) + metrics.increment("updates.drop_user", tags=make_tags(errno=103)) db.router.drop_user(uaid) raise InvalidRequest("No such subscription", status_code=410, errno=106) @@ -533,7 +531,7 @@ def _router_completed(self, response, uaid_data, warning="", # this record. self.log.debug(format="Dropping User", code=100, uaid_hash=hasher(uaid_data["uaid"]), - uaid_record=dump_uaid(uaid_data), + uaid_record=repr(uaid_data), client_info=self._client_info) d = deferToThread(self.db.router.drop_user, uaid_data["uaid"]) d.addCallback(lambda x: self._router_response(response, diff --git a/autopush/webpush_server.py b/autopush/webpush_server.py index f74a7e94..b7a5fda6 100644 --- a/autopush/webpush_server.py +++ b/autopush/webpush_server.py @@ -10,7 +10,7 @@ attrib, ) from boto.dynamodb2.exceptions import ItemNotFound -from boto.exception import JSONResponseError +from botocore.exceptions import ClientError from typing import ( # noqa Dict, List, @@ -570,8 +570,11 @@ def process(self, command): message = self.db.message_tables[command.message_month] try: message.register_channel(command.uaid.hex, command.channel_id) - except JSONResponseError: - return RegisterErrorResponse(error_msg="overloaded", status=503) + except ClientError as ex: + if (ex.response['Error']['Code'] == + "ProvisionedThroughputExceededException"): + return RegisterErrorResponse(error_msg="overloaded", + status=503) self.metrics.increment('ua.command.register') log.info( diff --git a/autopush/websocket.py b/autopush/websocket.py index 3bb5784b..a3de965a 100644 --- a/autopush/websocket.py +++ b/autopush/websocket.py @@ -49,10 +49,9 @@ ) from autobahn.websocket.protocol import ConnectionRequest # noqa from boto.dynamodb2.exceptions import ( - ProvisionedThroughputExceededException, ItemNotFound ) -from boto.exception import JSONResponseError +from botocore.exceptions import ClientError from twisted.internet import reactor from twisted.internet.defer import ( Deferred, @@ -89,7 +88,6 @@ has_connected_this_month, hasher, generate_last_connect, - dump_uaid, ) from autopush.db import DatabaseManager, Message # noqa from autopush.exceptions import MessageOverloadException @@ -665,7 +663,7 @@ def err_overload(self, failure, message_type, disconnect=True): :param disconnect: Whether the client should be disconnected or not. """ - failure.trap(JSONResponseError) + failure.trap(ClientError) if disconnect: self.transport.pauseProducing() @@ -673,6 +671,9 @@ def err_overload(self, failure, message_type, disconnect=True): self.err_finish_overload, message_type) d.addErrback(self.trap_cancel) else: + if (failure.value.response["Error"]["Code"] != + "ProvisionedThroughputExceededException"): + return failure # pragma nocover send = {"messageType": "error", "reason": "overloaded", "status": 503} self.sendJSON(send) @@ -765,7 +766,7 @@ def _verify_user_record(self): if "router_type" not in record or "connected_at" not in record: self.log.debug(format="Dropping User", code=104, uaid_hash=self.ps.uaid_hash, - uaid_record=dump_uaid(record)) + uaid_record=repr(record)) tags = ['code:104'] self.metrics.increment("ua.expiration", tags=tags) self.force_retry(self.db.router.drop_user, self.ps.uaid) @@ -777,7 +778,7 @@ def _verify_user_record(self): not in self.db.message_tables: self.log.debug(format="Dropping User", code=105, uaid_hash=self.ps.uaid_hash, - uaid_record=dump_uaid(record)) + uaid_record=repr(record)) self.force_retry(self.db.router.drop_user, self.ps.uaid) tags = ['code:105'] @@ -916,9 +917,14 @@ def error_notifications(self, fail): def error_notification_overload(self, fail): """errBack for provisioned errors during notification check""" - fail.trap(ProvisionedThroughputExceededException) + fail.trap(ClientError) + + if (fail.value.response["Error"]["Code"] != + "ProvisionedThroughputExceededException"): + return fail # pragma nocover # Silently ignore the error, and reschedule the notification check - # to run up to a minute in the future to distribute load farther out + # to run up to a minute in the future to distribute load farther + # out d = self.deferToLater(randrange(5, 60), self.process_notifications) d.addErrback(self.trap_cancel) @@ -1068,7 +1074,10 @@ def error_monthly_rotation_overload(self, fail): websocket client flow is returned in the meantime. """ - fail.trap(ProvisionedThroughputExceededException) + fail.trap(ClientError) + if (fail.value.response['Error']['Code'] != + "ProvisionedThroughputExceededException"): + return fail # pragma nocover self.transport.resumeProducing() d = self.deferToLater(randrange(1, 30*60), self.process_notifications) d.addErrback(self.trap_cancel) diff --git a/tox.ini b/tox.ini index 629eaf6c..5ab422e5 100644 --- a/tox.ini +++ b/tox.ini @@ -4,7 +4,7 @@ envlist = py27,pypy,flake8,py36-mypy [testenv] deps = -rtest-requirements.txt usedevelop = True -passenv = SKIP_INTEGRATION +passenv = SKIP_INTEGRATION AWS_SHARED_CREDENTIALS_FILE BOTO_CONFIG commands = nosetests {posargs} autopush install_command = pip install --pre {opts} {packages}