From f8a1c1cd63811e44066cdc6356a59cd736e9039d Mon Sep 17 00:00:00 2001 From: jrconlin Date: Tue, 16 Jan 2018 13:21:30 -0800 Subject: [PATCH] bug: make boto3 calls thread safe Closes #1081 --- autopush/config.py | 3 + autopush/constants.py | 4 + autopush/db.py | 343 ++++++++++++++----------- autopush/diagnostic_cli.py | 16 +- autopush/main.py | 69 ++--- autopush/main_argparse.py | 4 + autopush/metrics.py | 18 -- autopush/router/webpush.py | 2 +- autopush/scripts/drop_user.py | 1 - autopush/tests/__init__.py | 35 ++- autopush/tests/support.py | 4 +- autopush/tests/test_db.py | 347 ++++++++++++++++---------- autopush/tests/test_diagnostic_cli.py | 36 ++- autopush/tests/test_endpoint.py | 4 +- autopush/tests/test_health.py | 80 +----- autopush/tests/test_integration.py | 100 +++++--- autopush/tests/test_main.py | 71 +++--- autopush/tests/test_metrics.py | 19 +- autopush/tests/test_router.py | 22 +- autopush/tests/test_web_base.py | 6 +- autopush/tests/test_web_validation.py | 4 +- autopush/tests/test_web_webpush.py | 12 +- autopush/tests/test_webpush_server.py | 24 +- autopush/tests/test_websocket.py | 165 ++++++------ autopush/web/health.py | 15 +- autopush/web/registration.py | 5 +- autopush/web/webpush.py | 15 +- autopush/webpush_server.py | 39 +-- autopush/websocket.py | 65 ++--- autopush_rs/Cargo.lock | 36 +-- docs/install.rst | 9 +- 31 files changed, 873 insertions(+), 700 deletions(-) create mode 100644 autopush/constants.py diff --git a/autopush/config.py b/autopush/config.py index 992523df..35a75038 100644 --- a/autopush/config.py +++ b/autopush/config.py @@ -176,6 +176,9 @@ class AutopushConfig(object): # Don't cache ssl.wrap_socket's SSLContexts no_sslcontext_cache = attrib(default=False) # type: bool + # DynamoDB endpoint override + aws_ddb_endpoint = attrib(default=None) # type: str + def __attrs_post_init__(self): """Initialize the Settings object""" # Setup hosts/ports/urls diff --git a/autopush/constants.py b/autopush/constants.py new file mode 100644 index 00000000..e8bc31fd --- /dev/null +++ b/autopush/constants.py @@ -0,0 +1,4 @@ +"""Shared constants which might produce circular includes""" + +# Number of concurrent threads / AWS Session resources +THREAD_POOL_SIZE = 50 diff --git a/autopush/db.py b/autopush/db.py index c4cbfba9..7cb80afc 100644 --- a/autopush/db.py +++ b/autopush/db.py @@ -34,6 +34,7 @@ import datetime import os import random +import threading import time import uuid from functools import wraps @@ -42,14 +43,16 @@ asdict, attrs, attrib, - Factory -) + Factory, + Attribute) from boto.dynamodb2.exceptions import ( ItemNotFound, ) import boto3 import botocore +from boto.dynamodb2.table import Table # noqa +from boto3.resources.base import ServiceResource # noqa from boto3.dynamodb.conditions import Key from boto3.exceptions import Boto3Error from botocore.exceptions import ClientError @@ -73,6 +76,7 @@ from twisted.internet.threads import deferToThread import autopush.metrics +from autopush import constants from autopush.exceptions import AutopushException from autopush.metrics import IMetrics # noqa from autopush.types import ItemLike # noqa @@ -96,16 +100,7 @@ TRACK_DB_CALLS = False DB_CALLS = [] -# See https://botocore.readthedocs.io/en/stable/reference/config.html for -# additional config options -g_dynamodb = boto3.resource( - 'dynamodb', - config=botocore.config.Config( - region_name=os.getenv("AWS_REGION_NAME", "us-east-1") - ), - endpoint_url=os.getenv("AWS_LOCAL_DYNAMODB") -) -g_client = g_dynamodb.meta.client +MAX_DDB_SESSIONS = constants.THREAD_POOL_SIZE def get_month(delta=0): @@ -145,22 +140,25 @@ def make_rotating_tablename(prefix, delta=0, date=None): return "{}_{:04d}_{:02d}".format(prefix, date.year, date.month) -def create_rotating_message_table(prefix="message", delta=0, date=None, - read_throughput=5, - write_throughput=5): - # type: (str, int, Optional[datetime.date], int, int) -> Table +def create_rotating_message_table( + prefix="message", # type: str + delta=0, # type: int + date=None, # type: Optional[datetime.date]] + read_throughput=5, # type: int + write_throughput=5, # type: int + boto_resource=None): # type: DynamoDBResource + # type: (...) -> Table # noqa """Create a new message table for webpush style message storage""" tablename = make_rotating_tablename(prefix, delta, date) try: - table = g_dynamodb.Table(tablename) - if table.table_status == 'ACTIVE': # pragma nocover + table = boto_resource.Table(tablename) + if table.table_status == 'ACTIVE': return table except ClientError as ex: if ex.response['Error']['Code'] != 'ResourceNotFoundException': - # If we hit this, our boto3 is misconfigured and we need to bail. - raise ex # pragma nocover - table = g_dynamodb.create_table( + pass # pragma nocover + table = boto_resource.create_table( TableName=tablename, KeySchema=[ { @@ -201,25 +199,32 @@ def create_rotating_message_table(prefix="message", delta=0, date=None, return table -def get_rotating_message_table(prefix="message", delta=0, date=None, - message_read_throughput=5, - message_write_throughput=5): - # type: (str, int, Optional[datetime.date], int, int) -> Table +def get_rotating_message_tablename( + prefix="message", # type: str + delta=0, # type: int + date=None, # type: Optional[datetime.date] + message_read_throughput=5, # type: int + message_write_throughput=5, # type: int + boto_resource=None): # type: DynamoDBResource + # type: (...) -> str # noqa """Gets the message table for the current month.""" tablename = make_rotating_tablename(prefix, delta, date) - if not table_exists(tablename): - return create_rotating_message_table( + if not table_exists(tablename, boto_resource=boto_resource): + create_rotating_message_table( prefix=prefix, delta=delta, date=date, read_throughput=message_read_throughput, write_throughput=message_write_throughput, + boto_resource=boto_resource ) + return tablename else: - return g_dynamodb.Table(tablename) + return tablename def create_router_table(tablename="router", read_throughput=5, - write_throughput=5): - # type: (str, int, int) -> Table + write_throughput=5, + boto_resource=None): + # type: (str, int, int, DynamoDBResource) -> Table """Create a new router table The last_connect index is a value used to determine the last month a user @@ -233,8 +238,7 @@ def create_router_table(tablename="router", read_throughput=5, cost of additional queries during GC to locate expired users. """ - - table = g_dynamodb.create_table( + table = boto_resource.create_table( TableName=tablename, KeySchema=[ { @@ -293,20 +297,30 @@ def create_router_table(tablename="router", read_throughput=5, return table -def _drop_table(tablename): +def _drop_table(tablename, boto_resource): + # type: (str, DynamoDBResource) -> None try: - g_client.delete_table(TableName=tablename) + boto_resource.meta.client.delete_table(TableName=tablename) except ClientError: # pragma nocover pass -def _make_table(table_func, tablename, read_throughput, write_throughput): - # type: (Callable[[str, int, int], Table], str, int, int) -> Table +def _make_table( + table_func, # type: Callable[[str, int, int, ServiceResource]] + tablename, # type: str + read_throughput, # type: int + write_throughput, # type: int + boto_resource # type: DynamoDBResource + ): + # type (...) -> DynamoDBTable """Private common function to make a table with a table func""" - if not table_exists(tablename): - return table_func(tablename, read_throughput, write_throughput) + if not boto_resource: + raise AutopushException("No boto3 resource provided for _make_table") + if not table_exists(tablename, boto_resource): + return table_func(tablename, read_throughput, write_throughput, + boto_resource) else: - return g_dynamodb.Table(tablename) + return DynamoDBTable(boto_resource, tablename) def _expiry(ttl): @@ -314,8 +328,8 @@ def _expiry(ttl): def get_router_table(tablename="router", read_throughput=5, - write_throughput=5): - # type: (str, int, int) -> Table + write_throughput=5, boto_resource=None): + # type: (str, int, int, DynamoDBResource) -> Table """Get the main router table object Creates the table if it doesn't already exist, otherwise returns the @@ -323,7 +337,7 @@ def get_router_table(tablename="router", read_throughput=5, """ return _make_table(create_router_table, tablename, read_throughput, - write_throughput) + write_throughput, boto_resource=boto_resource) def preflight_check(message, router, uaid="deadbeef00000000deadbeef00000000"): @@ -355,7 +369,6 @@ def preflight_check(message, router, uaid="deadbeef00000000deadbeef00000000"): message_id=message_id, ttl=60, ) - # Store a notification, fetch it, delete it message.store_message(notif) assert message.delete_message(notif) @@ -432,45 +445,70 @@ def generate_last_connect_values(date): yield int(val) -def list_tables(client=g_client): - """Return a list of the names of all DynamoDB tables.""" - start_table = None - while True: - if start_table: # pragma nocover - result = client.list_tables(ExclusiveStartTableName=start_table) - else: - result = client.list_tables() - for table in result.get('TableNames', []): - yield table - start_table = result.get('LastEvaluatedTableName', None) - if not start_table: - break +def table_exists(tablename, boto_resource=None): + # type: (str, DynamoDBResource) -> bool + """Determine if the specified Table exists""" + try: + return boto_resource.Table(tablename).table_status in [ + 'CREATING', 'UPDATING', 'ACTIVE'] + except ClientError: + return False -def table_exists(tablename, client=None): - """Determine if the specified Table exists""" - if not client: - client = g_client - return tablename in list_tables(client) +class DynamoDBResource(threading.local): + def __init__(self, **kwargs): + conf = kwargs + if not conf.get("endpoint_url"): + conf["endpoint_url"] = os.getenv("AWS_LOCAL_DYNAMODB") + # If there is no endpoint URL, we must delete the entry + if "endpoint_url" in conf and not conf["endpoint_url"]: + del(conf["endpoint_url"]) + region = conf.get("region_name", + os.getenv("AWS_REGION_NAME", "us-east-1")) + if "region_name" in conf: + del(conf["region_name"]) + self.conf = conf + self._resource = boto3.resource( + "dynamodb", + config=botocore.config.Config(region_name=region), + **self.conf) + + def __getattr__(self, name): + return getattr(self._resource, name) + + +class DynamoDBTable(threading.local): + def __init__(self, ddb_resource, *args, **kwargs): + # type: (DynamoDBResource, *Any, **Any) -> None + self._table = ddb_resource.Table(*args, **kwargs) + + def __getattr__(self, name): + return getattr(self._table, name) class Message(object): """Create a Message table abstraction on top of a DynamoDB Table object""" - def __init__(self, table, metrics, max_ttl=MAX_EXPIRY): - # type: (Table, IMetrics) -> None + def __init__(self, tablename, metrics=None, boto_resource=None, + max_ttl=MAX_EXPIRY): + # type: (str, IMetrics, DynamoDBResource, int) -> None """Create a new Message object - :param table: :class:`Table` object. - :param metrics: Metrics object that implements the - :class:`autopush.metrics.IMetrics` interface. + :param tablename: name of the table. + :param metrics: unused + :param boto_resource: DynamoDBResource for thread """ - self.table = table - self.metrics = metrics + self.tablename = tablename self._max_ttl = max_ttl + self.resource = boto_resource + + def table(self, tablename=None): + if not tablename: + tablename = self.tablename + return DynamoDBTable(self.resource, tablename) def table_status(self): - return self.table.table_status + return self.table().table_status @track_provisioned def register_channel(self, uaid, channel_id, ttl=None): @@ -483,7 +521,7 @@ def register_channel(self, uaid, channel_id, ttl=None): ":channel_id": set([normalize_id(channel_id)]), ":expiry": _expiry(ttl) } - self.table.update_item( + self.table().update_item( Key={ 'uaid': hasher(uaid), 'chidmessageid': ' ', @@ -501,7 +539,7 @@ def unregister_channel(self, uaid, channel_id, **kwargs): chid = normalize_id(channel_id) expr_values = {":channel_id": set([chid])} - response = self.table.update_item( + response = self.table().update_item( Key={ 'uaid': hasher(uaid), 'chidmessageid': ' ', @@ -521,13 +559,13 @@ def unregister_channel(self, uaid, channel_id, **kwargs): @track_provisioned def all_channels(self, uaid): - # type: (str) -> Tuple[bool, Set[str]] + # type: (str, str) -> Tuple[bool, Set[str]] """Retrieve a list of all channels for a given uaid""" # Note: This only returns the chids associated with the UAID. # Functions that call store_message() would be required to # update that list as well using register_channel() - result = self.table.get_item( + result = self.table().get_item( Key={ 'uaid': hasher(uaid), 'chidmessageid': ' ', @@ -544,7 +582,7 @@ def all_channels(self, uaid): def save_channels(self, uaid, channels): # type: (str, Set[str]) -> None """Save out a set of channels""" - self.table.put_item( + self.table().put_item( Item={ 'uaid': hasher(uaid), 'chidmessageid': ' ', @@ -570,7 +608,7 @@ def store_message(self, notification): ) if notification.data: item['data'] = notification.data - self.table.put_item(Item=item) + self.table().put_item(Item=item) @track_provisioned def delete_message(self, notification): @@ -578,7 +616,7 @@ def delete_message(self, notification): """Deletes a specific message""" if notification.update_id: try: - self.table.delete_item( + self.table().delete_item( Key={ 'uaid': hasher(notification.uaid.hex), 'chidmessageid': notification.sort_key @@ -592,7 +630,7 @@ def delete_message(self, notification): except ClientError: return False else: - self.table.delete_item( + self.table().delete_item( Key={ 'uaid': hasher(notification.uaid.hex), 'chidmessageid': notification.sort_key, @@ -613,16 +651,16 @@ def fetch_messages( """ # Eagerly fetches all results in the result set. - response = self.table.query( + response = self.table().query( KeyConditionExpression=(Key("uaid").eq(hasher(uaid.hex)) & Key('chidmessageid').lt('02')), ConsistentRead=True, Limit=limit, ) results = list(response['Items']) - # First extract the position if applicable, slightly higher than 01: - # to ensure we don't load any 01 remainders that didn't get deleted - # yet + # First extract the position if applicable, slightly higher than + # 01: to ensure we don't load any 01 remainders that didn't get + # deleted yet last_position = None if results: # Ensure we return an int, as boto2 can return Decimals @@ -658,7 +696,7 @@ def fetch_timestamp_messages( else: sortkey = "01;" - response = self.table.query( + response = self.table().query( KeyConditionExpression=(Key('uaid').eq(hasher(uaid.hex)) & Key('chidmessageid').gt(sortkey)), ConsistentRead=True, @@ -681,7 +719,7 @@ def update_last_message_read(self, uaid, timestamp): expr = "SET current_timestamp=:timestamp, expiry=:expiry" expr_values = {":timestamp": timestamp, ":expiry": _expiry(self._max_ttl)} - self.table.update_item( + self.table().update_item( Key={ "uaid": hasher(uaid.hex), "chidmessageid": " " @@ -694,21 +732,31 @@ def update_last_message_read(self, uaid, timestamp): class Router(object): """Create a Router table abstraction on top of a DynamoDB Table object""" - def __init__(self, table, metrics, max_ttl=MAX_EXPIRY): - # type: (Table, IMetrics, int) -> None + def __init__(self, conf, metrics, max_ttl=MAX_EXPIRY, resource=None): + # type: (DDBTableConfig, IMetrics, int, DynamoDBResource) -> None """Create a new Router object :param table: :class:`Table` object. :param metrics: Metrics object that implements the :class:`autopush.metrics.IMetrics` interface. + :param max_ttl: Default maximum time to live. """ - self.table = table + self.conf = conf self.metrics = metrics self._max_ttl = max_ttl + self._cached_table = None + self._resource = resource or DynamoDBResource(**asdict(self.conf)) + self._table = get_router_table( + tablename=self.conf.tablename, + boto_resource=self._resource + ) + + def table(self): + return self._table def table_status(self): - return self.table.table_status + return self.table().table_status def get_uaid(self, uaid): # type: (str) -> Item @@ -721,7 +769,7 @@ def get_uaid(self, uaid): """ try: - item = self.table.get_item( + item = self.table().get_item( Key={ 'uaid': hasher(uaid) }, @@ -762,7 +810,8 @@ def register_user(self, data): db_key = {"uaid": hasher(data["uaid"])} del data["uaid"] if "router_type" not in data or "connected_at" not in data: - # Not specifying these values will generate an exception in AWS. + # Not specifying these values will generate an exception in + # AWS. raise AutopushException("data is missing router_type " "or connected_at") # Generate our update expression @@ -776,7 +825,7 @@ def register_user(self, data): attribute_not_exists(node_id) or (connected_at < :connected_at) )""" - result = self.table.update_item( + result = self.table().update_item( Key=db_key, UpdateExpression=expr, ConditionExpression=cond, @@ -787,7 +836,7 @@ def register_user(self, data): r = {} for key, value in result["Attributes"].items(): try: - r[key] = self.table._dynamizer.decode(value) + r[key] = self.table()._dynamizer.decode(value) except (TypeError, AttributeError): # pragma: nocover # Included for safety as moto has occasionally made # this not work @@ -795,8 +844,8 @@ def register_user(self, data): result = r return (True, result) except ClientError as ex: - # ClientErrors are generated by a factory, and while they have a - # class, it's dynamically generated. + # ClientErrors are generated by a factory, and while they have + # a class, it's dynamically generated. if ex.response['Error']['Code'] == \ 'ConditionalCheckFailedException': return (False, {}) @@ -809,7 +858,7 @@ def drop_user(self, uaid): # The following hack ensures that only uaids that exist and are # deleted return true. try: - item = self.table.get_item( + item = self.table().get_item( Key={ 'uaid': hasher(uaid) }, @@ -817,15 +866,16 @@ def drop_user(self, uaid): ) if 'Item' not in item: return False - except ClientError: + except ClientError: # pragma nocover pass - result = self.table.delete_item(Key={'uaid': hasher(uaid)}) + result = self.table().delete_item( + Key={'uaid': hasher(uaid)}) return result['ResponseMetadata']['HTTPStatusCode'] == 200 def delete_uaids(self, uaids): # type: (List[str]) -> None """Issue a batch delete call for the given uaids""" - with self.table.batch_writer() as batch: + with self.table().batch_writer() as batch: for uaid in uaids: batch.delete_item(Key={'uaid': uaid}) @@ -859,7 +909,7 @@ def drop_old_users(self, months_ago=2): batched = [] for hash_key in generate_last_connect_values(prior_date): - response = self.table.query( + response = self.table().query( KeyConditionExpression=Key("last_connect").eq(hash_key), IndexName="AccessIndex", ) @@ -879,7 +929,7 @@ def drop_old_users(self, months_ago=2): @track_provisioned def _update_last_connect(self, uaid, last_connect): - self.table.update_item( + self.table().update_item( Key={"uaid": hasher(uaid)}, UpdateExpression="SET last_connect=:last_connect", ExpressionAttributeValues={":last_connect": last_connect} @@ -903,7 +953,7 @@ def update_message_month(self, uaid, month): ":last_connect": generate_last_connect(), ":expiry": _expiry(self._max_ttl), } - self.table.update_item( + self.table().update_item( Key=db_key, UpdateExpression=expr, ExpressionAttributeValues=expr_values, @@ -930,7 +980,7 @@ def clear_node(self, item): try: cond = "(node_id = :node) and (connected_at = :conn)" - self.table.put_item( + self.table().put_item( Item=item, ConditionExpression=cond, ExpressionAttributeValues={ @@ -953,15 +1003,15 @@ class DatabaseManager(object): _router_conf = attrib() # type: DDBTableConfig _message_conf = attrib() # type: DDBTableConfig - - metrics = attrib() # type: IMetrics + metrics = attrib() # type: IMetrics + resource = attrib() # type: DynamoDBResource router = attrib(default=None) # type: Optional[Router] - message_tables = attrib(default=Factory(dict)) # type: Dict[str, Message] + message_tables = attrib(default=Factory(list)) # type: List[str] current_msg_month = attrib(init=False) # type: Optional[str] current_month = attrib(init=False) # type: Optional[int] + _message = attrib(default=None) # type: Optional[Message] # for testing: - client = attrib(default=g_client) # type: Optional[Any] def __attrs_post_init__(self): """Initialize sane defaults""" @@ -971,16 +1021,25 @@ def __attrs_post_init__(self): self._message_conf.tablename, date=today ) + if not self.resource: + self.resource = DynamoDBResource() @classmethod - def from_config(cls, conf, **kwargs): - # type: (AutopushConfig, **Any) -> DatabaseManager + def from_config(cls, + conf, # type: AutopushConfig + resource=None, # type: Optional[DynamoDBResource] + **kwargs # type: Any + ): + # type: (...) -> DatabaseManager """Create a DatabaseManager from the given config""" metrics = autopush.metrics.from_config(conf) + if not resource: + resource = DynamoDBResource() return cls( router_conf=conf.router_table, message_conf=conf.message_table, metrics=metrics, + resource=resource, **kwargs ) @@ -994,9 +1053,10 @@ def setup(self, preflight_uaid): def setup_tables(self): """Lookup or create the database tables""" self.router = Router( - get_router_table(**asdict(self._router_conf)), - self.metrics - ) + conf=self._router_conf, + metrics=self.metrics, + resource=self.resource) + self.router.table() # Used to determine whether a connection is out of date with current # db objects. There are three noteworty cases: # 1 "Last Month" the table requires a rollover. @@ -1006,18 +1066,20 @@ def setup_tables(self): # table is present before the switchover is the main reason for this, # just in case some nodes do switch sooner. self.create_initial_message_tables() + self._message = Message(self.current_msg_month, + self.metrics, + boto_resource=self.resource) @property def message(self): # type: () -> Message """Property that access the current message table""" - return self.message_tables[self.current_msg_month] + if not self._message or isinstance(self._message, Attribute): + self._message = self.message_table(self.current_msg_month) + return self._message - @message.setter - def message(self, value): - # type: (Message) -> None - """Setter to set the current message table""" - self.message_tables[self.current_msg_month] = value + def message_table(self, tablename): + return Message(tablename, self.metrics, boto_resource=self.resource) def _tomorrow(self): # type: () -> datetime.date @@ -1032,32 +1094,31 @@ def create_initial_message_tables(self): """ mconf = self._message_conf today = datetime.date.today() - last_month = get_rotating_message_table( + last_month = get_rotating_message_tablename( prefix=mconf.tablename, delta=-1, message_read_throughput=mconf.read_throughput, - message_write_throughput=mconf.write_throughput + message_write_throughput=mconf.write_throughput, + boto_resource=self.resource, ) - this_month = get_rotating_message_table( + this_month = get_rotating_message_tablename( prefix=mconf.tablename, message_read_throughput=mconf.read_throughput, - message_write_throughput=mconf.write_throughput + message_write_throughput=mconf.write_throughput, + boto_resource=self.resource, ) self.current_month = today.month - self.current_msg_month = this_month.table_name - self.message_tables = { - last_month.table_name: Message(last_month, self.metrics), - this_month.table_name: Message(this_month, self.metrics) - } + self.current_msg_month = this_month + self.message_tables = [last_month, this_month] if self._tomorrow().month != today.month: - next_month = get_rotating_message_table( + next_month = get_rotating_message_tablename( prefix=mconf.tablename, delta=1, message_read_throughput=mconf.read_throughput, - message_write_throughput=mconf.write_throughput + message_write_throughput=mconf.write_throughput, + boto_resource=self.resource, ) - self.message_tables[next_month.table_name] = Message( - next_month, self.metrics) + self.message_tables.append(next_month) @inlineCallbacks def update_rotating_tables(self): @@ -1073,34 +1134,32 @@ def update_rotating_tables(self): today = datetime.date.today() tomorrow = self._tomorrow() if ((tomorrow.month != today.month) and - sorted(self.message_tables.keys())[-1] != tomorrow.month): + sorted(self.message_tables)[-1] != tomorrow.month): next_month = yield deferToThread( - get_rotating_message_table, + get_rotating_message_tablename, prefix=mconf.tablename, delta=0, date=tomorrow, message_read_throughput=mconf.read_throughput, - message_write_throughput=mconf.write_throughput + message_write_throughput=mconf.write_throughput, + boto_resource=self.resource ) - self.message_tables[next_month.table_name] = Message( - next_month, self.metrics) - + self.message_tables.append(next_month) if today.month == self.current_month: # No change in month, we're fine. returnValue(False) - # Get tables for the new month, and verify they exist before we try to - # switch over + # Get tables for the new month, and verify they exist before we + # try to switch over message_table = yield deferToThread( - get_rotating_message_table, + get_rotating_message_tablename, prefix=mconf.tablename, message_read_throughput=mconf.read_throughput, - message_write_throughput=mconf.write_throughput + message_write_throughput=mconf.write_throughput, + boto_resource=self.resource, ) # Both tables found, safe to switch-over self.current_month = today.month - self.current_msg_month = message_table.table_name - self.message_tables[self.current_msg_month] = Message( - message_table, self.metrics) + self.current_msg_month = message_table returnValue(True) diff --git a/autopush/diagnostic_cli.py b/autopush/diagnostic_cli.py index 229aae7f..24debb1f 100644 --- a/autopush/diagnostic_cli.py +++ b/autopush/diagnostic_cli.py @@ -7,7 +7,7 @@ from twisted.logger import Logger from autopush.config import AutopushConfig -from autopush.db import DatabaseManager +from autopush.db import DatabaseManager, Message from autopush.main import AutopushMultiService from autopush.main_argparse import add_shared_args @@ -18,11 +18,11 @@ class EndpointDiagnosticCLI(object): log = Logger() - def __init__(self, sysargs, use_files=True): + def __init__(self, sysargs, resource, use_files=True): ns = self._load_args(sysargs, use_files) self._conf = conf = AutopushConfig.from_argparse(ns) conf.statsd_host = None - self.db = DatabaseManager.from_config(conf) + self.db = DatabaseManager.from_config(conf, resource=resource) self.db.setup(conf.preflight_uaid) self._endpoint = ns.endpoint self._pp = pprint.PrettyPrinter(indent=4) @@ -69,12 +69,14 @@ def run(self): print("\n") if "current_month" in rec: - mess_table = rec["current_month"] - chans = self.db.message_tables[mess_table].all_channels(uaid) + chans = Message(rec["current_month"], + boto_resource=self.db.resource).all_channels(uaid) print("Channels in message table:") self._pp.pprint(chans) -def run_endpoint_diagnostic_cli(sysargs=None, use_files=True): - cli = EndpointDiagnosticCLI(sysargs, use_files) +def run_endpoint_diagnostic_cli(sysargs=None, use_files=True, resource=None): + cli = EndpointDiagnosticCLI(sysargs, + resource=resource, + use_files=use_files) return cli.run() diff --git a/autopush/main.py b/autopush/main.py index ca5afb4e..741cf76c 100644 --- a/autopush/main.py +++ b/autopush/main.py @@ -20,6 +20,7 @@ Sequence, ) +from autopush import constants from autopush.http import ( InternalRouterHTTPFactory, EndpointHTTPFactory, @@ -29,12 +30,11 @@ import autopush.utils as utils import autopush.logging as logging from autopush.config import AutopushConfig -from autopush.db import DatabaseManager +from autopush.db import DatabaseManager, DynamoDBResource # noqa from autopush.exceptions import InvalidConfig from autopush.haproxy import HAProxyServerEndpoint from autopush.logging import PushLogger from autopush.main_argparse import parse_connection, parse_endpoint -from autopush.metrics import periodic_reporter from autopush.router import routers_from_config from autopush.ssl import ( monkey_patch_ssl_wrap_socket, @@ -62,13 +62,11 @@ class AutopushMultiService(MultiService): config_files = None # type: Sequence[str] logger_name = None # type: str - THREAD_POOL_SIZE = 50 - - def __init__(self, conf): - # type: (AutopushConfig) -> None + def __init__(self, conf, resource=None): + # type: (AutopushConfig, DynamoDBResource) -> None super(AutopushMultiService, self).__init__() self.conf = conf - self.db = DatabaseManager.from_config(conf) + self.db = DatabaseManager.from_config(conf, resource=resource) self.agent = agent_from_config(conf) @staticmethod @@ -104,7 +102,7 @@ def add_memusage(self): def run(self): """Start the services and run the reactor""" - reactor.suggestThreadPoolSize(self.THREAD_POOL_SIZE) + reactor.suggestThreadPoolSize(constants.THREAD_POOL_SIZE) self.startService() reactor.run() @@ -116,8 +114,8 @@ def stopService(self): undo_monkey_patch_ssl_wrap_socket() @classmethod - def _from_argparse(cls, ns, **kwargs): - # type: (Namespace, **Any) -> AutopushMultiService + def _from_argparse(cls, ns, resource=None, **kwargs): + # type: (Namespace, DynamoDBResource, **Any) -> AutopushMultiService """Create an instance from argparse/additional kwargs""" # Add some entropy to prevent potential conflicts. postfix = os.urandom(4).encode('hex').ljust(8, '0') @@ -127,11 +125,11 @@ def _from_argparse(cls, ns, **kwargs): preflight_uaid="deadbeef00000000deadbeef" + postfix, **kwargs ) - return cls(conf) + return cls(conf, resource=resource) @classmethod - def main(cls, args=None, use_files=True): - # type: (Sequence[str], bool) -> Any + def main(cls, args=None, use_files=True, resource=None): + # type: (Sequence[str], bool, DynamoDBResource) -> Any """Entry point to autopush's main command line scripts. aka autopush/autoendpoint. @@ -149,7 +147,8 @@ def main(cls, args=None, use_files=True): firehose_delivery_stream=ns.firehose_stream_name ) try: - app = cls.from_argparse(ns) + cls.argparse = cls.from_argparse(ns, resource=resource) + app = cls.argparse except InvalidConfig as e: log.critical(str(e)) return 1 @@ -173,9 +172,9 @@ class EndpointApplication(AutopushMultiService): endpoint_factory = EndpointHTTPFactory - def __init__(self, conf): - # type: (AutopushConfig) -> None - super(EndpointApplication, self).__init__(conf) + def __init__(self, conf, resource=None): + # type: (AutopushConfig, DynamoDBResource) -> None + super(EndpointApplication, self).__init__(conf, resource=resource) self.routers = routers_from_config(conf, self.db, self.agent) def setup(self, rotate_tables=True): @@ -190,8 +189,6 @@ def setup(self, rotate_tables=True): # Start the table rotation checker/updater if rotate_tables: self.add_timer(60, self.db.update_rotating_tables) - self.add_timer(15, periodic_reporter, self.db.metrics, - prefix='autoendpoint') def add_endpoint(self): """Start the Endpoint HTTP router""" @@ -212,8 +209,8 @@ def add_endpoint(self): self.addService(StreamServerEndpointService(ep, factory)) @classmethod - def from_argparse(cls, ns): - # type: (Namespace) -> AutopushMultiService + def from_argparse(cls, ns, resource=None): + # type: (Namespace, DynamoDBResource) -> AutopushMultiService return super(EndpointApplication, cls)._from_argparse( ns, port=ns.port, @@ -223,6 +220,8 @@ def from_argparse(cls, ns): cors=not ns.no_cors, bear_hash_key=ns.auth_key, proxy_protocol_port=ns.proxy_protocol_port, + aws_ddb_endpoint=ns.aws_ddb_endpoint, + resource=resource ) @@ -243,9 +242,12 @@ class ConnectionApplication(AutopushMultiService): websocket_factory = PushServerFactory websocket_site_factory = ConnectionWSSite - def __init__(self, conf): - # type: (AutopushConfig) -> None - super(ConnectionApplication, self).__init__(conf) + def __init__(self, conf, resource=None): + # type: (AutopushConfig, DynamoDBResource) -> None + super(ConnectionApplication, self).__init__( + conf, + resource=resource + ) self.clients = {} # type: Dict[str, PushServerProtocol] def setup(self, rotate_tables=True): @@ -262,7 +264,6 @@ def setup(self, rotate_tables=True): # Start the table rotation checker/updater if rotate_tables: self.add_timer(60, self.db.update_rotating_tables) - self.add_timer(15, periodic_reporter, self.db.metrics) def add_internal_router(self): """Start the internal HTTP notification router""" @@ -280,8 +281,8 @@ def add_websocket(self): self.add_maybe_ssl(conf.port, site_factory, site_factory.ssl_cf()) @classmethod - def from_argparse(cls, ns): - # type: (Namespace) -> AutopushMultiService + def from_argparse(cls, ns, resource=None): + # type: (Namespace, DynamoDBResource) -> AutopushMultiService return super(ConnectionApplication, cls)._from_argparse( ns, port=ns.port, @@ -302,6 +303,8 @@ def from_argparse(cls, ns): auto_ping_timeout=ns.auto_ping_timeout, max_connections=ns.max_connections, close_handshake_timeout=ns.close_handshake_timeout, + aws_ddb_endpoint=ns.aws_ddb_endpoint, + resource=resource ) @@ -344,8 +347,8 @@ def stopService(self): yield super(RustConnectionApplication, self).stopService() @classmethod - def from_argparse(cls, ns): - # type: (Namespace) -> AutopushMultiService + def from_argparse(cls, ns, resource=None): + # type: (Namespace, DynamoDBResource) -> AutopushMultiService return super(RustConnectionApplication, cls)._from_argparse( ns, port=ns.port, @@ -367,11 +370,13 @@ def from_argparse(cls, ns): auto_ping_timeout=ns.auto_ping_timeout, max_connections=ns.max_connections, close_handshake_timeout=ns.close_handshake_timeout, + aws_ddb_endpoint=ns.aws_ddb_endpoint, + resource=resource ) @classmethod - def main(cls, args=None, use_files=True): - # type: (Sequence[str], bool) -> Any + def main(cls, args=None, use_files=True, resource=None): + # type: (Sequence[str], bool, DynamoDBResource) -> Any """Entry point to autopush's main command line scripts. aka autopush/autoendpoint. @@ -389,7 +394,7 @@ def main(cls, args=None, use_files=True): firehose_delivery_stream=ns.firehose_stream_name ) try: - app = cls.from_argparse(ns) + app = cls.from_argparse(ns, resource=resource) except InvalidConfig as e: log.critical(str(e)) return 1 diff --git a/autopush/main_argparse.py b/autopush/main_argparse.py index 0813a6c2..1ad06da9 100644 --- a/autopush/main_argparse.py +++ b/autopush/main_argparse.py @@ -101,6 +101,10 @@ def add_shared_args(parser): help="Don't cache ssl.wrap_socket's SSLContexts", action="store_true", default=False, env_var="_NO_SSLCONTEXT_CACHE") + parser.add_argument('--aws_ddb_endpoint', + help="AWS DynamoDB endpoint override", + type=str, default=None, + env_var="AWS_LOCAL_DYNAMODB") # No ENV because this is for humans _add_external_router_args(parser) _obsolete_args(parser) diff --git a/autopush/metrics.py b/autopush/metrics.py index 6becd51f..38766e91 100644 --- a/autopush/metrics.py +++ b/autopush/metrics.py @@ -129,21 +129,3 @@ def from_config(conf): return TwistedMetrics(conf.statsd_host, conf.statsd_port) else: return SinkMetrics() - - -def periodic_reporter(metrics, prefix=''): - # type: (IMetrics, Optional[str]) -> None - """Emit metrics on twisted's thread pool. - - Only meant to be called via a LoopingCall (TimerService). - - """ - # unfortunately stats only available via the private '_team' - stats = reactor.getThreadPool()._team.statistics() - for attr in ('idleWorkerCount', 'busyWorkerCount', 'backloggedWorkCount'): - name = '{}{}twisted.threadpool.{}'.format( - prefix, - '.' if prefix else '', - attr - ) - metrics.gauge(name, getattr(stats, attr)) diff --git a/autopush/router/webpush.py b/autopush/router/webpush.py index 177c9a08..93602140 100644 --- a/autopush/router/webpush.py +++ b/autopush/router/webpush.py @@ -231,7 +231,7 @@ def _save_notification(self, uaid_data, notification): "Location": location}, logged_status=204) return deferToThread( - self.db.message_tables[month_table].store_message, + self.db.message_table(month_table).store_message, notification=notification, ) diff --git a/autopush/scripts/drop_user.py b/autopush/scripts/drop_user.py index 2405e176..eb5a49c5 100644 --- a/autopush/scripts/drop_user.py +++ b/autopush/scripts/drop_user.py @@ -19,7 +19,6 @@ def drop_users(router_table_name, months_ago, batch_size, pause_time): router_table = get_router_table(router_table_name) router = Router(router_table, SinkMetrics()) - click.echo("Deleting users with a last_connect %s months ago." % months_ago) diff --git a/autopush/tests/__init__.py b/autopush/tests/__init__.py index adc0ead4..b090f187 100644 --- a/autopush/tests/__init__.py +++ b/autopush/tests/__init__.py @@ -4,12 +4,10 @@ import subprocess import boto -import botocore -import boto3 import psutil +from twisted.internet import reactor -import autopush.db -from autopush.db import create_rotating_message_table +from autopush.db import create_rotating_message_table, DynamoDBResource here_dir = os.path.abspath(os.path.dirname(__file__)) root_dir = os.path.dirname(os.path.dirname(here_dir)) @@ -17,35 +15,34 @@ ddb_lib_dir = os.path.join(ddb_dir, "DynamoDBLocal_lib") ddb_jar = os.path.join(ddb_dir, "DynamoDBLocal.jar") ddb_process = None +boto_resource = None def setUp(): for name in ('boto', 'boto3', 'botocore'): logging.getLogger(name).setLevel(logging.CRITICAL) - global ddb_process + global ddb_process, boto_resource cmd = " ".join([ "java", "-Djava.library.path=%s" % ddb_lib_dir, "-jar", ddb_jar, "-sharedDb", "-inMemory" ]) - conf = botocore.config.Config( - region_name=os.getenv('AWS_REGION_NAME', 'us-east-1') - ) ddb_process = subprocess.Popen(cmd, shell=True, env=os.environ) - autopush.db.g_dynamodb = boto3.resource( - 'dynamodb', - config=conf, - endpoint_url=os.getenv("AWS_LOCAL_DYNAMODB", "http://127.0.0.1:8000"), + if os.getenv("AWS_LOCAL_DYNAMODB") is None: + os.environ["AWS_LOCAL_DYNAMODB"] = "http://127.0.0.1:8000" + ddb_session_args = dict( + endpoint_url=os.getenv("AWS_LOCAL_DYNAMODB"), aws_access_key_id="BogusKey", aws_secret_access_key="BogusKey", ) - - autopush.db.g_client = autopush.db.g_dynamodb.meta.client - + boto_resource = DynamoDBResource(**ddb_session_args) # Setup the necessary message tables message_table = os.environ.get("MESSAGE_TABLE", "message_int_test") - - create_rotating_message_table(prefix=message_table, delta=-1) - create_rotating_message_table(prefix=message_table) + create_rotating_message_table(prefix=message_table, delta=-1, + boto_resource=boto_resource) + create_rotating_message_table(prefix=message_table, + boto_resource=boto_resource) + pool = reactor.getThreadPool() + pool.adjustPoolsize(minthreads=pool.max) def tearDown(): @@ -59,7 +56,7 @@ def tearDown(): # Clear out the boto config that was loaded so the rest of the tests run # fine - for section in boto.config.sections(): + for section in boto.config.sections(): # pragma: nocover boto.config.remove_section(section) diff --git a/autopush/tests/support.py b/autopush/tests/support.py index 1d4e4bfd..93608e97 100644 --- a/autopush/tests/support.py +++ b/autopush/tests/support.py @@ -8,6 +8,7 @@ Router, ) from autopush.metrics import SinkMetrics +import autopush.tests @implementer(ILogObserver) @@ -44,5 +45,6 @@ def test_db(metrics=None): router_conf=DDBTableConfig(tablename='router'), router=Mock(spec=Router), message_conf=DDBTableConfig(tablename='message'), - metrics=SinkMetrics() if metrics is None else metrics + metrics=SinkMetrics() if metrics is None else metrics, + resource=autopush.tests.boto_resource, ) diff --git a/autopush/tests/test_db.py b/autopush/tests/test_db.py index 74fd8063..212da1cc 100644 --- a/autopush/tests/test_db.py +++ b/autopush/tests/test_db.py @@ -1,3 +1,4 @@ +import os import unittest import uuid from datetime import datetime, timedelta @@ -7,12 +8,12 @@ ItemNotFound, ) from botocore.exceptions import ClientError -from mock import Mock +from mock import Mock, patch import pytest +from autopush.config import DDBTableConfig from autopush.db import ( - get_rotating_message_table, - get_router_table, + get_rotating_message_tablename, create_router_table, preflight_check, table_exists, @@ -20,16 +21,31 @@ Router, generate_last_connect, make_rotating_tablename, + create_rotating_message_table, _drop_table, - _make_table) + _make_table, + DatabaseManager, + DynamoDBResource + ) from autopush.exceptions import AutopushException from autopush.metrics import SinkMetrics from autopush.utils import WebPushNotification +# nose fails to import sessions correctly. +import autopush.tests dummy_uaid = str(uuid.UUID("abad1dea00000000aabbccdd00000000")) dummy_chid = str(uuid.UUID("deadbeef00000000decafbad00000000")) +test_router = None + + +def setUp(): + global test_router + config = DDBTableConfig("router_test") + test_router = Router(config, SinkMetrics(), + resource=autopush.tests.boto_resource) + def make_webpush_notification(uaid, chid, ttl=100): message_id = str(uuid.uuid4()) @@ -44,17 +60,60 @@ def make_webpush_notification(uaid, chid, ttl=100): class DbUtilsTest(unittest.TestCase): def test_make_table(self): + fake_resource = Mock() fake_func = Mock() fake_table = "DoesNotExist_{}".format(uuid.uuid4()) - _make_table(fake_func, fake_table, 5, 10) - assert fake_func.call_args[0] == (fake_table, 5, 10) + _make_table(fake_func, fake_table, 5, 10, boto_resource=fake_resource) + assert fake_func.call_args[0] == (fake_table, 5, 10, fake_resource) + + def test_make_table_no_resource(self): + fake_func = Mock() + fake_table = "DoesNotExist_{}".format(uuid.uuid4()) + + with pytest.raises(AutopushException) as ex: + _make_table(fake_func, fake_table, 5, 10, + boto_resource=None) + assert ex.value.message == "No boto3 resource provided for _make_table" + + +class DatabaseManagerTest(unittest.TestCase): + def test_init_with_resources(self): + from autopush.db import DynamoDBResource + dm = DatabaseManager(router_conf=Mock(), + message_conf=Mock(), + metrics=Mock(), + resource=None) + assert dm.resource is not None + assert isinstance(dm.resource, DynamoDBResource) + + +class DdbResourceTest(unittest.TestCase): + @patch("boto3.resource") + def test_ddb_no_endpoint(self, mresource): + safe = os.getenv("AWS_LOCAL_DYNAMODB") + os.unsetenv("AWS_LOCAL_DYANMODB") + del(os.environ["AWS_LOCAL_DYNAMODB"]) + DynamoDBResource(region_name="us-east-1") + assert mresource.call_args[0] == ('dynamodb',) + if safe: + os.environ["AWS_LOCAL_DYNAMODB"] = safe class DbCheckTestCase(unittest.TestCase): + def setUp(cls): + cls.resource = autopush.tests.boto_resource + cls.table_conf = DDBTableConfig("router_test") + cls.router = Router(cls.table_conf, SinkMetrics(), + resource=cls.resource) + def test_preflight_check_fail(self): - router = Router(get_router_table(), SinkMetrics()) - message = Message(get_rotating_message_table(), SinkMetrics()) + router = Router(self.table_conf, SinkMetrics(), + resource=self.resource) + message = Message(get_rotating_message_tablename( + boto_resource=self.resource), + SinkMetrics(), + boto_resource=self.resource) def raise_exc(*args, **kwargs): # pragma: no cover raise Exception("Oops") @@ -63,34 +122,38 @@ def raise_exc(*args, **kwargs): # pragma: no cover router.clear_node.side_effect = raise_exc with pytest.raises(Exception): - preflight_check(message, router) + preflight_check(message, router, self.resource) def test_preflight_check(self): - router = Router(get_router_table(), SinkMetrics()) - message = Message(get_rotating_message_table(), SinkMetrics()) + message = Message(get_rotating_message_tablename( + boto_resource=self.resource), + SinkMetrics(), + boto_resource=self.resource) pf_uaid = "deadbeef00000000deadbeef01010101" - preflight_check(message, router, pf_uaid) + preflight_check(message, test_router, pf_uaid) # now check that the database reports no entries. _, notifs = message.fetch_messages(uuid.UUID(pf_uaid)) assert len(notifs) == 0 with pytest.raises(ItemNotFound): - router.get_uaid(pf_uaid) + self.router.get_uaid(pf_uaid) def test_preflight_check_wait(self): - router = Router(get_router_table(), SinkMetrics()) - message = Message(get_rotating_message_table(), SinkMetrics()) + message = Message(get_rotating_message_tablename( + boto_resource=self.resource), + SinkMetrics(), + boto_resource=self.resource) values = ["PENDING", "ACTIVE"] message.table_status = Mock(side_effect=values) pf_uaid = "deadbeef00000000deadbeef01010101" - preflight_check(message, router, pf_uaid) + preflight_check(message, test_router, pf_uaid) # now check that the database reports no entries. _, notifs = message.fetch_messages(uuid.UUID(pf_uaid)) assert len(notifs) == 0 with pytest.raises(ItemNotFound): - router.get_uaid(pf_uaid) + self.router.get_uaid(pf_uaid) def test_get_month(self): from autopush.db import get_month @@ -126,21 +189,21 @@ def test_normalize_id(self): class MessageTestCase(unittest.TestCase): def setUp(self): - table = get_rotating_message_table() + self.resource = autopush.tests.boto_resource + table = get_rotating_message_tablename(boto_resource=self.resource) self.real_table = table self.uaid = str(uuid.uuid4()) - def tearDown(self): - pass - def test_register(self): chid = str(uuid.uuid4()) - m = get_rotating_message_table() - message = Message(m, SinkMetrics()) - message.register_channel(self.uaid, chid) + m = get_rotating_message_tablename(boto_resource=self.resource) + message = Message(m, metrics=SinkMetrics(), + boto_resource=self.resource) + message.register_channel(self.uaid, chid) + lm = self.resource.Table(m) # Verify it's in the db - response = m.query( + response = lm.query( KeyConditions={ 'uaid': { 'AttributeValueList': [self.uaid], @@ -157,12 +220,15 @@ def test_register(self): def test_unregister(self): chid = str(uuid.uuid4()) - m = get_rotating_message_table() - message = Message(m, SinkMetrics()) + m = get_rotating_message_tablename(boto_resource=self.resource) + message = Message(m, metrics=SinkMetrics(), + boto_resource=self.resource) message.register_channel(self.uaid, chid) # Verify its in the db - response = m.query( + lm = self.resource.Table(m) + # Verify it's in the db + response = lm.query( KeyConditions={ 'uaid': { 'AttributeValueList': [self.uaid], @@ -182,7 +248,7 @@ def test_unregister(self): message.unregister_channel(self.uaid, chid) # Verify its not in the db - response = m.query( + response = lm.query( KeyConditions={ 'uaid': { 'AttributeValueList': [self.uaid], @@ -200,18 +266,20 @@ def test_unregister(self): assert results[0].get("chids") is None # Test for the very unlikely case that there's no 'chid' - m.update_item = Mock(return_value={ + mtable = Mock() + mtable.update_item = Mock(return_value={ 'Attributes': {'uaid': self.uaid}, 'ResponseMetaData': {} }) + message.table = Mock(return_value=mtable) r = message.unregister_channel(self.uaid, dummy_chid) assert r is False def test_all_channels(self): chid = str(uuid.uuid4()) chid2 = str(uuid.uuid4()) - m = get_rotating_message_table() - message = Message(m, SinkMetrics()) + m = get_rotating_message_tablename(boto_resource=self.resource) + message = Message(m, SinkMetrics(), boto_resource=self.resource) message.register_channel(self.uaid, chid) message.register_channel(self.uaid, chid2) @@ -225,24 +293,25 @@ def test_all_channels(self): assert chid in chans def test_all_channels_fail(self): - m = get_rotating_message_table() - message = Message(m, SinkMetrics()) - message.table.get_item = Mock() - message.table.get_item.return_value = { + m = get_rotating_message_tablename(boto_resource=self.resource) + message = Message(m, SinkMetrics(), boto_resource=self.resource) + + mtable = Mock() + mtable.get_item.return_value = { "ResponseMetadata": { "HTTPStatusCode": 400 }, } - + message.table = Mock(return_value=mtable) res = message.all_channels(self.uaid) assert res == (False, set([])) def test_save_channels(self): chid = str(uuid.uuid4()) chid2 = str(uuid.uuid4()) - m = get_rotating_message_table() - message = Message(m, SinkMetrics()) + m = get_rotating_message_tablename(boto_resource=self.resource) + message = Message(m, SinkMetrics(), boto_resource=self.resource) message.register_channel(self.uaid, chid) message.register_channel(self.uaid, chid2) @@ -253,16 +322,16 @@ def test_save_channels(self): assert chans == new_chans def test_all_channels_no_uaid(self): - m = get_rotating_message_table() - message = Message(m, SinkMetrics()) + m = get_rotating_message_tablename(boto_resource=self.resource) + message = Message(m, SinkMetrics(), boto_resource=self.resource) exists, chans = message.all_channels(dummy_uaid) assert chans == set([]) def test_message_storage(self): chid = str(uuid.uuid4()) chid2 = str(uuid.uuid4()) - m = get_rotating_message_table() - message = Message(m, SinkMetrics()) + m = get_rotating_message_tablename(boto_resource=self.resource) + message = Message(m, SinkMetrics(), boto_resource=self.resource) message.register_channel(self.uaid, chid) message.register_channel(self.uaid, chid2) @@ -283,8 +352,8 @@ def test_message_storage_overwrite(self): notif2 = make_webpush_notification(self.uaid, chid) notif3 = make_webpush_notification(self.uaid, chid2) notif2.message_id = notif1.message_id - m = get_rotating_message_table() - message = Message(m, SinkMetrics()) + m = get_rotating_message_tablename(boto_resource=self.resource) + message = Message(m, SinkMetrics(), boto_resource=self.resource) message.register_channel(self.uaid, chid) message.register_channel(self.uaid, chid2) @@ -292,20 +361,22 @@ def test_message_storage_overwrite(self): message.store_message(notif2) message.store_message(notif3) - all_messages = list(message.fetch_messages(uuid.UUID(self.uaid))) + all_messages = list(message.fetch_messages( + uuid.UUID(self.uaid))) assert len(all_messages) == 2 def test_message_delete_fail_condition(self): notif = make_webpush_notification(dummy_uaid, dummy_chid) notif.message_id = notif.update_id = dummy_uaid - m = get_rotating_message_table() - message = Message(m, SinkMetrics()) + m = get_rotating_message_tablename(boto_resource=self.resource) + message = Message(m, SinkMetrics(), boto_resource=self.resource) def raise_condition(*args, **kwargs): raise ClientError({}, 'delete_item') - message.table = Mock() - message.table.delete_item.side_effect = raise_condition + m_de = Mock() + m_de.delete_item = Mock(side_effect=raise_condition) + message.table = Mock(return_value=m_de) result = message.delete_message(notif) assert result is False @@ -314,22 +385,19 @@ def test_message_rotate_table_with_date(self): future = (datetime.today() + timedelta(days=32)).date() tbl_name = make_rotating_tablename(prefix, date=future) - m = get_rotating_message_table(prefix=prefix, date=future) - assert m.table_name == tbl_name + m = get_rotating_message_tablename(prefix=prefix, date=future, + boto_resource=self.resource) + assert m == tbl_name # Clean up the temp table. - _drop_table(tbl_name) + _drop_table(tbl_name, boto_resource=self.resource) class RouterTestCase(unittest.TestCase): @classmethod - def setup_class(self): - table = get_router_table() - self.real_table = table - self.real_connection = table.meta.client - - @classmethod - def teardown_class(self): - self.real_table.meta.client = self.real_connection + def setUpClass(cls): + cls.resource = autopush.tests.boto_resource + cls.table_conf = DDBTableConfig("router_test") + cls.router = test_router def _create_minimal_record(self): data = { @@ -343,69 +411,76 @@ def _create_minimal_record(self): def test_drop_old_users(self): # First create a bunch of users - r = get_router_table() - router = Router(r, SinkMetrics()) # Purge any existing users from previous runs. - router.drop_old_users(0) + self.router.drop_old_users(months_ago=0) for _ in range(0, 53): - router.register_user(self._create_minimal_record()) + self.router.register_user(self._create_minimal_record()) - results = router.drop_old_users(months_ago=0) + results = self.router.drop_old_users(months_ago=0) assert list(results) == [25, 25, 3] def test_custom_tablename(self): db_name = "router_%s" % uuid.uuid4() - assert not table_exists(db_name) - create_router_table(db_name) - assert table_exists(db_name) + assert not table_exists(db_name, boto_resource=self.resource) + create_router_table(db_name, boto_resource=self.resource) + assert table_exists(db_name, boto_resource=self.resource) # Clean up the temp table. - _drop_table(db_name) + _drop_table(db_name, boto_resource=self.resource) + + def test_create_rotating_cache(self): + mock_table = Mock() + mock_table.table_status = 'ACTIVE' + mock_resource = Mock() + mock_resource.Table = Mock(return_value=mock_table) + table = create_rotating_message_table(boto_resource=mock_resource) + assert table == mock_table def test_provisioning(self): db_name = "router_%s" % uuid.uuid4() - r = create_router_table(db_name, 3, 17) + r = create_router_table(db_name, 3, 17, + boto_resource=self.resource) assert r.provisioned_throughput.get('ReadCapacityUnits') == 3 assert r.provisioned_throughput.get('WriteCapacityUnits') == 17 def test_no_uaid_found(self): uaid = str(uuid.uuid4()) - r = get_router_table() - router = Router(r, SinkMetrics()) with pytest.raises(ItemNotFound): - router.get_uaid(uaid) + self.router.get_uaid(uaid) def test_uaid_provision_failed(self): - r = get_router_table() - router = Router(r, SinkMetrics()) + router = Router(self.table_conf, SinkMetrics(), + resource=self.resource) router.table = Mock() def raise_condition(*args, **kwargs): - import autopush.db - raise autopush.db.g_client.exceptions.ClientError( + raise ClientError( {'Error': {'Code': 'ProvisionedThroughputExceededException'}}, 'mock_update_item' ) - router.table.get_item.side_effect = raise_condition + mm = Mock() + mm.get_item = Mock(side_effect=raise_condition) + router.table = Mock(return_value=mm) with pytest.raises(ClientError) as ex: router.get_uaid(uaid="asdf") assert (ex.value.response['Error']['Code'] == "ProvisionedThroughputExceededException") def test_register_user_provision_failed(self): - r = get_router_table() - router = Router(r, SinkMetrics()) - router.table.meta.client = Mock() + router = Router(self.table_conf, SinkMetrics(), resource=self.resource) + mm = Mock() + mm.client = Mock() + + router.table = Mock(return_value=mm) def raise_condition(*args, **kwargs): - import autopush.db - raise autopush.db.g_client.exceptions.ClientError( + raise ClientError( {'Error': {'Code': 'ProvisionedThroughputExceededException'}}, 'mock_update_item' ) - router.table.update_item = Mock(side_effect=raise_condition) + mm.update_item = Mock(side_effect=raise_condition) with pytest.raises(ClientError) as ex: router.register_user(dict(uaid=dummy_uaid, node_id="me", connected_at=1234, @@ -414,35 +489,36 @@ def raise_condition(*args, **kwargs): "ProvisionedThroughputExceededException") def test_register_user_condition_failed(self): - r = get_router_table() - router = Router(r, SinkMetrics()) - router.table.meta.client = Mock() + router = Router(self.table_conf, SinkMetrics(), + resource=self.resource) + router.table().meta.client = Mock() def raise_error(*args, **kwargs): - import autopush.db - raise autopush.db.g_client.exceptions.ClientError( + raise ClientError( {'Error': {'Code': 'ConditionalCheckFailedException'}}, 'mock_update_item' ) - - router.table.update_item = Mock(side_effect=raise_error) + mm = Mock() + mm.update_item = Mock(side_effect=raise_error) + router.table = Mock(return_value=mm) res = router.register_user(dict(uaid=dummy_uaid, node_id="me", connected_at=1234, router_type="webpush")) assert res == (False, {}) def test_clear_node_provision_failed(self): - r = get_router_table() - router = Router(r, SinkMetrics()) + router = Router(self.table_conf, SinkMetrics(), + resource=self.resource) def raise_condition(*args, **kwargs): - import autopush.db - raise autopush.db.g_client.exceptions.ClientError( + raise ClientError( {'Error': {'Code': 'ProvisionedThroughputExceededException'}}, 'mock_update_item' ) - router.table.put_item = Mock(side_effect=raise_condition) + mm = Mock() + mm.put_item = Mock(side_effect=raise_condition) + router.table = Mock(return_value=mm) with pytest.raises(ClientError) as ex: router.clear_node(dict(uaid=dummy_uaid, connected_at="1234", @@ -452,32 +528,35 @@ def raise_condition(*args, **kwargs): "ProvisionedThroughputExceededException") def test_clear_node_condition_failed(self): - r = get_router_table() - router = Router(r, SinkMetrics()) + router = Router(self.table_conf, SinkMetrics(), + resource=self.resource) def raise_error(*args, **kwargs): - import autopush.db - raise autopush.db.g_client.exceptions.ClientError( + raise ClientError( {'Error': {'Code': 'ConditionalCheckFailedException'}}, 'mock_put_item' ) - router.table.put_item = Mock(side_effect=raise_error) + mock_put = Mock() + mock_put.put_item = Mock(side_effect=raise_error) + mock_table = Mock(return_value=mock_put) + router.table = mock_table res = router.clear_node(dict(uaid=dummy_uaid, connected_at="1234", node_id="asdf", router_type="webpush")) + assert res is False def test_incomplete_uaid(self): # Older records may be incomplete. We can't inject them using normal # methods. uaid = str(uuid.uuid4()) - r = get_router_table() - router = Router(r, SinkMetrics()) - router.table.get_item = Mock() - router.drop_user = Mock() - router.table.get_item.return_value = { + router = Router(self.table_conf, SinkMetrics(), + resource=self.resource) + mm = Mock() + mm.get_item = Mock() + mm.get_item.return_value = { "ResponseMetadata": { "HTTPStatusCode": 200 }, @@ -485,6 +564,13 @@ def test_incomplete_uaid(self): "uaid": uuid.uuid4().hex } } + mm.delete_item.return_value = { + "ResponseMetadata": { + "HTTPStatusCode": 200 + }, + } + router.table = Mock(return_value=mm) + router.drop_user = Mock() try: router.register_user(dict(uaid=uaid)) except AutopushException: @@ -495,24 +581,28 @@ def test_incomplete_uaid(self): def test_failed_uaid(self): uaid = str(uuid.uuid4()) - r = get_router_table() - router = Router(r, SinkMetrics()) - router.table.get_item = Mock() - router.drop_user = Mock() - router.table.get_item.return_value = { + router = Router(self.table_conf, SinkMetrics(), + resource=self.resource) + mm = Mock() + mm.get_item = Mock() + mm.get_item.return_value = { "ResponseMetadata": { "HTTPStatusCode": 400 }, } + router.table = Mock(return_value=mm) + router.drop_user = Mock() with pytest.raises(ItemNotFound): router.get_uaid(uaid) def test_save_new(self): - r = get_router_table() - router = Router(r, SinkMetrics()) + router = Router(self.table_conf, SinkMetrics(), + resource=self.resource) # Sadly, moto currently does not return an empty value like boto # when not updating data. - router.table.update_item = Mock(return_value={}) + mock_update = Mock() + mock_update.update_item = Mock(return_value={}) + router.table = Mock(return_value=mock_update) result = router.register_user(dict(uaid=dummy_uaid, node_id="me", router_type="webpush", @@ -520,25 +610,26 @@ def test_save_new(self): assert result[0] is True def test_save_fail(self): - r = get_router_table() - router = Router(r, SinkMetrics()) + router = Router(self.table_conf, SinkMetrics(), + resource=self.resource) def raise_condition(*args, **kwargs): - import autopush.db - raise autopush.db.g_client.exceptions.ClientError( + raise ClientError( {'Error': {'Code': 'ConditionalCheckFailedException'}}, 'mock_update_item' ) - router.table.update_item = Mock(side_effect=raise_condition) + mock_update = Mock() + mock_update.update_item = Mock(side_effect=raise_condition) + router.table = Mock(return_value=mock_update) router_data = dict(uaid=dummy_uaid, node_id="asdf", connected_at=1234, router_type="webpush") result = router.register_user(router_data) assert result == (False, {}) def test_node_clear(self): - r = get_router_table() - router = Router(r, SinkMetrics()) + router = Router(self.table_conf, SinkMetrics(), + resource=self.resource) # Register a node user router.register_user(dict(uaid=dummy_uaid, node_id="asdf", @@ -560,8 +651,8 @@ def test_node_clear(self): assert user["router_type"] == "webpush" def test_node_clear_fail(self): - r = get_router_table() - router = Router(r, SinkMetrics()) + router = Router(self.table_conf, SinkMetrics(), + resource=self.resource) def raise_condition(*args, **kwargs): raise ClientError( @@ -569,15 +660,17 @@ def raise_condition(*args, **kwargs): 'mock_update_item' ) - router.table.put_item = Mock(side_effect=raise_condition) + mock_put = Mock() + mock_put.put_item = Mock(side_effect=raise_condition) + router.table = Mock(return_value=mock_put) data = dict(uaid=dummy_uaid, node_id="asdf", connected_at=1234) result = router.clear_node(data) assert result is False def test_drop_user(self): uaid = str(uuid.uuid4()) - r = get_router_table() - router = Router(r, SinkMetrics()) + router = Router(self.table_conf, SinkMetrics(), + resource=self.resource) # Register a node user router.register_user(dict(uaid=uaid, node_id="asdf", router_type="webpush", diff --git a/autopush/tests/test_diagnostic_cli.py b/autopush/tests/test_diagnostic_cli.py index adc7970f..fbfc1a5d 100644 --- a/autopush/tests/test_diagnostic_cli.py +++ b/autopush/tests/test_diagnostic_cli.py @@ -2,6 +2,8 @@ from mock import Mock, patch +import autopush.tests + class FakeDict(dict): pass @@ -10,14 +12,18 @@ class FakeDict(dict): class DiagnosticCLITestCase(unittest.TestCase): def _makeFUT(self, *args, **kwargs): from autopush.diagnostic_cli import EndpointDiagnosticCLI - return EndpointDiagnosticCLI(*args, use_files=False, **kwargs) + return EndpointDiagnosticCLI( + *args, + resource=autopush.tests.boto_resource, + use_files=False, + **kwargs) def test_basic_load(self): cli = self._makeFUT([ "--router_tablename=fred", "http://someendpoint", ]) - assert cli.db.router.table.table_name == "fred" + assert cli.db.router.table().table_name == "fred" def test_bad_endpoint(self): cli = self._makeFUT([ @@ -27,9 +33,10 @@ def test_bad_endpoint(self): returncode = cli.run() assert returncode not in (None, 0) + @patch("autopush.diagnostic_cli.Message") @patch("autopush.diagnostic_cli.AutopushConfig") @patch("autopush.diagnostic_cli.DatabaseManager.from_config") - def test_successfull_lookup(self, mock_db_cstr, mock_conf_class): + def test_successfull_lookup(self, mock_db_cstr, mock_conf_class, mock_msg): from autopush.diagnostic_cli import run_endpoint_diagnostic_cli mock_conf_class.return_value = mock_conf = Mock() mock_conf.parse_endpoint.return_value = dict( @@ -39,18 +46,25 @@ def test_successfull_lookup(self, mock_db_cstr, mock_conf_class): mock_db.router.get_uaid.return_value = mock_item = FakeDict() mock_item._data = {} mock_item["current_month"] = "201608120002" - mock_message_table = Mock() - mock_db.message_tables = {"201608120002": mock_message_table} + mock_db.message_tables = ["2016081200002"] + mock_msg.return_value = mock_message = Mock() - run_endpoint_diagnostic_cli([ - "--router_tablename=fred", - "http://something/wpush/v1/legit_endpoint", - ], use_files=False) - mock_message_table.all_channels.assert_called() + run_endpoint_diagnostic_cli( + sysargs=[ + "--router_tablename=fred", + "http://something/wpush/v1/legit_endpoint", + ], + use_files=False, + resource=autopush.tests.boto_resource) + mock_message.all_channels.assert_called() def test_parser_tuple(self): from autopush.diagnostic_cli import EndpointDiagnosticCLI - edc = EndpointDiagnosticCLI(("http://someendpoint",)) + edc = EndpointDiagnosticCLI( + ("http://someendpoint",), + use_files=False, + resource=autopush.tests.boto_resource + ) assert edc is not None assert edc._endpoint == "http://someendpoint" diff --git a/autopush/tests/test_endpoint.py b/autopush/tests/test_endpoint.py index af7426aa..70757f70 100644 --- a/autopush/tests/test_endpoint.py +++ b/autopush/tests/test_endpoint.py @@ -14,6 +14,7 @@ Message, ItemNotFound, has_connected_this_month, + Router ) from autopush.exceptions import RouterException from autopush.http import EndpointHTTPFactory @@ -51,7 +52,7 @@ def setUp(self): crypto_key='AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=', ) db = test_db() - self.message_mock = db.message = Mock(spec=Message) + self.message_mock = db._message = Mock(spec=Message) self.fernet_mock = conf.fernet = Mock(spec=Fernet) app = EndpointHTTPFactory.for_handler(MessageHandler, conf, db=db) @@ -140,6 +141,7 @@ def setUp(self): self.fernet_mock = conf.fernet = Mock(spec=Fernet) self.db = db = test_db() + db.router = Mock(spec=Router) db.router.register_user.return_value = (True, {}, {}) db.router.get_uaid.return_value = { "router_type": "test", diff --git a/autopush/tests/test_health.py b/autopush/tests/test_health.py index 3e96896f..537b9cbe 100644 --- a/autopush/tests/test_health.py +++ b/autopush/tests/test_health.py @@ -1,15 +1,13 @@ import json import twisted.internet.base -from boto.dynamodb2.exceptions import InternalServerError from mock import Mock from twisted.internet.defer import inlineCallbacks from twisted.logger import globalLogPublisher from twisted.trial import unittest -import autopush.db from autopush import __version__ -from autopush.config import AutopushConfig +from autopush.config import AutopushConfig, DDBTableConfig from autopush.db import DatabaseManager from autopush.exceptions import MissingTableException from autopush.http import EndpointHTTPFactory @@ -17,20 +15,24 @@ from autopush.tests.client import Client from autopush.tests.support import TestingLogObserver from autopush.web.health import HealthHandler, StatusHandler +import autopush.tests class HealthTestCase(unittest.TestCase): def setUp(self): - self.timeout = 0.5 + self.timeout = 4 twisted.internet.base.DelayedCall.debug = True conf = AutopushConfig( hostname="localhost", statsd_host=None, + router_table=DDBTableConfig(tablename="router_test"), + message_table=DDBTableConfig(tablename="message_int_test"), ) - db = DatabaseManager.from_config(conf) - db.client = autopush.db.g_client + db = DatabaseManager.from_config( + conf, + resource=autopush.tests.boto_resource) db.setup_tables() # ignore logging @@ -39,7 +41,6 @@ def setUp(self): self.addCleanup(globalLogPublisher.removeObserver, logs) app = EndpointHTTPFactory.for_handler(HealthHandler, conf, db=db) - self.router_table = app.db.router.table self.message = app.db.message self.client = Client(app) @@ -50,82 +51,23 @@ def test_healthy(self): "version": __version__, "clients": 0, "storage": {"status": "OK"}, - "router": {"status": "OK"} + "router_test": {"status": "OK"} }) - @inlineCallbacks - def test_aws_error(self): - - def raise_error(*args, **kwargs): - raise InternalServerError(None, None) - - safe = self.client.app.db.client - self.client.app.db.client = Mock() - self.client.app.db.client.list_tables = Mock(side_effect=raise_error) - - yield self._assert_reply({ - "status": "NOT OK", - "version": __version__, - "clients": 0, - "storage": { - "status": "NOT OK", - "error": "Server error" - }, - "router": { - "status": "NOT OK", - "error": "Server error" - } - }, InternalServerError) - - self.client.app.db.client = safe - @inlineCallbacks def test_nonexistent_table(self): - no_tables = Mock(return_value={"TableNames": []}) - safe = self.client.app.db.client - self.client.app.db.client = Mock() - self.client.app.db.client.list_tables = no_tables + self.client.app.db.message.table().delete() yield self._assert_reply({ "status": "NOT OK", "version": __version__, "clients": 0, + "router_test": {"status": "OK"}, "storage": { "status": "NOT OK", "error": "Nonexistent table" - }, - "router": { - "status": "NOT OK", - "error": "Nonexistent table" } }, MissingTableException) - self.client.app.db.client = safe - - @inlineCallbacks - def test_internal_error(self): - def raise_error(*args, **kwargs): - raise Exception("synergies not aligned") - - safe = self.client.app.db.client - self.client.app.db.client = Mock() - self.client.app.db.client.list_tables = Mock( - side_effect=raise_error - ) - - yield self._assert_reply({ - "status": "NOT OK", - "version": __version__, - "clients": 0, - "storage": { - "status": "NOT OK", - "error": "Internal error" - }, - "router": { - "status": "NOT OK", - "error": "Internal error" - } - }, Exception) - self.client.app.db.client = safe @inlineCallbacks def _assert_reply(self, reply, exception=None): diff --git a/autopush/tests/test_integration.py b/autopush/tests/test_integration.py index 69dd0581..7e767515 100644 --- a/autopush/tests/test_integration.py +++ b/autopush/tests/test_integration.py @@ -36,12 +36,14 @@ from autopush.config import AutopushConfig from autopush.db import ( get_month, - has_connected_this_month + has_connected_this_month, + Message, ) from autopush.logging import begin_or_register from autopush.main import ConnectionApplication, EndpointApplication from autopush.utils import base64url_encode, normalize_id from autopush.metrics import SinkMetrics, DatadogMetrics +import autopush.tests from autopush.tests.support import TestingLogObserver from autopush.websocket import PushServerFactory @@ -130,7 +132,8 @@ def register(self, chid=None, key=None): key=key)) log.debug("Send: %s", msg) self.ws.send(msg) - result = json.loads(self.ws.recv()) + rcv = self.ws.recv() + result = json.loads(rcv) log.debug("Recv: %s", result) assert result["status"] == 200 assert result["channelID"] == chid @@ -319,15 +322,19 @@ def setUp(self): ) # Endpoint HTTP router - self.ep = ep = EndpointApplication(ep_conf) - ep.db.client = db.g_client + self.ep = ep = EndpointApplication( + conf=ep_conf, + resource=autopush.tests.boto_resource + ) ep.setup(rotate_tables=False) ep.startService() self.addCleanup(ep.stopService) # Websocket server - self.conn = conn = ConnectionApplication(conn_conf) - conn.db.client = db.g_client + self.conn = conn = ConnectionApplication( + conf=conn_conf, + resource=autopush.tests.boto_resource + ) conn.setup(rotate_tables=False) conn.startService() self.addCleanup(conn.stopService) @@ -472,10 +479,10 @@ def test_legacy_simplepush_record(self): return_value={'router_type': 'simplepush', 'uaid': uaid, 'current_month': self.ep.db.current_msg_month}) - self.ep.db.message_tables[ - self.ep.db.current_msg_month].all_channels = Mock( - return_value=(True, client.channels)) + safe = db.Message.all_channels + db.Message.all_channels = Mock(return_value=(True, client.channels)) yield client.send_notification() + db.Message.all_channels = safe yield self.shut_down(client) @patch("autopush.metrics.datadog") @@ -537,12 +544,13 @@ def test_webpush_data_save_fail(self): yield client.hello() yield client.register(chid=chan) yield client.disconnect() - self.ep.db.message_tables[ - self.ep.db.current_msg_month].store_message = Mock( + safe = db.Message.store_message + db.Message.store_message = Mock( return_value=False) yield client.send_notification(channel=chan, data=test["data"], status=201) + db.Message.store_message = safe yield self.shut_down(client) @@ -1157,15 +1165,16 @@ def test_webpush_monthly_rotation(self): # Move the client back one month to the past last_month = make_rotating_tablename( prefix=self.conn.conf.message_table.tablename, delta=-1) - lm_message = self.conn.db.message_tables[last_month] + lm_message = Message(last_month, boto_resource=self.conn.db.resource) yield deferToThread( self.conn.db.router.update_message_month, client.uaid, - last_month + last_month, ) # Verify the move - c = yield deferToThread(self.conn.db.router.get_uaid, client.uaid) + c = yield deferToThread(self.conn.db.router.get_uaid, + client.uaid) assert c["current_month"] == last_month # Verify last_connect is current, then move that back @@ -1177,33 +1186,34 @@ def test_webpush_monthly_rotation(self): yield deferToThread( self.conn.db.router._update_last_connect, client.uaid, - last_connect, - ) - c = yield deferToThread(self.conn.db.router.get_uaid, client.uaid) + last_connect) + c = yield deferToThread(self.conn.db.router.get_uaid, + client.uaid) assert has_connected_this_month(c) is False # Move the clients channels back one month exists, chans = yield deferToThread( - self.conn.db.message.all_channels, client.uaid + self.conn.db.message.all_channels, + client.uaid ) assert exists is True assert len(chans) == 1 yield deferToThread( lm_message.save_channels, client.uaid, - chans + chans, ) # Remove the channels entry entirely from this month yield deferToThread( - self.conn.db.message.table.delete_item, + self.conn.db.message.table().delete_item, Key={'uaid': client.uaid, 'chidmessageid': ' '} ) # Verify the channel is gone exists, chans = yield deferToThread( self.conn.db.message.all_channels, - client.uaid + client.uaid, ) assert exists is False assert len(chans) == 0 @@ -1235,11 +1245,12 @@ def test_webpush_monthly_rotation(self): # Acknowledge the notification, which triggers the migration yield client.ack(chan, result["version"]) - # Wait up to 2 seconds for the table rotation to occur + # Wait up to 4 seconds for the table rotation to occur start = time.time() - while time.time()-start < 2: + while time.time()-start < 4: c = yield deferToThread( - self.conn.db.router.get_uaid, client.uaid) + self.conn.db.router.get_uaid, + client.uaid) if c["current_month"] == self.conn.db.current_msg_month: break else: @@ -1247,7 +1258,8 @@ def test_webpush_monthly_rotation(self): # Verify the month update in the router table c = yield deferToThread( - self.conn.db.router.get_uaid, client.uaid) + self.conn.db.router.get_uaid, + client.uaid) assert c["current_month"] == self.conn.db.current_msg_month assert server_client.ps.rotate_message_table is False @@ -1261,7 +1273,6 @@ def test_webpush_monthly_rotation(self): ) assert exists is True assert len(chans) == 1 - yield self.shut_down(client) @inlineCallbacks @@ -1273,15 +1284,17 @@ def test_webpush_monthly_rotation_prior_record_exists(self): # Move the client back one month to the past last_month = make_rotating_tablename( prefix=self.conn.conf.message_table.tablename, delta=-1) - lm_message = self.conn.db.message_tables[last_month] + lm_message = Message(last_month, + boto_resource=autopush.tests.boto_resource) yield deferToThread( self.conn.db.router.update_message_month, client.uaid, - last_month + last_month, ) # Verify the move - c = yield deferToThread(self.conn.db.router.get_uaid, client.uaid) + c = yield deferToThread(self.conn.db.router.get_uaid, + client.uaid) assert c["current_month"] == last_month # Verify last_connect is current, then move that back @@ -1290,21 +1303,22 @@ def test_webpush_monthly_rotation_prior_record_exists(self): yield deferToThread( self.conn.db.router._update_last_connect, client.uaid, - int("%s%s020001" % (today.year, str(today.month).zfill(2))) + int("%s%s020001" % (today.year, str(today.month).zfill(2))), ) c = yield deferToThread(self.conn.db.router.get_uaid, client.uaid) assert has_connected_this_month(c) is False # Move the clients channels back one month exists, chans = yield deferToThread( - self.conn.db.message.all_channels, client.uaid + self.conn.db.message.all_channels, + client.uaid, ) assert exists is True assert len(chans) == 1 yield deferToThread( lm_message.save_channels, client.uaid, - chans + chans, ) # Send in a notification, verify it landed in last months notification @@ -1334,11 +1348,12 @@ def test_webpush_monthly_rotation_prior_record_exists(self): # Acknowledge the notification, which triggers the migration yield client.ack(chan, result["version"]) - # Wait up to 2 seconds for the table rotation to occur + # Wait up to 4 seconds for the table rotation to occur start = time.time() - while time.time()-start < 2: + while time.time()-start < 4: c = yield deferToThread( - self.conn.db.router.get_uaid, client.uaid) + self.conn.db.router.get_uaid, + client.uaid) if c["current_month"] == self.conn.db.current_msg_month: break else: @@ -1359,7 +1374,6 @@ def test_webpush_monthly_rotation_prior_record_exists(self): ) assert exists is True assert len(chans) == 1 - yield self.shut_down(client) @inlineCallbacks @@ -1380,13 +1394,15 @@ def test_webpush_monthly_rotation_no_channels(self): ) # Verify the move - c = yield deferToThread(self.conn.db.router.get_uaid, client.uaid) + c = yield deferToThread(self.conn.db.router.get_uaid, + client.uaid + ) assert c["current_month"] == last_month # Verify there's no channels exists, chans = yield deferToThread( self.conn.db.message.all_channels, - client.uaid + client.uaid, ) assert exists is False assert len(chans) == 0 @@ -1403,17 +1419,19 @@ def test_webpush_monthly_rotation_no_channels(self): start = time.time() while time.time()-start < 2: c = yield deferToThread( - self.conn.db.router.get_uaid, client.uaid) + self.conn.db.router.get_uaid, + client.uaid, + ) if c["current_month"] == self.conn.db.current_msg_month: break else: yield deferToThread(time.sleep, 0.2) # Verify the month update in the router table - c = yield deferToThread(self.conn.db.router.get_uaid, client.uaid) + c = yield deferToThread(self.conn.db.router.get_uaid, + client.uaid) assert c["current_month"] == self.conn.db.current_msg_month assert server_client.ps.rotate_message_table is False - yield self.shut_down(client) @inlineCallbacks diff --git a/autopush/tests/test_main.py b/autopush/tests/test_main.py index 4b95aa07..d152095f 100644 --- a/autopush/tests/test_main.py +++ b/autopush/tests/test_main.py @@ -11,7 +11,11 @@ import autopush.db from autopush.config import AutopushConfig -from autopush.db import DatabaseManager, get_rotating_message_table +from autopush.db import ( + DatabaseManager, + get_rotating_message_tablename, + make_rotating_tablename, +) from autopush.exceptions import InvalidConfig from autopush.http import skip_request_logging from autopush.main import ( @@ -20,6 +24,7 @@ ) from autopush.tests.support import test_db from autopush.utils import resolve_ip +import autopush.tests connection_main = ConnectionApplication.main endpoint_main = EndpointApplication.main @@ -60,27 +65,30 @@ def test_update_rotating_tables(self): from autopush.db import get_month conf = AutopushConfig( hostname="example.com", resolve_hostname=True) - db = DatabaseManager.from_config(conf) + db = DatabaseManager.from_config( + conf, + resource=autopush.tests.boto_resource) db.create_initial_message_tables() # Erase the tables it has on init, and move current month back one last_month = get_month(-1) db.current_month = last_month.month - db.message_tables = {} + db.message_tables = [make_rotating_tablename("message", delta=-1), + make_rotating_tablename("message", delta=0)] # Create the next month's table, just in case today is the day before # a new month, in which case the lack of keys will cause an error in # update_rotating_tables next_month = get_month(1) - db.message_tables[next_month.month] = None + assert next_month.month not in db.message_tables # Get the deferred back e = Deferred() d = db.update_rotating_tables() def check_tables(result): - assert len(db.message_tables) == 2 assert db.current_month == get_month().month + assert len(db.message_tables) == 2 d.addCallback(check_tables) d.addBoth(lambda x: e.callback(True)) @@ -115,7 +123,9 @@ def test_update_rotating_tables_month_end(self): conf = AutopushConfig( hostname="example.com", resolve_hostname=True) - db = DatabaseManager.from_config(conf) + db = DatabaseManager.from_config( + conf, + resource=autopush.tests.boto_resource) db._tomorrow = Mock(return_value=tomorrow) db.create_initial_message_tables() @@ -123,37 +133,33 @@ def test_update_rotating_tables_month_end(self): assert len(db.message_tables) == 3 # Grab next month's table name and remove it - next_month = get_rotating_message_table( + next_month = get_rotating_message_tablename( conf.message_table.tablename, - delta=1 + delta=1, + boto_resource=db.resource ) - db.message_tables.pop(next_month.table_name) + db.message_tables.pop(db.message_tables.index(next_month)) # Get the deferred back d = db.update_rotating_tables() def check_tables(result): assert len(db.message_tables) == 3 - assert next_month.table_name in db.message_tables + assert next_month in db.message_tables d.addCallback(check_tables) return d def test_update_not_needed(self): - from autopush.db import get_month conf = AutopushConfig( hostname="google.com", resolve_hostname=True) - db = DatabaseManager.from_config(conf) + db = DatabaseManager.from_config( + conf, + resource=autopush.tests.boto_resource) db.create_initial_message_tables() # Erase the tables it has on init, and move current month back one - db.message_tables = {} - - # Create the next month's table, just in case today is the day before - # a new month, in which case the lack of keys will cause an error in - # update_rotating_tables - next_month = get_month(1) - db.message_tables[next_month.month] = None + db.message_tables = [] # Get the deferred back e = Deferred() @@ -184,7 +190,7 @@ def tearDown(self): mock.stop() def test_basic(self): - connection_main([], False) + connection_main([], False, resource=autopush.tests.boto_resource) def test_ssl(self): connection_main([ @@ -193,12 +199,12 @@ def test_ssl(self): "--ssl_key=keys/server.key", "--router_ssl_cert=keys/server.crt", "--router_ssl_key=keys/server.key", - ], False) + ], False, resource=autopush.tests.boto_resource) def test_memusage(self): connection_main([ "--memusage_port=8083", - ], False) + ], False, resource=autopush.tests.boto_resource) def test_skip_logging(self): # Should skip setting up logging on the handler @@ -258,6 +264,7 @@ class TestArg(AutopushConfig): use_cryptography = False sts_max_age = 1234 _no_sslcontext_cache = False + aws_ddb_endpoint = None def setUp(self): patchers = [ @@ -277,15 +284,18 @@ def tearDown(self): autopush.db.key_hash = "" def test_basic(self): - endpoint_main([ - ], False) + endpoint_main( + [], + False, + resource=autopush.tests.boto_resource + ) def test_ssl(self): endpoint_main([ "--ssl_dh_param=keys/dhparam.pem", "--ssl_cert=keys/server.crt", "--ssl_key=keys/server.key", - ], False) + ], False, resource=autopush.tests.boto_resource) def test_bad_senderidlist(self): returncode = endpoint_main([ @@ -306,18 +316,18 @@ def test_client_certs(self): "--ssl_cert=keys/server.crt", "--ssl_key=keys/server.key", '--client_certs={"foo": ["%s"]}' % cert - ], False) + ], False, resource=autopush.tests.boto_resource) assert not returncode def test_proxy_protocol_port(self): endpoint_main([ "--proxy_protocol_port=8081", - ], False) + ], False, resource=autopush.tests.boto_resource) def test_memusage(self): endpoint_main([ "--memusage_port=8083", - ], False) + ], False, resource=autopush.tests.boto_resource) def test_client_certs_parse(self): conf = AutopushConfig.from_argparse(self.TestArg) @@ -344,7 +354,8 @@ def test_bad_client_certs(self): @patch('hyper.tls', spec=hyper.tls) def test_conf(self, *args): conf = AutopushConfig.from_argparse(self.TestArg) - app = EndpointApplication(conf) + app = EndpointApplication(conf, + resource=autopush.tests.boto_resource) # verify that the hostname is what we said. assert conf.hostname == self.TestArg.hostname assert app.routers["gcm"].router_conf['collapsekey'] == "collapse" @@ -375,7 +386,7 @@ def test_gcm_start(self): endpoint_main([ "--gcm_enabled", """--senderid_list={"123":{"auth":"abcd"}}""", - ], False) + ], False, resource=autopush.tests.boto_resource) @patch("requests.get") def test_aws_ami_id(self, request_mock): diff --git a/autopush/tests/test_metrics.py b/autopush/tests/test_metrics.py index f3f076a0..ef33896a 100644 --- a/autopush/tests/test_metrics.py +++ b/autopush/tests/test_metrics.py @@ -3,14 +3,13 @@ import twisted.internet.base import pytest -from mock import Mock, call, patch +from mock import Mock, patch from autopush.metrics import ( IMetrics, DatadogMetrics, TwistedMetrics, SinkMetrics, - periodic_reporter ) @@ -72,19 +71,3 @@ def test_basic(self, mock_dog): m.timing("lifespan", 113) m._client.timing.assert_called_with("testpush.lifespan", value=113, host=hostname) - - -class PeriodicReporterTestCase(unittest.TestCase): - - def test_periodic_reporter(self): - metrics = Mock(spec=SinkMetrics) - periodic_reporter(metrics) - periodic_reporter(metrics, prefix='foo') - metrics.gauge.assert_has_calls([ - call('twisted.threadpool.idleWorkerCount', 0), - call('twisted.threadpool.busyWorkerCount', 0), - call('twisted.threadpool.backloggedWorkCount', 0), - call('foo.twisted.threadpool.idleWorkerCount', 0), - call('foo.twisted.threadpool.busyWorkerCount', 0), - call('foo.twisted.threadpool.backloggedWorkCount', 0), - ]) diff --git a/autopush/tests/test_router.py b/autopush/tests/test_router.py index 9e721585..cf25aa23 100644 --- a/autopush/tests/test_router.py +++ b/autopush/tests/test_router.py @@ -7,9 +7,8 @@ import requests import ssl -from autopush.utils import WebPushNotification -from mock import Mock, PropertyMock, patch import pytest +from mock import Mock, PropertyMock, patch from twisted.trial import unittest from twisted.internet.error import ConnectionRefusedError from twisted.internet.defer import inlineCallbacks @@ -35,6 +34,7 @@ from autopush.router.interface import RouterResponse, IRouter from autopush.tests import MockAssist from autopush.tests.support import test_db +from autopush.utils import WebPushNotification class RouterInterfaceTestCase(TestCase): @@ -999,7 +999,7 @@ def setUp(self): mock_result.not_registered = dict() mock_result.retry_after = 1000 self.router_mock = db.router - self.message_mock = db.message = Mock(spec=Message) + self.message_mock = db._message = Mock(spec=Message) self.conf = conf def test_route_to_busy_node_saves_looks_up_and_sends_check_201(self): @@ -1009,6 +1009,7 @@ def test_route_to_busy_node_saves_looks_up_and_sends_check_201(self): side_effect=MockAssist([202, 200])) self.message_mock.store_message.return_value = True self.message_mock.all_channels.return_value = (True, [dummy_chid]) + self.db.message_table = Mock(return_value=self.message_mock) router_data = dict(node_id="http://somewhere", uaid=dummy_uaid, current_month=self.db.current_msg_month) self.router_mock.get_uaid.return_value = router_data @@ -1033,6 +1034,7 @@ def test_route_failure(self): self.agent_mock.request = Mock(side_effect=ConnectionRefusedError) self.message_mock.store_message.return_value = True self.message_mock.all_channels.return_value = (True, [dummy_chid]) + self.db.message_table = Mock(return_value=self.message_mock) router_data = dict(node_id="http://somewhere", uaid=dummy_uaid, current_month=self.db.current_msg_month) self.router_mock.get_uaid.return_value = router_data @@ -1070,6 +1072,7 @@ def test_route_to_busy_node_with_ttl_zero(self): side_effect=MockAssist([202, 200])) self.message_mock.store_message.return_value = True self.message_mock.all_channels.return_value = (True, [dummy_chid]) + self.db.message_table = Mock(return_value=self.message_mock) router_data = dict(node_id="http://somewhere", uaid=dummy_uaid, current_month=self.db.current_msg_month) self.router_mock.get_uaid.return_value = router_data @@ -1100,9 +1103,8 @@ def throw(): self.agent_mock.request.return_value = response_mock = Mock() response_mock.code = 202 - self.message_mock.store_message.side_effect = MockAssist( - [throw] - ) + self.message_mock.store_message.side_effect = MockAssist([throw]) + self.db.message_table = Mock(return_value=self.message_mock) router_data = dict(node_id="http://somewhere", uaid=dummy_uaid, current_month=self.db.current_msg_month) @@ -1123,6 +1125,7 @@ def throw(): raise JSONResponseError(500, "Whoops") self.message_mock.store_message.return_value = True + self.db.message_table = Mock(return_value=self.message_mock) self.router_mock.get_uaid.side_effect = MockAssist( [throw] ) @@ -1144,6 +1147,7 @@ def throw(): raise ItemNotFound() self.message_mock.store_message.return_value = True + self.db.message_table = Mock(return_value=self.message_mock) self.router_mock.get_uaid.side_effect = MockAssist( [throw] ) @@ -1160,9 +1164,8 @@ def verify_deliver(status): def test_route_lookup_uaid_no_nodeid(self): self.message_mock.store_message.return_value = True - self.router_mock.get_uaid.return_value = dict( - - ) + self.db.message_table = Mock(return_value=self.message_mock) + self.router_mock.get_uaid.return_value = dict() router_data = dict(node_id="http://somewhere", uaid=dummy_uaid, current_month=self.db.current_msg_month) @@ -1179,6 +1182,7 @@ def test_route_and_clear_failure(self): self.agent_mock.request = Mock(side_effect=ConnectionRefusedError) self.message_mock.store_message.return_value = True self.message_mock.all_channels.return_value = (True, [dummy_chid]) + self.db.message_table = Mock(return_value=self.message_mock) router_data = dict(node_id="http://somewhere", uaid=dummy_uaid, current_month=self.db.current_msg_month) self.router_mock.get_uaid.return_value = router_data diff --git a/autopush/tests/test_web_base.py b/autopush/tests/test_web_base.py index dc56255d..fd1d2f71 100644 --- a/autopush/tests/test_web_base.py +++ b/autopush/tests/test_web_base.py @@ -195,8 +195,7 @@ def test_response_err(self): def test_overload_err(self): try: - import autopush.db - raise autopush.db.g_client.exceptions.ClientError( + raise ClientError( {'Error': { 'Code': 'ProvisionedThroughputExceededException'}}, 'mock_update_item' @@ -208,8 +207,7 @@ def test_overload_err(self): def test_client_err(self): try: - import autopush.db - raise autopush.db.g_client.exceptions.ClientError( + raise ClientError( {'Error': { 'Code': 'Flibbertygidgit'}}, 'mock_update_item' diff --git a/autopush/tests/test_web_validation.py b/autopush/tests/test_web_validation.py index 96a978b0..6122eed4 100644 --- a/autopush/tests/test_web_validation.py +++ b/autopush/tests/test_web_validation.py @@ -593,7 +593,7 @@ def test_no_current_month(self): assert cm.value.status_code == 410 assert cm.value.errno == 106 - assert cm.value.message == "Subscription elapsed" + assert cm.value.message == "No such subscription" def test_old_current_month(self): schema = self._make_fut() @@ -616,7 +616,7 @@ def test_old_current_month(self): assert cm.value.status_code == 410 assert cm.value.errno == 106 - assert cm.value.message == "Subscription expired" + assert cm.value.message == "No such subscription" class TestWebPushRequestSchemaUsingVapid(unittest.TestCase): diff --git a/autopush/tests/test_web_webpush.py b/autopush/tests/test_web_webpush.py index ca6327f5..204fe0cd 100644 --- a/autopush/tests/test_web_webpush.py +++ b/autopush/tests/test_web_webpush.py @@ -20,6 +20,7 @@ class TestWebpushHandler(unittest.TestCase): def setUp(self): + import autopush from autopush.web.webpush import WebPushHandler self.conf = conf = AutopushConfig( @@ -30,11 +31,13 @@ def setUp(self): self.fernet_mock = conf.fernet = Mock(spec=Fernet) self.db = db = test_db() - self.message_mock = db.message = Mock(spec=Message) + self.message_mock = db._message = Mock(spec=Message) + self.db.message_table = Mock(return_value=self.message_mock) self.message_mock.all_channels.return_value = (True, [dummy_chid]) app = EndpointHTTPFactory.for_handler(WebPushHandler, conf, db=db) self.wp_router_mock = app.routers["webpush"] = Mock(spec=IRouter) + self.db.router = Mock(spec=autopush.db.Router) self.client = Client(app) def url(self, **kwargs): @@ -48,16 +51,18 @@ def test_router_needs_update(self): public_key="asdfasdf", )) self.fernet_mock.decrypt.return_value = dummy_token - self.db.router.get_uaid.return_value = dict( + self.db.router.get_uaid = Mock(return_value=dict( router_type="webpush", router_data=dict(), uaid=dummy_uaid, current_month=self.db.current_msg_month, - ) + )) + self.db.router.register_user = Mock(return_value=False) self.wp_router_mock.route_notification.return_value = RouterResponse( status_code=503, router_data=dict(token="new_connect"), ) + self.db.message_tables.append(self.db.current_msg_month) resp = yield self.client.post( self.url(api_ver="v1", token=dummy_token), @@ -85,6 +90,7 @@ def test_router_returns_data_without_detail(self): status_code=503, router_data=dict(), ) + self.db.message_tables.append(self.db.current_msg_month) resp = yield self.client.post( self.url(api_ver="v1", token=dummy_token), diff --git a/autopush/tests/test_webpush_server.py b/autopush/tests/test_webpush_server.py index 2805ff37..367b7dfe 100644 --- a/autopush/tests/test_webpush_server.py +++ b/autopush/tests/test_webpush_server.py @@ -15,6 +15,7 @@ DatabaseManager, generate_last_connect, make_rotating_tablename, + Message, ) from autopush.metrics import SinkMetrics from autopush.config import AutopushConfig @@ -35,6 +36,7 @@ Unregister, WebPushMessage, ) +import autopush.tests class AutopushCall(object): @@ -170,14 +172,18 @@ def setUp(self): begin_or_register(self.logs) self.addCleanup(globalLogPublisher.removeObserver, self.logs) - self.db = db = DatabaseManager.from_config(self.conf) + self.db = db = DatabaseManager.from_config( + self.conf, + resource=autopush.tests.boto_resource) self.metrics = db.metrics = Mock(spec=SinkMetrics) db.setup_tables() def _store_messages(self, uaid, topic=False, num=5): try: item = self.db.router.get_uaid(uaid.hex) - message_table = self.db.message_tables[item["current_month"]] + message_table = Message( + item["current_month"], + boto_resource=autopush.tests.boto_resource) except ItemNotFound: message_table = self.db.message messages = [WebPushNotificationFactory(uaid=uaid) @@ -441,7 +447,9 @@ def test_migrate_user(self): # Check that it's there item = self.db.router.get_uaid(uaid) - _, channels = self.db.message_tables[last_month].all_channels(uaid) + _, channels = Message( + last_month, + boto_resource=self.db.resource).all_channels(uaid) assert item["current_month"] != self.db.current_msg_month assert item is not None assert len(channels) == 3 @@ -504,16 +512,18 @@ def test_register_bad_chid_nodash(self): self._test_invalid(uuid4().hex) def test_register_over_provisioning(self): + import autopush def raise_condition(*args, **kwargs): - import autopush.db - raise autopush.db.g_client.exceptions.ClientError( + from botocore.exceptions import ClientError + raise ClientError( {'Error': {'Code': 'ProvisionedThroughputExceededException'}}, 'mock_update_item' ) - self.db.message.table.update_item = Mock( - side_effect=raise_condition) + mock_table = Mock(spec=autopush.db.Message) + mock_table.register_channel = Mock(side_effect=raise_condition) + self.db.message_table = Mock(return_value=mock_table) self._test_invalid(str(uuid4()), "overloaded", 503) diff --git a/autopush/tests/test_websocket.py b/autopush/tests/test_websocket.py index 9f973dad..93c08ff5 100644 --- a/autopush/tests/test_websocket.py +++ b/autopush/tests/test_websocket.py @@ -8,6 +8,7 @@ import twisted.internet.base from autobahn.twisted.util import sleep from autobahn.websocket.protocol import ConnectionRequest +from botocore.exceptions import ClientError from mock import Mock, patch import pytest from twisted.internet import reactor @@ -25,7 +26,6 @@ from autopush.db import DatabaseManager from autopush.http import InternalRouterHTTPFactory from autopush.metrics import SinkMetrics -from autopush.tests import MockAssist from autopush.utils import WebPushNotification from autopush.tests.client import Client from autopush.tests.test_db import make_webpush_notification @@ -37,6 +37,7 @@ WebSocketServerProtocol, ) from autopush.utils import base64url_encode, ms_time +import autopush.tests dummy_version = (u'gAAAAABX_pXhN22H-hvscOHsMulKvtC0hKJimrZivbgQPFB3sQAtOPmb' @@ -101,7 +102,10 @@ def setUp(self): statsd_host=None, env="test", ) - db = DatabaseManager.from_config(conf) + db = DatabaseManager.from_config( + conf, + resource=autopush.tests.boto_resource + ) self.metrics = db.metrics = Mock(spec=SinkMetrics) db.setup_tables() @@ -381,7 +385,9 @@ def test_close_with_delivery_cleanup(self): self.proto.ps.direct_updates[chid] = [notif] # Apply some mocks - self.proto.db.message.store_message = Mock() + msg_mock = Mock(spec=db.Message) + msg_mock.store_message = Mock() + self.proto.db.message_table = Mock(return_value=msg_mock) self.proto.db.router.get_uaid = mock_get = Mock() mock_get.return_value = dict(node_id="localhost:2000") @@ -400,7 +406,9 @@ def test_close_with_delivery_cleanup_using_webpush(self): self.proto.ps.direct_updates[dummy_chid_str] = [dummy_notif()] # Apply some mocks - self.proto.db.message.store_message = Mock() + msg_mock = Mock(spec=db.Message) + msg_mock.store_message = Mock() + self.proto.db.message_table = Mock(return_value=msg_mock) self.proto.db.router.get_uaid = mock_get = Mock() mock_get.return_value = dict(node_id="localhost:2000") @@ -421,7 +429,9 @@ def test_close_with_delivery_cleanup_and_get_no_result(self): self.proto.ps.direct_updates[chid] = [notif] # Apply some mocks - self.proto.db.message.store_message = Mock() + msg_mock = Mock(spec=db.Message) + msg_mock.store_message = Mock() + self.proto.db.message_table = Mock(return_value=msg_mock) self.proto.db.router.get_uaid = mock_get = Mock() mock_get.return_value = False self.metrics.reset_mock() @@ -451,7 +461,7 @@ def test_hello_old(self): "current_month": msg_date, } router = self.proto.db.router - router.table.put_item( + router.table().put_item( Item=dict( uaid=orig_uaid, connected_at=ms_time(), @@ -460,7 +470,7 @@ def test_hello_old(self): ) ) - def fake_msg(data): + def fake_msg(data, **kwargs): return (True, msg_data) mock_msg = Mock(wraps=db.Message) @@ -473,12 +483,10 @@ def fake_msg(data): # notifications are irrelevant for this test. self.proto.process_notifications = Mock() # massage message_tables to include our fake range - mt = self.proto.ps.db.message_tables - for k in mt.keys(): - del(mt[k]) - mt['message_2016_1'] = mock_msg - mt['message_2016_2'] = mock_msg - mt['message_2016_3'] = mock_msg + self.proto.ps.db.message_tables = [ + 'message_2016_1', 'message_2016_2', 'message_2016_3' + ] + self.proto.ps.db.message_table = Mock(return_value=mock_msg) with patch.object(datetime, 'date', Mock(wraps=datetime.date)) as patched: patched.today.return_value = target_day @@ -521,7 +529,7 @@ def test_hello_tomorrow(self): "current_month": msg_date, } - def fake_msg(data): + def fake_msg(data, **kwargs): return (True, msg_data) mock_msg = Mock(wraps=db.Message) @@ -530,12 +538,10 @@ def fake_msg(data): mock_msg.all_channels.return_value = (None, []) self.proto.db.router.register_user = fake_msg # massage message_tables to include our fake range - mt = self.proto.ps.db.message_tables - for k in mt.keys(): - del(mt[k]) - mt['message_2016_1'] = mock_msg - mt['message_2016_2'] = mock_msg - mt['message_2016_3'] = mock_msg + self.proto.db.message_table = Mock(return_value=mock_msg) + self.proto.ps.db.message_tables = [ + 'message_2016_1', 'message_2016_2', 'message_2016_3' + ] with patch.object(datetime, 'date', Mock(wraps=datetime.date)) as patched: patched.today.return_value = target_day @@ -559,10 +565,11 @@ def fake_msg(data): def test_hello_tomorrow_provision_error(self): orig_uaid = "deadbeef00000000abad1dea00000000" router = self.proto.db.router + current_month = "message_2016_3" router.register_user(dict( uaid=orig_uaid, connected_at=ms_time(), - current_month="message_2016_3", + current_month=current_month, router_type="webpush", )) @@ -580,36 +587,46 @@ def test_hello_tomorrow_provision_error(self): "current_month": msg_date, } - def fake_msg(data): - return (True, msg_data) - mock_msg = Mock(wraps=db.Message) mock_msg.fetch_messages.return_value = "01;", [] mock_msg.fetch_timestamp_messages.return_value = None, [] mock_msg.all_channels.return_value = (None, []) - self.proto.db.router.register_user = fake_msg + self.proto.db.router.register_user = Mock( + return_value=(True, msg_data) + ) # massage message_tables to include our fake range - mt = self.proto.ps.db.message_tables - mt.clear() - mt['message_2016_1'] = mock_msg - mt['message_2016_2'] = mock_msg - mt['message_2016_3'] = mock_msg - + self.proto.ps.db.message_tables = [ + 'message_2016_1', 'message_2016_2', current_month + ] + self.proto.db.message_table = Mock(return_value=mock_msg) patch_range = patch("autopush.websocket.randrange") mock_patch = patch_range.start() mock_patch.return_value = 1 + self.flag = True def raise_condition(*args, **kwargs): - import autopush.db - raise autopush.db.g_client.exceptions.ClientError( - {'Error': {'Code': 'ProvisionedThroughputExceededException'}}, - 'mock_update_item' - ) - - self.proto.db.router.update_message_month = MockAssist([ - raise_condition, - Mock(), - ]) + if self.flag: + self.flag = False + raise ClientError( + {'Error': + {'Code': 'ProvisionedThroughputExceededException'}}, + 'mock_update_item' + ) + + self.proto.db.register_user = Mock(return_value=(False, {})) + mock_router = Mock(spec=db.Router) + mock_router.register_user = Mock(return_value=(True, msg_data)) + mock_router.update_message_month = Mock(side_effect=raise_condition) + self.proto.db.router = mock_router + self.proto.db.router.get_uaid = Mock(return_value={ + "router_type": "webpush", + "connected_at": int(msg_day.strftime("%s")), + "current_month": 'message_2016_2', + "last_connect": int(msg_day.strftime("%s")), + "record_version": 1, + }) + self.proto.db.current_msg_month = current_month + self.proto.ps.message_month = current_month with patch.object(datetime, 'date', Mock(wraps=datetime.date)) as patched: @@ -629,7 +646,8 @@ def raise_condition(*args, **kwargs): # Wait to see that the message table gets rotated yield self._wait_for( - lambda: not self.proto.ps.rotate_message_table + lambda: not self.proto.ps.rotate_message_table, + duration=5000 ) assert self.proto.ps.rotate_message_table is False finally: @@ -711,7 +729,9 @@ def test_hello_failure(self): self._connect() # Fail out the register_user call router = self.proto.db.router - router.table.update_item = Mock(side_effect=KeyError) + mock_up = Mock() + mock_up.update_item = Mock(side_effect=KeyError) + router.table = Mock(return_value=mock_up) self._send_message(dict(messageType="hello", channelIDs=[], use_webpush=True, stop=1)) @@ -727,15 +747,15 @@ def test_hello_provisioned_during_check(self): # Fail out the register_user call def raise_condition(*args, **kwargs): - import autopush.db - raise autopush.db.g_client.exceptions.ClientError( + raise ClientError( {'Error': {'Code': 'ProvisionedThroughputExceededException'}}, 'mock_update_item' ) router = self.proto.db.router - router.table.update_item = Mock(side_effect=raise_condition) - + mock_table = Mock() + mock_table.update_item = Mock(side_effect=raise_condition) + router.table = Mock(return_value=mock_table) self._send_message(dict(messageType="hello", use_webpush=True, channelIDs=[])) msg = yield self.get_response() @@ -911,10 +931,12 @@ def test_register_webpush(self): self._connect() chid = str(uuid.uuid4()) self.proto.ps.uaid = uuid.uuid4().hex - self.proto.db.message.register_channel = Mock() + msg_mock = Mock(spec=db.Message) + msg_mock.register_channel = Mock() + self.proto.db.message_table = Mock(return_value=msg_mock) yield self.proto.process_register(dict(channelID=chid)) - assert self.proto.db.message.register_channel.called + assert msg_mock.register_channel.called assert_called_included(self.proto.log.info, format="Register") @inlineCallbacks @@ -922,7 +944,8 @@ def test_register_webpush_with_key(self): self._connect() chid = str(uuid.uuid4()) self.proto.ps.uaid = uuid.uuid4().hex - self.proto.db.message.register_channel = Mock() + msg_mock = Mock(spec=db.Message) + self.proto.db.message_table = Mock(return_value=msg_mock) test_key = "SomeRandomCryptoKeyString" test_sha = sha256(test_key).hexdigest() test_endpoint = ('http://localhost/wpush/v2/' + @@ -942,7 +965,7 @@ def echo(string): ) assert test_endpoint == self.proto.sendJSON.call_args[0][0][ 'pushEndpoint'] - assert self.proto.db.message.register_channel.called + assert msg_mock.register_channel.called assert_called_included(self.proto.log.info, format="Register") @inlineCallbacks @@ -1046,11 +1069,12 @@ def test_register_over_provisioning(self): self._connect() chid = str(uuid.uuid4()) self.proto.ps.uaid = uuid.uuid4().hex - self.proto.db.message.register_channel = register = Mock() + msg_mock = Mock(spec=db.Message) + msg_mock.register_channel = register = Mock() + self.proto.ps.db.message_table = Mock(return_value=msg_mock) def raise_condition(*args, **kwargs): - import autopush.db - raise autopush.db.g_client.exceptions.ClientError( + raise ClientError( {'Error': {'Code': 'ProvisionedThroughputExceededException'}}, 'mock_update_item' ) @@ -1058,7 +1082,7 @@ def raise_condition(*args, **kwargs): register.side_effect = raise_condition yield self.proto.process_register(dict(channelID=chid)) - assert self.proto.db.message.register_channel.called + assert msg_mock.register_channel.called assert self.send_mock.called args, _ = self.send_mock.call_args msg = json.loads(args[0]) @@ -1305,9 +1329,11 @@ def test_process_notifications(self): self.proto.ps.uaid = uuid.uuid4().hex # Swap out fetch_notifications - self.proto.db.message.fetch_messages = Mock( + msg_mock = Mock(spec=db.Message) + msg_mock.fetch_messages = Mock( return_value=(None, []) ) + self.proto.ps.db.message_table = Mock(return_value=msg_mock) self.proto.process_notifications() @@ -1337,13 +1363,12 @@ def throw(*args, **kwargs): twisted.internet.base.DelayedCall.debug = True self._connect() - self.proto.db.message.fetch_messages = Mock( - return_value=(None, []) - ) - self.proto.db.message.fetch_messages = Mock( + msg_mock = Mock(spec=db.Message) + msg_mock.fetch_messages = Mock( side_effect=throw) - self.proto.db.message.fetch_timestamp_messages = Mock( + msg_mock.fetch_timestamp_messages = Mock( side_effect=throw) + self.proto.db.message_table = Mock(return_value=msg_mock) self.proto.ps.uaid = uuid.uuid4().hex self.proto.process_notifications() @@ -1367,21 +1392,19 @@ def wait(result): def test_process_notifications_provision_err(self): def raise_condition(*args, **kwargs): - import autopush.db - raise autopush.db.g_client.exceptions.ClientError( + raise ClientError( {'Error': {'Code': 'ProvisionedThroughputExceededException'}}, 'mock_update_item' ) twisted.internet.base.DelayedCall.debug = True self._connect() - self.proto.db.message.fetch_messages = Mock( - return_value=(None, []) - ) - self.proto.db.message.fetch_messages = Mock( + msg_mock = Mock(spec=db.Message) + msg_mock.fetch_messages = Mock( side_effect=raise_condition) - self.proto.db.message.fetch_timestamp_messages = Mock( + msg_mock.fetch_timestamp_messages = Mock( side_effect=raise_condition) + self.proto.db.message_table = Mock(return_value=msg_mock) self.proto.deferToLater = Mock() self.proto.ps.uaid = uuid.uuid4().hex @@ -1488,7 +1511,9 @@ def test_notif_finished_with_too_many_messages(self): self.proto.ps.uaid = uuid.uuid4().hex self.proto.ps._check_notifications = True self.proto.db.router.drop_user = Mock() - self.proto.ps.message.fetch_messages = Mock() + msg_mock = Mock() + msg_mock.fetch_messages = Mock() + self.proto.ps.db.message_table = Mock(return_value=msg_mock) notif = make_webpush_notification( self.proto.ps.uaid, @@ -1496,7 +1521,7 @@ def test_notif_finished_with_too_many_messages(self): ttl=500 ) self.proto.ps.updates_sent = defaultdict(lambda: []) - self.proto.ps.message.fetch_messages.return_value = ( + msg_mock.fetch_messages.return_value = ( None, [notif, notif, notif] ) diff --git a/autopush/web/health.py b/autopush/web/health.py index dd120668..59665a61 100644 --- a/autopush/web/health.py +++ b/autopush/web/health.py @@ -33,14 +33,15 @@ def get(self): ) dl = DeferredList([ - self._check_table(self.db.router.table), - self._check_table(self.db.message.table, "storage") + self._check_table(self.db.router.table()), + self._check_table(self.db.message.table(), "storage"), ]) dl.addBoth(self._finish_response) def _check_table(self, table, name_over=None): """Checks the tables known about in DynamoDB""" - d = deferToThread(table_exists, table.table_name, self.db.client) + d = deferToThread(table_exists, table.table_name, + boto_resource=self.db.resource) d.addCallback(self._check_success, name_over or table.table_name) d.addErrback(self._check_error, name_over or table.table_name) return d @@ -58,12 +59,12 @@ def _check_error(self, failure, name): self.log.failure(format=fmt, failure=failure, name=name) cause = self._health_checks[name] = {"status": "NOT OK"} - if failure.check(InternalServerError): + if failure.check(MissingTableException): + cause["error"] = failure.value.message + elif failure.check(InternalServerError): # pragma nocover cause["error"] = "Server error" - elif failure.check(MissingTableException): - cause["error"] = failure.getErrorMessage() else: - cause["error"] = "Internal error" + cause["error"] = "Internal error" # pragma nocover def _finish_response(self, results): """Returns whether the check succeeded or not""" diff --git a/autopush/web/registration.py b/autopush/web/registration.py index b94283b4..235a6469 100644 --- a/autopush/web/registration.py +++ b/autopush/web/registration.py @@ -149,8 +149,7 @@ def validate_auth(self, data): auth_type, auth_token = re.sub( r' +', ' ', auth.strip()).split(" ", 2) except ValueError: - raise InvalidRequest("Invalid Authentication", - status_code=401, + raise InvalidRequest("Invalid Authentication", status_code=401, errno=109, headers=request_pref_header) if auth_type.lower() not in AUTH_SCHEMES: @@ -484,4 +483,4 @@ def _chid_not_found_err(self, fail): self.log.debug(format="CHID not found in AWS.", status_code=410, errno=106, **self._client_info) - self._write_response(410, 106, message="Invalid endpoint for user.") + self._write_response(410, 106, message="Invalid endpoint.") diff --git a/autopush/web/webpush.py b/autopush/web/webpush.py index 50294fe5..954198b8 100644 --- a/autopush/web/webpush.py +++ b/autopush/web/webpush.py @@ -97,8 +97,7 @@ def validate_uaid_month_and_chid(self, d): self.context["metrics"].increment("updates.drop_user", tags=make_tags(errno=102)) self.context["db"].router.drop_user(result["uaid"]) - raise InvalidRequest("No route for subscription", - status_code=410, + raise InvalidRequest("No such subscription", status_code=410, errno=106) if (router_type in ["gcm", "fcm"] @@ -138,8 +137,7 @@ def _validate_webpush(self, d, result): uaid_record=repr(result)) metrics.increment("updates.drop_user", tags=make_tags(errno=102)) db.router.drop_user(uaid) - raise InvalidRequest("Subscription elapsed", - status_code=410, + raise InvalidRequest("No such subscription", status_code=410, errno=106) month_table = result["current_month"] @@ -149,17 +147,16 @@ def _validate_webpush(self, d, result): uaid_record=repr(result)) metrics.increment("updates.drop_user", tags=make_tags(errno=103)) db.router.drop_user(uaid) - raise InvalidRequest("Subscription expired", - status_code=410, + raise InvalidRequest("No such subscription", status_code=410, errno=106) - exists, chans = db.message_tables[month_table].all_channels(uaid=uaid) + msg = db.message_table(month_table) + exists, chans = msg.all_channels(uaid=uaid) if (not exists or channel_id.lower() not in map(lambda x: normalize_id(x), chans)): log.debug("Unknown subscription: {channel_id}", channel_id=channel_id) - raise InvalidRequest("No such subscription for user", - status_code=410, + raise InvalidRequest("No such subscription", status_code=410, errno=106) diff --git a/autopush/webpush_server.py b/autopush/webpush_server.py index 596cfe6e..a738abf4 100644 --- a/autopush/webpush_server.py +++ b/autopush/webpush_server.py @@ -25,6 +25,7 @@ has_connected_this_month, hasher, generate_last_connect, + Message, ) from autopush.config import AutopushConfig # noqa @@ -459,10 +460,11 @@ def process(self, command): def _check_storage(self, command): timestamp = None messages = [] - message = self.db.message_tables[command.message_month] + message = Message(command.message_month, + boto_resource=self.db.resource) if command.include_topic: timestamp, messages = message.fetch_messages( - uaid=command.uaid, limit=11, + uaid=command.uaid, limit=11 ) # If we have topic messages, return them immediately @@ -488,7 +490,8 @@ def _check_storage(self, command): class IncrementStorageCommand(ProcessorCommand): def process(self, command): # type: (IncStoragePosition) -> IncStoragePositionResponse - message = self.db.message_tables[command.message_month] + message = Message(command.message_month, + boto_resource=self.db.resource) message.update_last_message_read(command.uaid, command.timestamp) return IncStoragePositionResponse() @@ -497,7 +500,8 @@ class DeleteMessageCommand(ProcessorCommand): def process(self, command): # type: (DeleteMessage) -> DeleteMessageResponse notif = command.message.to_WebPushNotification() - message = self.db.message_tables[command.message_month] + message = Message(command.message_month, + boto_resource=self.db.resource) message.delete_message(notif) return DeleteMessageResponse() @@ -513,25 +517,30 @@ class MigrateUserCommand(ProcessorCommand): def process(self, command): # type: (MigrateUser) -> MigrateUserResponse # Get the current channels for this month - message = self.db.message_tables[command.message_month] + message = Message(command.message_month, + boto_resource=self.db.resource) _, channels = message.all_channels(command.uaid.hex) # Get the current message month cur_month = self.db.current_msg_month if channels: # Save the current channels into this months message table - msg_table = self.db.message_tables[cur_month] - msg_table.save_channels(command.uaid.hex, channels) + msg_table = Message(cur_month, + boto_resource=self.db.resource) + msg_table.save_channels(command.uaid.hex, + channels) # Finally, update the route message month - self.db.router.update_message_month(command.uaid.hex, cur_month) + self.db.router.update_message_month(command.uaid.hex, + cur_month) return MigrateUserResponse(message_month=cur_month) class StoreMessagesUserCommand(ProcessorCommand): def process(self, command): # type: (StoreMessages) -> StoreMessagesResponse - message = self.db.message_tables[command.message_month] + message = Message(command.message_month, + boto_resource=self.db.resource) for m in command.messages: if "topic" not in m: m["topic"] = None @@ -585,15 +594,15 @@ def process(self, command): command.channel_id, command.key ) - message = self.db.message_tables[command.message_month] + message = self.db.message_table(command.message_month) try: - message.register_channel(command.uaid.hex, command.channel_id) + message.register_channel(command.uaid.hex, + command.channel_id) except ClientError as ex: if (ex.response['Error']['Code'] == "ProvisionedThroughputExceededException"): return RegisterErrorResponse(error_msg="overloaded", status=503) - self.metrics.increment('ua.command.register') log.info( "Register", @@ -634,12 +643,12 @@ def process(self, if not valid: return UnregisterErrorResponse(error_msg=msg) - message = self.db.message_tables[command.message_month] + message = Message(command.message_month, + boto_resource=self.db.resource) # TODO: JSONResponseError not handled (no force_retry) message.unregister_channel(command.uaid.hex, command.channel_id) - # TODO: Clear out any existing tracked messages for this - # channel + # TODO: Clear out any existing tracked messages for this channel self.metrics.increment('ua.command.unregister') # TODO: user/raw_agent? diff --git a/autopush/websocket.py b/autopush/websocket.py index a3de965a..8b17e756 100644 --- a/autopush/websocket.py +++ b/autopush/websocket.py @@ -245,12 +245,6 @@ def __attrs_post_init__(self): self.reset_uaid = False - @property - def message(self): - # type: () -> Message - """Property to access the currently used message table""" - return self.db.message_tables[self.message_month] - @property def user_agent(self): # type: () -> str @@ -596,7 +590,8 @@ def cleanUp(self, wasClean, code, reason): def _save_webpush_notif(self, notif): """Save a direct_update webpush style notification""" - return deferToThread(self.ps.message.store_message, + message = self.db.message_table(self.ps.message_month) + return deferToThread(message.store_message, notif).addErrback(self.log_failure) def _lookup_node(self, results): @@ -648,7 +643,7 @@ def returnError(self, messageType, reason, statusCode, close=True, if close: self.sendClose() - def err_overload(self, failure, message_type, disconnect=True): + def error_overload(self, failure, message_type, disconnect=True): """Handle database overloads and errors If ``disconnect`` is False, the an overload error is returned and the @@ -668,7 +663,7 @@ def err_overload(self, failure, message_type, disconnect=True): if disconnect: self.transport.pauseProducing() d = self.deferToLater(self.randrange(4, 9), - self.err_finish_overload, message_type) + self.error_finish_overload, message_type) d.addErrback(self.trap_cancel) else: if (failure.value.response["Error"]["Code"] != @@ -678,7 +673,7 @@ def err_overload(self, failure, message_type, disconnect=True): "status": 503} self.sendJSON(send) - def err_finish_overload(self, message_type): + def error_finish_overload(self, message_type): """Close the connection down and resume consuming input after the random interval from a db overload""" # Resume producing so we can finish the shutdown @@ -716,8 +711,8 @@ def process_hello(self, data): d = self.deferToThread(self._register_user, existing_user) d.addCallback(self._check_other_nodes) d.addErrback(self.trap_cancel) - d.addErrback(self.err_overload, "hello") - d.addErrback(self.err_hello) + d.addErrback(self.error_overload, "hello") + d.addErrback(self.error_hello) self.ps._register = d return d @@ -779,8 +774,7 @@ def _verify_user_record(self): self.log.debug(format="Dropping User", code=105, uaid_hash=self.ps.uaid_hash, uaid_record=repr(record)) - self.force_retry(self.db.router.drop_user, - self.ps.uaid) + self.force_retry(self.db.router.drop_user, self.ps.uaid) tags = ['code:105'] self.metrics.increment("ua.expiration", tags=tags) return None @@ -806,7 +800,7 @@ def _verify_user_record(self): record["connected_at"] = self.ps.connected_at return record - def err_hello(self, failure): + def error_hello(self, failure): """errBack for hello failures""" self.transport.resumeProducing() self.log_failure(failure) @@ -903,11 +897,14 @@ def process_notifications(self): def webpush_fetch(self): """Helper to return an appropriate function to fetch messages""" + message = self.db.message_table(self.ps.message_month) if self.ps.scan_timestamps: - return partial(self.ps.message.fetch_timestamp_messages, - self.ps.uaid_obj, self.ps.current_timestamp) + return partial(message.fetch_timestamp_messages, + self.ps.uaid_obj, + self.ps.current_timestamp) else: - return partial(self.ps.message.fetch_messages, self.ps.uaid_obj) + return partial(message.fetch_messages, + self.ps.uaid_obj) def error_notifications(self, fail): """errBack for notification check failing""" @@ -990,13 +987,15 @@ def finish_webpush_notifications(self, result): # Send out all the notifications now = int(time.time()) messages_sent = False + message = self.db.message_table(self.ps.message_month) for notif in notifs: self.ps.stats.stored_retrieved += 1 # If the TTL is too old, don't deliver and fire a delete off if notif.expired(at_time=now): if not notif.sortkey_timestamp: # Delete non-timestamped messages - self.force_retry(self.ps.message.delete_message, notif) + self.force_retry(message.delete_message, + notif) # nocover here as coverage gets confused on the line below # for unknown reasons @@ -1010,7 +1009,8 @@ def finish_webpush_notifications(self, result): raise MessageOverloadException() if notif.topic: self.metrics.increment("ua.notification.topic") - self.metrics.increment('ua.message_data', len(msg.get('data', '')), + self.metrics.increment('ua.message_data', + len(msg.get('data', '')), tags=make_tags(source=notif.source)) self.sendJSON(msg) @@ -1021,8 +1021,9 @@ def finish_webpush_notifications(self, result): # No messages sent, update the record if needed if self.ps.current_timestamp: self.force_retry( - self.ps.message.update_last_message_read, - self.ps.uaid_obj, self.ps.current_timestamp) + message.update_last_message_read, + self.ps.uaid_obj, + self.ps.current_timestamp) # Schedule a new process check self.check_missed_notifications(None) @@ -1047,13 +1048,14 @@ def _monthly_transition(self): """ # Get the current channels for this month - _, channels = self.ps.message.all_channels(self.ps.uaid) + message = self.db.message_table(self.ps.message_month) + _, channels = message.all_channels(self.ps.uaid) # Get the current message month cur_month = self.db.current_msg_month if channels: # Save the current channels into this months message table - msg_table = self.db.message_tables[cur_month] + msg_table = self.db.message_table(cur_month) msg_table.save_channels(self.ps.uaid, channels) # Finally, update the route message month @@ -1069,7 +1071,7 @@ def _finish_monthly_transition(self, result): def error_monthly_rotation_overload(self, fail): """Capture overload on monthly table rotation attempt - If a provision exdeeded error hits while attempting monthly table + If a provision exceeded error hits while attempting monthly table rotation, schedule it all over and re-scan the messages. Normal websocket client flow is returned in the meantime. @@ -1137,12 +1139,13 @@ def error_register(self, fail): def finish_register(self, endpoint, chid): """callback for successful endpoint creation, sends register reply""" - d = self.deferToThread(self.ps.message.register_channel, self.ps.uaid, + message = self.db.message_table(self.ps.message_month) + d = self.deferToThread(message.register_channel, self.ps.uaid, chid) d.addCallback(self.send_register_finish, endpoint, chid) # Note: No trap_cancel needed here since the deferred here is # returned to process_register which will trap it - d.addErrback(self.err_overload, "register", disconnect=False) + d.addErrback(self.error_overload, "register", disconnect=False) return d def send_register_finish(self, result, endpoint, chid): @@ -1180,7 +1183,8 @@ def process_unregister(self, data): self.ps.updates_sent[chid] = [] # Unregister the channel - self.force_retry(self.ps.message.unregister_channel, self.ps.uaid, + message = self.db.message_table(self.ps.message_month) + self.force_retry(message.unregister_channel, self.ps.uaid, chid) data["status"] = 200 @@ -1238,6 +1242,7 @@ def ver_filter(notif): **self.ps.raw_agent) self.ps.stats.stored_acked += 1 + message = self.db.message_table(self.ps.message_month) if msg.sortkey_timestamp: # Is this the last un-acked message we're waiting for? last_unacked = sum( @@ -1249,7 +1254,7 @@ def ver_filter(notif): # If it's the last message in the batch, or last un-acked # message d = self.force_retry( - self.ps.message.update_last_message_read, + message.update_last_message_read, self.ps.uaid_obj, self.ps.current_timestamp, ) d.addBoth(self._handle_webpush_update_remove, chid, msg) @@ -1260,7 +1265,7 @@ def ver_filter(notif): d = None else: # No sortkey_timestamp, so legacy/topic message, delete - d = self.force_retry(self.ps.message.delete_message, msg) + d = self.force_retry(message.delete_message, msg) # We don't remove the update until we know the delete ran # This is because we don't use range queries on dynamodb and # we need to make sure this notification is deleted from the diff --git a/autopush_rs/Cargo.lock b/autopush_rs/Cargo.lock index 400b6032..3830a858 100644 --- a/autopush_rs/Cargo.lock +++ b/autopush_rs/Cargo.lock @@ -1,4 +1,21 @@ -[root] +[[package]] +name = "advapi32-sys" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "winapi 0.2.8 (registry+https://github.com/rust-lang/crates.io-index)", + "winapi-build 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] +name = "aho-corasick" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "memchr 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)", +] + +[[package]] name = "autopush" version = "0.1.0" dependencies = [ @@ -34,23 +51,6 @@ dependencies = [ "woothee 0.7.3 (registry+https://github.com/rust-lang/crates.io-index)", ] -[[package]] -name = "advapi32-sys" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -dependencies = [ - "winapi 0.2.8 (registry+https://github.com/rust-lang/crates.io-index)", - "winapi-build 0.1.1 (registry+https://github.com/rust-lang/crates.io-index)", -] - -[[package]] -name = "aho-corasick" -version = "0.6.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -dependencies = [ - "memchr 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)", -] - [[package]] name = "backtrace" version = "0.3.4" diff --git a/docs/install.rst b/docs/install.rst index 2fb98226..1c60f253 100644 --- a/docs/install.rst +++ b/docs/install.rst @@ -40,7 +40,7 @@ Or a Debian based system (like Ubuntu): $ sudo apt-get install build-essential libffi-dev \ libssl-dev pypy-dev python-virtualenv git --assume-yes -Autopush uses the `Boto python library`_. Be sure to `properly set up your boto +Autopush uses the `Boto3 python library`_. Be sure to `properly set up your boto config file`_. Notes on OS X @@ -138,11 +138,10 @@ An example `boto config file`_ is provided in ``automock/boto.cfg`` that directs autopush to your local DynamoDB instance. .. _Mozilla Push Service - Code Development: http://mozilla-push-service.readthedocs.io/en/latest/development/#code-development -.. _`boto config file`: https://boto.readthedocs.io/en/latest/boto_config_tut.html +.. _`boto config file`: http://boto3.readthedocs.io/en/docs/guide/quickstart.html#configuration .. _`Local DynamoDB Java server`: http://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Tools.DynamoDBLocal.html -.. _`Boto python library`: https://boto.readthedocs.io/en/latest/ -.. _`properly set up your boto config file`: - https://boto.readthedocs.io/en/latest/boto_config_tut.html +.. _`Boto3 python library`: https://boto3.readthedocs.io/en/latest/ +.. _`properly set up your boto config file`: http://boto3.readthedocs.io/en/docs/guide/quickstart.html#configuration .. _`cryptography`: https://cryptography.io/en/latest/installation .. toctree::