From b3376e2c03948137bdb71b41e9ca0df2e4ed0dea Mon Sep 17 00:00:00 2001 From: jr conlin Date: Thu, 30 Nov 2017 18:12:46 -0800 Subject: [PATCH] bug: make boto3 calls thread safe Closes #1081 --- autopush/config.py | 3 + autopush/constants.py | 4 + autopush/db.py | 421 ++++++++++++++++---------- autopush/diagnostic_cli.py | 17 +- autopush/main.py | 67 ++-- autopush/main_argparse.py | 4 + autopush/router/webpush.py | 2 +- autopush/scripts/drop_user.py | 25 +- autopush/tests/__init__.py | 32 +- autopush/tests/support.py | 4 +- autopush/tests/test_db.py | 307 ++++++++++++------- autopush/tests/test_diagnostic_cli.py | 32 +- autopush/tests/test_endpoint.py | 4 +- autopush/tests/test_health.py | 77 +---- autopush/tests/test_integration.py | 92 +++--- autopush/tests/test_main.py | 66 ++-- autopush/tests/test_router.py | 22 +- autopush/tests/test_web_base.py | 6 +- autopush/tests/test_web_webpush.py | 12 +- autopush/tests/test_webpush_server.py | 24 +- autopush/tests/test_websocket.py | 150 +++++---- autopush/web/health.py | 9 +- autopush/web/webpush.py | 3 +- autopush/webpush_server.py | 40 ++- autopush/websocket.py | 73 +++-- autopush_rs/Cargo.lock | 2 +- 26 files changed, 875 insertions(+), 623 deletions(-) create mode 100644 autopush/constants.py diff --git a/autopush/config.py b/autopush/config.py index 992523df..35a75038 100644 --- a/autopush/config.py +++ b/autopush/config.py @@ -176,6 +176,9 @@ class AutopushConfig(object): # Don't cache ssl.wrap_socket's SSLContexts no_sslcontext_cache = attrib(default=False) # type: bool + # DynamoDB endpoint override + aws_ddb_endpoint = attrib(default=None) # type: str + def __attrs_post_init__(self): """Initialize the Settings object""" # Setup hosts/ports/urls diff --git a/autopush/constants.py b/autopush/constants.py new file mode 100644 index 00000000..e8bc31fd --- /dev/null +++ b/autopush/constants.py @@ -0,0 +1,4 @@ +"""Shared constants which might produce circular includes""" + +# Number of concurrent threads / AWS Session resources +THREAD_POOL_SIZE = 50 diff --git a/autopush/db.py b/autopush/db.py index 2940c564..1def4089 100644 --- a/autopush/db.py +++ b/autopush/db.py @@ -36,20 +36,23 @@ import random import time import uuid +from collections import deque from functools import wraps from attr import ( asdict, attrs, attrib, - Factory -) + Factory, + Attribute) from boto.dynamodb2.exceptions import ( ItemNotFound, ) import boto3 import botocore +from boto.dynamodb2.table import Table # noqa +from boto3.resources.base import ServiceResource # noqa from boto3.dynamodb.conditions import Key from boto3.exceptions import Boto3Error from botocore.exceptions import ClientError @@ -73,6 +76,7 @@ from twisted.internet.threads import deferToThread import autopush.metrics +from autopush import constants from autopush.exceptions import AutopushException from autopush.metrics import IMetrics # noqa from autopush.types import ItemLike # noqa @@ -96,16 +100,53 @@ TRACK_DB_CALLS = False DB_CALLS = [] -# See https://botocore.readthedocs.io/en/stable/reference/config.html for -# additional config options -g_dynamodb = boto3.resource( - 'dynamodb', - config=botocore.config.Config( - region_name=os.getenv("AWS_REGION_NAME", "us-east-1") - ), - endpoint_url=os.getenv("AWS_LOCAL_DYNAMODB") -) -g_client = g_dynamodb.meta.client +DDB_SESSION_ARGS = {} +MAX_DDB_SESSIONS = constants.THREAD_POOL_SIZE + + +class BotoResources(object): + """A pool of boto3 Resources. + + Boto3 resources are NOT thread safe. + + """ + def __init__(self, conf=None): + if conf is None: + conf = DDB_SESSION_ARGS # pragma: nocover + self.pool = deque(maxlen=MAX_DDB_SESSIONS) + if not conf.get("endpoint_url") and os.getenv("AWS_LOCAL_DYNAMODB"): + conf["endpoint_url"] = os.getenv("AWS_LOCAL_DYNAMODB") + if "endpoint_url" in conf and not conf["endpoint_url"]: + # If there's no endpoint URL value, we must delete the data + # entirely + del(conf["endpoint_url"]) + self.pool.extendleft((boto3.session.Session().resource( + 'dynamodb', + config=botocore.config.Config( + region_name=os.getenv("AWS_REGION_NAME", "us-east-1") + ), + **conf + ) for x in range(0, MAX_DDB_SESSIONS))) + + def fetch(self): + """Return a boto resource. This MUST be released after use.""" + # type: () -> ServiceResource + try: + return self.pool.pop() + except IndexError: + raise ClientError( + { + 'Error': { + 'Code': 'ProvisionedThroughputExceededException', + 'Message': 'Session pool exhausted', + } + }, + 'session.fetch') + + def release(self, resource): + """Return a boto3 resource to the pool.""" + # type: (ServiceResource) -> None + self.pool.appendleft(resource) def get_month(delta=0): @@ -147,20 +188,20 @@ def make_rotating_tablename(prefix, delta=0, date=None): def create_rotating_message_table(prefix="message", delta=0, date=None, read_throughput=5, - write_throughput=5): - # type: (str, int, Optional[datetime.date], int, int) -> Table + write_throughput=5, + boto_resource=None): + # type: (str, int, Optional[datetime.date], int, int, ServiceResource) -> Table # noqa """Create a new message table for webpush style message storage""" tablename = make_rotating_tablename(prefix, delta, date) try: - table = g_dynamodb.Table(tablename) + table = boto_resource.Table(tablename) if table.table_status == 'ACTIVE': # pragma nocover return table except ClientError as ex: if ex.response['Error']['Code'] != 'ResourceNotFoundException': - # If we hit this, our boto3 is misconfigured and we need to bail. - raise ex # pragma nocover - table = g_dynamodb.create_table( + pass + table = boto_resource.create_table( TableName=tablename, KeySchema=[ { @@ -201,24 +242,28 @@ def create_rotating_message_table(prefix="message", delta=0, date=None, return table -def get_rotating_message_table(prefix="message", delta=0, date=None, - message_read_throughput=5, - message_write_throughput=5): - # type: (str, int, Optional[datetime.date], int, int) -> Table +def get_rotating_message_tablename(prefix="message", delta=0, date=None, + message_read_throughput=5, + message_write_throughput=5, + boto_resource=None): + # type: (str, int, Optional[datetime.date], int, int, ServiceResource) -> str # noqa """Gets the message table for the current month.""" tablename = make_rotating_tablename(prefix, delta, date) - if not table_exists(tablename): - return create_rotating_message_table( + if not table_exists(tablename, boto_resource=boto_resource): + create_rotating_message_table( prefix=prefix, delta=delta, date=date, read_throughput=message_read_throughput, write_throughput=message_write_throughput, + boto_resource=boto_resource ) + return tablename else: - return g_dynamodb.Table(tablename) + return tablename def create_router_table(tablename="router", read_throughput=5, - write_throughput=5): + write_throughput=5, + boto_resource=None): # type: (str, int, int) -> Table """Create a new router table @@ -233,8 +278,7 @@ def create_router_table(tablename="router", read_throughput=5, cost of additional queries during GC to locate expired users. """ - - table = g_dynamodb.create_table( + table = boto_resource.create_table( TableName=tablename, KeySchema=[ { @@ -293,20 +337,22 @@ def create_router_table(tablename="router", read_throughput=5, return table -def _drop_table(tablename): +def _drop_table(tablename, boto_resource): try: - g_client.delete_table(TableName=tablename) + boto_resource.meta.client.delete_table(TableName=tablename) except ClientError: # pragma nocover pass -def _make_table(table_func, tablename, read_throughput, write_throughput): - # type: (Callable[[str, int, int], Table], str, int, int) -> Table +def _make_table(table_func, tablename, read_throughput, write_throughput, + boto_resource): + # type: (Callable[[str, int, int, ServiceResource]], str, int, int, ServiceResource) -> Table # noqa """Private common function to make a table with a table func""" - if not table_exists(tablename): - return table_func(tablename, read_throughput, write_throughput) + if not table_exists(tablename, boto_resource): + return table_func(tablename, read_throughput, write_throughput, + boto_resource) else: - return g_dynamodb.Table(tablename) + return boto_resource.Table(tablename) def _expiry(ttl): @@ -314,7 +360,7 @@ def _expiry(ttl): def get_router_table(tablename="router", read_throughput=5, - write_throughput=5): + write_throughput=5, boto_resource=None): # type: (str, int, int) -> Table """Get the main router table object @@ -323,10 +369,11 @@ def get_router_table(tablename="router", read_throughput=5, """ return _make_table(create_router_table, tablename, read_throughput, - write_throughput) + write_throughput, boto_resource=boto_resource) -def preflight_check(message, router, uaid="deadbeef00000000deadbeef00000000"): +def preflight_check(message, router, uaid="deadbeef00000000deadbeef00000000", + boto_resource=None): # type: (Message, Router, str) -> None """Performs a pre-flight check of the router/message to ensure appropriate permissions for operation. @@ -337,7 +384,7 @@ def preflight_check(message, router, uaid="deadbeef00000000deadbeef00000000"): # Verify tables are ready for use if they just got created ready = False while not ready: - tbl_status = [x.table_status() for x in [message, router]] + tbl_status = [x.table_status()for x in [message, router]] ready = all([status == "ACTIVE" for status in tbl_status]) if not ready: time.sleep(1) @@ -355,7 +402,6 @@ def preflight_check(message, router, uaid="deadbeef00000000deadbeef00000000"): message_id=message_id, ttl=60, ) - # Store a notification, fetch it, delete it message.store_message(notif) assert message.delete_message(notif) @@ -431,45 +477,51 @@ def generate_last_connect_values(date): yield int(val) -def list_tables(client=g_client): - """Return a list of the names of all DynamoDB tables.""" - start_table = None - while True: - if start_table: # pragma nocover - result = client.list_tables(ExclusiveStartTableName=start_table) - else: - result = client.list_tables() - for table in result.get('TableNames', []): - yield table - start_table = result.get('LastEvaluatedTableName', None) - if not start_table: - break - - -def table_exists(tablename, client=None): +def table_exists(tablename, boto_resource=None, resource_pool=None): + # type: (str, ServiceResource, BotoResources) -> bool """Determine if the specified Table exists""" - if not client: - client = g_client - return tablename in list_tables(client) + release = False + if not boto_resource: + boto_resource = resource_pool.fetch() + release = True + try: + return boto_resource.Table(tablename).table_status in [ + 'CREATING', 'UPDATING', 'ACTIVE'] + except ClientError: + return False + finally: + if release: + resource_pool.release(boto_resource) class Message(object): """Create a Message table abstraction on top of a DynamoDB Table object""" - def __init__(self, table, metrics, max_ttl=MAX_EXPIRY): - # type: (Table, IMetrics) -> None + def __init__(self, tablename, metrics=None, resource_pool=None, + max_ttl=MAX_EXPIRY): + # type: (str, IMetrics, BotoResources, int) -> None """Create a new Message object - :param table: :class:`Table` object. - :param metrics: Metrics object that implements the - :class:`autopush.metrics.IMetrics` interface. + :param tablename: name of the table. + :param metrics: unused + :param resource_pool: BotoResources pool for thread """ - self.table = table - self.metrics = metrics + self.tablename = tablename self._max_ttl = max_ttl + self._resources = resource_pool + + def table(self, tablename=None): + if not tablename: + tablename = self.tablename + boto_resource = self._resources.fetch() + try: + table = boto_resource.Table(tablename) + finally: + self._resources.release(boto_resource) + return table def table_status(self): - return self.table.table_status + return self.table().table_status @track_provisioned def register_channel(self, uaid, channel_id, ttl=None): @@ -482,7 +534,7 @@ def register_channel(self, uaid, channel_id, ttl=None): ":channel_id": set([normalize_id(channel_id)]), ":expiry": _expiry(ttl) } - self.table.update_item( + self.table().update_item( Key={ 'uaid': hasher(uaid), 'chidmessageid': ' ', @@ -500,7 +552,7 @@ def unregister_channel(self, uaid, channel_id, **kwargs): chid = normalize_id(channel_id) expr_values = {":channel_id": set([chid])} - response = self.table.update_item( + response = self.table().update_item( Key={ 'uaid': hasher(uaid), 'chidmessageid': ' ', @@ -520,13 +572,13 @@ def unregister_channel(self, uaid, channel_id, **kwargs): @track_provisioned def all_channels(self, uaid): - # type: (str) -> Tuple[bool, Set[str]] + # type: (str, ServiceResource) -> Tuple[bool, Set[str]] """Retrieve a list of all channels for a given uaid""" # Note: This only returns the chids associated with the UAID. # Functions that call store_message() would be required to # update that list as well using register_channel() - result = self.table.get_item( + result = self.table().get_item( Key={ 'uaid': hasher(uaid), 'chidmessageid': ' ', @@ -543,7 +595,7 @@ def all_channels(self, uaid): def save_channels(self, uaid, channels): # type: (str, Set[str]) -> None """Save out a set of channels""" - self.table.put_item( + self.table().put_item( Item={ 'uaid': hasher(uaid), 'chidmessageid': ' ', @@ -569,7 +621,7 @@ def store_message(self, notification): ) if notification.data: item['data'] = notification.data - self.table.put_item(Item=item) + self.table().put_item(Item=item) @track_provisioned def delete_message(self, notification): @@ -577,7 +629,7 @@ def delete_message(self, notification): """Deletes a specific message""" if notification.update_id: try: - self.table.delete_item( + self.table().delete_item( Key={ 'uaid': hasher(notification.uaid.hex), 'chidmessageid': notification.sort_key @@ -591,7 +643,7 @@ def delete_message(self, notification): except ClientError: return False else: - self.table.delete_item( + self.table().delete_item( Key={ 'uaid': hasher(notification.uaid.hex), 'chidmessageid': notification.sort_key, @@ -612,16 +664,16 @@ def fetch_messages( """ # Eagerly fetches all results in the result set. - response = self.table.query( + response = self.table().query( KeyConditionExpression=(Key("uaid").eq(hasher(uaid.hex)) & Key('chidmessageid').lt('02')), ConsistentRead=True, Limit=limit, ) results = list(response['Items']) - # First extract the position if applicable, slightly higher than 01: - # to ensure we don't load any 01 remainders that didn't get deleted - # yet + # First extract the position if applicable, slightly higher than + # 01: to ensure we don't load any 01 remainders that didn't get + # deleted yet last_position = None if results: # Ensure we return an int, as boto2 can return Decimals @@ -657,7 +709,7 @@ def fetch_timestamp_messages( else: sortkey = "01;" - response = self.table.query( + response = self.table().query( KeyConditionExpression=(Key('uaid').eq(hasher(uaid.hex)) & Key('chidmessageid').gt(sortkey)), ConsistentRead=True, @@ -680,7 +732,7 @@ def update_last_message_read(self, uaid, timestamp): expr = "SET current_timestamp=:timestamp, expiry=:expiry" expr_values = {":timestamp": timestamp, ":expiry": _expiry(self._max_ttl)} - self.table.update_item( + self.table().update_item( Key={ "uaid": hasher(uaid.hex), "chidmessageid": " " @@ -693,24 +745,40 @@ def update_last_message_read(self, uaid, timestamp): class Router(object): """Create a Router table abstraction on top of a DynamoDB Table object""" - def __init__(self, table, metrics, max_ttl=MAX_EXPIRY): - # type: (Table, IMetrics) -> None + def __init__(self, conf, metrics, resource_pool, max_ttl=MAX_EXPIRY): + # type: (dict, IMetrics, BotoResources, int) -> None """Create a new Router object :param table: :class:`Table` object. :param metrics: Metrics object that implements the :class:`autopush.metrics.IMetrics` interface. + :param resource_pool: Pool of :class:`autopush.db.BotoResources` """ - self.table = table + self.conf = conf self.metrics = metrics self._max_ttl = max_ttl + self._cached_table = None + self._resources = resource_pool or BotoResources() + self._resource = None + + def table(self): + boto_resource = self._resources.fetch() + try: + if self.conf: + table = get_router_table(boto_resource=boto_resource, + **asdict(self.conf)) + else: + table = get_router_table(boto_resource=boto_resource) + finally: + self._resources.release(boto_resource) + return table def table_status(self): - return self.table.table_status + return self.table().table_status def get_uaid(self, uaid): - # type: (str) -> Item + # type: (str, ServiceResource) -> Item """Get the database record for the UAID :raises: @@ -720,7 +788,7 @@ def get_uaid(self, uaid): """ try: - item = self.table.get_item( + item = self.table().get_item( Key={ 'uaid': hasher(uaid) }, @@ -761,7 +829,8 @@ def register_user(self, data): db_key = {"uaid": hasher(data["uaid"])} del data["uaid"] if "router_type" not in data or "connected_at" not in data: - # Not specifying these values will generate an exception in AWS. + # Not specifying these values will generate an exception in + # AWS. raise AutopushException("data is missing router_type " "or connected_at") # Generate our update expression @@ -775,7 +844,7 @@ def register_user(self, data): attribute_not_exists(node_id) or (connected_at < :connected_at) )""" - result = self.table.update_item( + result = self.table().update_item( Key=db_key, UpdateExpression=expr, ConditionExpression=cond, @@ -786,7 +855,7 @@ def register_user(self, data): r = {} for key, value in result["Attributes"].items(): try: - r[key] = self.table._dynamizer.decode(value) + r[key] = self.table()._dynamizer.decode(value) except (TypeError, AttributeError): # pragma: nocover # Included for safety as moto has occasionally made # this not work @@ -794,8 +863,8 @@ def register_user(self, data): result = r return (True, result) except ClientError as ex: - # ClientErrors are generated by a factory, and while they have a - # class, it's dynamically generated. + # ClientErrors are generated by a factory, and while they have + # a class, it's dynamically generated. if ex.response['Error']['Code'] == \ 'ConditionalCheckFailedException': return (False, {}) @@ -808,7 +877,7 @@ def drop_user(self, uaid): # The following hack ensures that only uaids that exist and are # deleted return true. try: - item = self.table.get_item( + item = self.table().get_item( Key={ 'uaid': hasher(uaid) }, @@ -818,13 +887,14 @@ def drop_user(self, uaid): return False except ClientError: pass - result = self.table.delete_item(Key={'uaid': hasher(uaid)}) + result = self.table().delete_item( + Key={'uaid': hasher(uaid)}) return result['ResponseMetadata']['HTTPStatusCode'] == 200 def delete_uaids(self, uaids): # type: (List[str]) -> None """Issue a batch delete call for the given uaids""" - with self.table.batch_writer() as batch: + with self.table().batch_writer() as batch: for uaid in uaids: batch.delete_item(Key={'uaid': uaid}) @@ -850,6 +920,7 @@ def drop_old_users(self, months_ago=2): quickly as possible. :param months_ago: how many months ago since the last connect + :param boto_resource: ServiceResource for thread :returns: Iterable of how many deletes were run @@ -858,7 +929,7 @@ def drop_old_users(self, months_ago=2): batched = [] for hash_key in generate_last_connect_values(prior_date): - response = self.table.query( + response = self.table().query( KeyConditionExpression=Key("last_connect").eq(hash_key), IndexName="AccessIndex", ) @@ -878,7 +949,7 @@ def drop_old_users(self, months_ago=2): @track_provisioned def _update_last_connect(self, uaid, last_connect): - self.table.update_item( + self.table().update_item( Key={"uaid": hasher(uaid)}, UpdateExpression="SET last_connect=:last_connect", ExpressionAttributeValues={":last_connect": last_connect} @@ -902,7 +973,7 @@ def update_message_month(self, uaid, month): ":last_connect": generate_last_connect(), ":expiry": _expiry(self._max_ttl), } - self.table.update_item( + self.table().update_item( Key=db_key, UpdateExpression=expr, ExpressionAttributeValues=expr_values, @@ -929,7 +1000,7 @@ def clear_node(self, item): try: cond = "(node_id = :node) and (connected_at = :conn)" - self.table.put_item( + self.table().put_item( Item=item, ConditionExpression=cond, ExpressionAttributeValues={ @@ -952,15 +1023,15 @@ class DatabaseManager(object): _router_conf = attrib() # type: DDBTableConfig _message_conf = attrib() # type: DDBTableConfig - - metrics = attrib() # type: IMetrics + metrics = attrib() # type: IMetrics + _resources = attrib() # type: BotoResources router = attrib(default=None) # type: Optional[Router] - message_tables = attrib(default=Factory(dict)) # type: Dict[str, Message] + message_tables = attrib(default=Factory(list)) # type: List[str] current_msg_month = attrib(init=False) # type: Optional[str] current_month = attrib(init=False) # type: Optional[int] + _message = attrib(default=None) # type: Optional[Message] # for testing: - client = attrib(default=g_client) # type: Optional[Any] def __attrs_post_init__(self): """Initialize sane defaults""" @@ -970,16 +1041,23 @@ def __attrs_post_init__(self): self._message_conf.tablename, date=today ) + if not self._resources: + self._resources = BotoResources() @classmethod - def from_config(cls, conf, **kwargs): - # type: (AutopushConfig, **Any) -> DatabaseManager + def from_config(cls, conf, resource_pool=None, **kwargs): + # type: (AutopushConfig, BotoResources, **Any) -> DatabaseManager """Create a DatabaseManager from the given config""" metrics = autopush.metrics.from_config(conf) + if not resource_pool: + resource_pool = BotoResources(conf={ + 'endpoint_url': conf.aws_ddb_endpoint + }) return cls( router_conf=conf.router_table, message_conf=conf.message_table, metrics=metrics, + resources=resource_pool, **kwargs ) @@ -988,14 +1066,21 @@ def setup(self, preflight_uaid): """Setup metrics, message tables and perform preflight_check""" self.metrics.start() self.setup_tables() - preflight_check(self.message, self.router, preflight_uaid) + ddb_resource = self._resources.fetch() + try: + preflight_check(self.message, self.router, preflight_uaid, + boto_resource=ddb_resource) + finally: + self._resources.release(ddb_resource) def setup_tables(self): """Lookup or create the database tables""" self.router = Router( - get_router_table(**asdict(self._router_conf)), - self.metrics + conf=self._router_conf, + metrics=self.metrics, + resource_pool=self._resources, ) + self.router.table() # Used to determine whether a connection is out of date with current # db objects. There are three noteworty cases: # 1 "Last Month" the table requires a rollover. @@ -1005,18 +1090,24 @@ def setup_tables(self): # table is present before the switchover is the main reason for this, # just in case some nodes do switch sooner. self.create_initial_message_tables() + self._message = Message(self.current_msg_month, + self.metrics, + resource_pool=self.resources) @property def message(self): # type: () -> Message """Property that access the current message table""" - return self.message_tables[self.current_msg_month] + if not self._message or isinstance(self._message, Attribute): + self._message = self.message_table(self.current_msg_month) + return self._message + + @property + def resources(self): + return self._resources - @message.setter - def message(self, value): - # type: (Message) -> None - """Setter to set the current message table""" - self.message_tables[self.current_msg_month] = value + def message_table(self, tablename): + return Message(tablename, self.metrics, resource_pool=self.resources) def _tomorrow(self): # type: () -> datetime.date @@ -1031,32 +1122,35 @@ def create_initial_message_tables(self): """ mconf = self._message_conf today = datetime.date.today() - last_month = get_rotating_message_table( - prefix=mconf.tablename, - delta=-1, - message_read_throughput=mconf.read_throughput, - message_write_throughput=mconf.write_throughput - ) - this_month = get_rotating_message_table( - prefix=mconf.tablename, - message_read_throughput=mconf.read_throughput, - message_write_throughput=mconf.write_throughput - ) - self.current_month = today.month - self.current_msg_month = this_month.table_name - self.message_tables = { - last_month.table_name: Message(last_month, self.metrics), - this_month.table_name: Message(this_month, self.metrics) - } - if self._tomorrow().month != today.month: - next_month = get_rotating_message_table( + boto_resource = self.resources.fetch() + try: + last_month = get_rotating_message_tablename( prefix=mconf.tablename, - delta=1, + delta=-1, message_read_throughput=mconf.read_throughput, - message_write_throughput=mconf.write_throughput + message_write_throughput=mconf.write_throughput, + boto_resource=boto_resource, + ) + this_month = get_rotating_message_tablename( + prefix=mconf.tablename, + message_read_throughput=mconf.read_throughput, + message_write_throughput=mconf.write_throughput, + boto_resource=boto_resource, ) - self.message_tables[next_month.table_name] = Message( - next_month, self.metrics) + self.current_month = today.month + self.current_msg_month = this_month + self.message_tables = [last_month, this_month] + if self._tomorrow().month != today.month: + next_month = get_rotating_message_tablename( + prefix=mconf.tablename, + delta=1, + message_read_throughput=mconf.read_throughput, + message_write_throughput=mconf.write_throughput, + boto_resource=boto_resource, + ) + self.message_tables.append(next_month) + finally: + self.resources.release(boto_resource) @inlineCallbacks def update_rotating_tables(self): @@ -1071,35 +1165,36 @@ def update_rotating_tables(self): mconf = self._message_conf today = datetime.date.today() tomorrow = self._tomorrow() - if ((tomorrow.month != today.month) and - sorted(self.message_tables.keys())[-1] != tomorrow.month): - next_month = yield deferToThread( - get_rotating_message_table, + boto_resource = self.resources.fetch() + try: + if ((tomorrow.month != today.month) and + sorted(self.message_tables)[-1] != tomorrow.month): + next_month = yield deferToThread( + get_rotating_message_tablename, + prefix=mconf.tablename, + delta=0, + date=tomorrow, + message_read_throughput=mconf.read_throughput, + message_write_throughput=mconf.write_throughput, + boto_resource=boto_resource + ) + self.message_tables.append(next_month) + if today.month == self.current_month: + # No change in month, we're fine. + returnValue(False) + + # Get tables for the new month, and verify they exist before we + # try to switch over + message_table = yield deferToThread( + get_rotating_message_tablename, prefix=mconf.tablename, - delta=0, - date=tomorrow, message_read_throughput=mconf.read_throughput, message_write_throughput=mconf.write_throughput ) - self.message_tables[next_month.table_name] = Message( - next_month, self.metrics) - - if today.month == self.current_month: - # No change in month, we're fine. - returnValue(False) - - # Get tables for the new month, and verify they exist before we try to - # switch over - message_table = yield deferToThread( - get_rotating_message_table, - prefix=mconf.tablename, - message_read_throughput=mconf.read_throughput, - message_write_throughput=mconf.write_throughput - ) - # Both tables found, safe to switch-over - self.current_month = today.month - self.current_msg_month = message_table.table_name - self.message_tables[self.current_msg_month] = Message( - message_table, self.metrics) + # Both tables found, safe to switch-over + self.current_month = today.month + self.current_msg_month = message_table + finally: + self.resources.release(boto_resource) returnValue(True) diff --git a/autopush/diagnostic_cli.py b/autopush/diagnostic_cli.py index b44d5594..3c945cf7 100644 --- a/autopush/diagnostic_cli.py +++ b/autopush/diagnostic_cli.py @@ -7,7 +7,7 @@ from twisted.logger import Logger from autopush.config import AutopushConfig -from autopush.db import DatabaseManager +from autopush.db import DatabaseManager, Message from autopush.main import AutopushMultiService from autopush.main_argparse import add_shared_args @@ -18,11 +18,12 @@ class EndpointDiagnosticCLI(object): log = Logger() - def __init__(self, sysargs, use_files=True): + def __init__(self, sysargs, resource_pool, use_files=True): ns = self._load_args(sysargs, use_files) self._conf = conf = AutopushConfig.from_argparse(ns) conf.statsd_host = None - self.db = DatabaseManager.from_config(conf) + self.db = DatabaseManager.from_config(conf, + resource_pool=resource_pool) self.db.setup(conf.preflight_uaid) self._endpoint = ns.endpoint self._pp = pprint.PrettyPrinter(indent=4) @@ -69,11 +70,15 @@ def run(self): print("\n") mess_table = rec["current_month"] - chans = self.db.message_tables[mess_table].all_channels(uaid) + chans = Message(mess_table, + resource_pool=self.db.resources).all_channels(uaid) print("Channels in message table:") self._pp.pprint(chans) -def run_endpoint_diagnostic_cli(sysargs=None, use_files=True): - cli = EndpointDiagnosticCLI(sysargs, use_files) +def run_endpoint_diagnostic_cli(sysargs=None, use_files=True, + resource_pool=None): + cli = EndpointDiagnosticCLI(sysargs, + resource_pool=resource_pool, + use_files=use_files) return cli.run() diff --git a/autopush/main.py b/autopush/main.py index 5e24f924..a255c248 100644 --- a/autopush/main.py +++ b/autopush/main.py @@ -20,6 +20,7 @@ Sequence, ) +from autopush import constants from autopush.http import ( InternalRouterHTTPFactory, EndpointHTTPFactory, @@ -29,7 +30,7 @@ import autopush.utils as utils import autopush.logging as logging from autopush.config import AutopushConfig -from autopush.db import DatabaseManager +from autopush.db import DatabaseManager, BotoResources # noqa from autopush.exceptions import InvalidConfig from autopush.haproxy import HAProxyServerEndpoint from autopush.logging import PushLogger @@ -61,13 +62,12 @@ class AutopushMultiService(MultiService): config_files = None # type: Sequence[str] logger_name = None # type: str - THREAD_POOL_SIZE = 50 - - def __init__(self, conf): - # type: (AutopushConfig) -> None + def __init__(self, conf, resource_pool=None): + # type: (AutopushConfig, BotoResources) -> None super(AutopushMultiService, self).__init__() self.conf = conf - self.db = DatabaseManager.from_config(conf) + self.db = DatabaseManager.from_config(conf, + resource_pool=resource_pool) self.agent = agent_from_config(conf) @staticmethod @@ -103,7 +103,7 @@ def add_memusage(self): def run(self): """Start the services and run the reactor""" - reactor.suggestThreadPoolSize(self.THREAD_POOL_SIZE) + reactor.suggestThreadPoolSize(constants.THREAD_POOL_SIZE) self.startService() reactor.run() @@ -115,8 +115,8 @@ def stopService(self): undo_monkey_patch_ssl_wrap_socket() @classmethod - def _from_argparse(cls, ns, **kwargs): - # type: (Namespace, **Any) -> AutopushMultiService + def _from_argparse(cls, ns, resource_pool=None, **kwargs): + # type: (Namespace, BotoResources, **Any) -> AutopushMultiService """Create an instance from argparse/additional kwargs""" # Add some entropy to prevent potential conflicts. postfix = os.urandom(4).encode('hex').ljust(8, '0') @@ -126,11 +126,11 @@ def _from_argparse(cls, ns, **kwargs): preflight_uaid="deadbeef00000000deadbeef" + postfix, **kwargs ) - return cls(conf) + return cls(conf, resource_pool=resource_pool) @classmethod - def main(cls, args=None, use_files=True): - # type: (Sequence[str], bool) -> Any + def main(cls, args=None, use_files=True, resource_pool=None): + # type: (Sequence[str], bool, BotoResources) -> Any """Entry point to autopush's main command line scripts. aka autopush/autoendpoint. @@ -148,7 +148,8 @@ def main(cls, args=None, use_files=True): firehose_delivery_stream=ns.firehose_stream_name ) try: - app = cls.from_argparse(ns) + cls.argparse = cls.from_argparse(ns, resource_pool=resource_pool) + app = cls.argparse except InvalidConfig as e: log.critical(str(e)) return 1 @@ -172,9 +173,10 @@ class EndpointApplication(AutopushMultiService): endpoint_factory = EndpointHTTPFactory - def __init__(self, conf): - # type: (AutopushConfig) -> None - super(EndpointApplication, self).__init__(conf) + def __init__(self, conf, resource_pool=None): + # type: (AutopushConfig, BotoResources) -> None + super(EndpointApplication, self).__init__(conf, + resource_pool=resource_pool) self.routers = routers_from_config(conf, self.db, self.agent) def setup(self, rotate_tables=True): @@ -209,8 +211,8 @@ def add_endpoint(self): self.addService(StreamServerEndpointService(ep, factory)) @classmethod - def from_argparse(cls, ns): - # type: (Namespace) -> AutopushMultiService + def from_argparse(cls, ns, resource_pool=None): + # type: (Namespace, BotoResources) -> AutopushMultiService return super(EndpointApplication, cls)._from_argparse( ns, port=ns.port, @@ -220,6 +222,8 @@ def from_argparse(cls, ns): cors=not ns.no_cors, bear_hash_key=ns.auth_key, proxy_protocol_port=ns.proxy_protocol_port, + aws_ddb_endpoint=ns.aws_ddb_endpoint, + resource_pool=resource_pool ) @@ -240,9 +244,12 @@ class ConnectionApplication(AutopushMultiService): websocket_factory = PushServerFactory websocket_site_factory = ConnectionWSSite - def __init__(self, conf): - # type: (AutopushConfig) -> None - super(ConnectionApplication, self).__init__(conf) + def __init__(self, conf, resource_pool=None): + # type: (AutopushConfig, BotoResources) -> None + super(ConnectionApplication, self).__init__( + conf, + resource_pool=resource_pool + ) self.clients = {} # type: Dict[str, PushServerProtocol] def setup(self, rotate_tables=True): @@ -276,8 +283,8 @@ def add_websocket(self): self.add_maybe_ssl(conf.port, site_factory, site_factory.ssl_cf()) @classmethod - def from_argparse(cls, ns): - # type: (Namespace) -> AutopushMultiService + def from_argparse(cls, ns, resource_pool=None): + # type: (Namespace, BotoResources) -> AutopushMultiService return super(ConnectionApplication, cls)._from_argparse( ns, port=ns.port, @@ -298,6 +305,8 @@ def from_argparse(cls, ns): auto_ping_timeout=ns.auto_ping_timeout, max_connections=ns.max_connections, close_handshake_timeout=ns.close_handshake_timeout, + aws_ddb_endpoint=ns.aws_ddb_endpoint, + resource_pool=resource_pool ) @@ -340,8 +349,8 @@ def stopService(self): yield super(RustConnectionApplication, self).stopService() @classmethod - def from_argparse(cls, ns): - # type: (Namespace) -> AutopushMultiService + def from_argparse(cls, ns, resource_pool=None): + # type: (Namespace, BotoResources) -> AutopushMultiService return super(RustConnectionApplication, cls)._from_argparse( ns, port=ns.port, @@ -363,11 +372,13 @@ def from_argparse(cls, ns): auto_ping_timeout=ns.auto_ping_timeout, max_connections=ns.max_connections, close_handshake_timeout=ns.close_handshake_timeout, + aws_ddb_endpoint=ns.aws_ddb_endpoint, + resource_pool=resource_pool ) @classmethod - def main(cls, args=None, use_files=True): - # type: (Sequence[str], bool) -> Any + def main(cls, args=None, use_files=True, resource_pool=None): + # type: (Sequence[str], bool, BotoResources) -> Any """Entry point to autopush's main command line scripts. aka autopush/autoendpoint. @@ -385,7 +396,7 @@ def main(cls, args=None, use_files=True): firehose_delivery_stream=ns.firehose_stream_name ) try: - app = cls.from_argparse(ns) + app = cls.from_argparse(ns, resource_pool=resource_pool) except InvalidConfig as e: log.critical(str(e)) return 1 diff --git a/autopush/main_argparse.py b/autopush/main_argparse.py index 0813a6c2..1ad06da9 100644 --- a/autopush/main_argparse.py +++ b/autopush/main_argparse.py @@ -101,6 +101,10 @@ def add_shared_args(parser): help="Don't cache ssl.wrap_socket's SSLContexts", action="store_true", default=False, env_var="_NO_SSLCONTEXT_CACHE") + parser.add_argument('--aws_ddb_endpoint', + help="AWS DynamoDB endpoint override", + type=str, default=None, + env_var="AWS_LOCAL_DYNAMODB") # No ENV because this is for humans _add_external_router_args(parser) _obsolete_args(parser) diff --git a/autopush/router/webpush.py b/autopush/router/webpush.py index 177c9a08..93602140 100644 --- a/autopush/router/webpush.py +++ b/autopush/router/webpush.py @@ -231,7 +231,7 @@ def _save_notification(self, uaid_data, notification): "Location": location}, logged_status=204) return deferToThread( - self.db.message_tables[month_table].store_message, + self.db.message_table(month_table).store_message, notification=notification, ) diff --git a/autopush/scripts/drop_user.py b/autopush/scripts/drop_user.py index 2405e176..8a9d9610 100644 --- a/autopush/scripts/drop_user.py +++ b/autopush/scripts/drop_user.py @@ -5,6 +5,7 @@ from autopush.db import ( get_router_table, Router, + BotoResources, ) from autopush.metrics import SinkMetrics @@ -18,20 +19,24 @@ help="Seconds to pause between batches.") def drop_users(router_table_name, months_ago, batch_size, pause_time): router_table = get_router_table(router_table_name) - router = Router(router_table, SinkMetrics()) - + resources = BotoResources() + router = Router(router_table, SinkMetrics(), resource_pool=resources) + ddb_resource = resources.fetch() click.echo("Deleting users with a last_connect %s months ago." % months_ago) count = 0 - for deletes in router.drop_old_users(months_ago): - click.echo("") - count += deletes - if count >= batch_size: - click.echo("Deleted %s user records, pausing for %s seconds." - % pause_time) - time.sleep(pause_time) - count = 0 + try: + for deletes in router.drop_old_users(months_ago, ddb_resource): + click.echo("") + count += deletes + if count >= batch_size: + click.echo("Deleted %s user records, pausing for %s seconds." + % pause_time) + time.sleep(pause_time) + count = 0 + finally: + resources.release(ddb_resource) click.echo("Finished old user purge.") diff --git a/autopush/tests/__init__.py b/autopush/tests/__init__.py index adc0ead4..f5457342 100644 --- a/autopush/tests/__init__.py +++ b/autopush/tests/__init__.py @@ -4,12 +4,9 @@ import subprocess import boto -import botocore -import boto3 import psutil -import autopush.db -from autopush.db import create_rotating_message_table +from autopush.db import create_rotating_message_table, BotoResources here_dir = os.path.abspath(os.path.dirname(__file__)) root_dir = os.path.dirname(os.path.dirname(here_dir)) @@ -17,35 +14,34 @@ ddb_lib_dir = os.path.join(ddb_dir, "DynamoDBLocal_lib") ddb_jar = os.path.join(ddb_dir, "DynamoDBLocal.jar") ddb_process = None +boto_resources = None def setUp(): for name in ('boto', 'boto3', 'botocore'): logging.getLogger(name).setLevel(logging.CRITICAL) - global ddb_process + global ddb_process, boto_resources cmd = " ".join([ "java", "-Djava.library.path=%s" % ddb_lib_dir, "-jar", ddb_jar, "-sharedDb", "-inMemory" ]) - conf = botocore.config.Config( - region_name=os.getenv('AWS_REGION_NAME', 'us-east-1') - ) ddb_process = subprocess.Popen(cmd, shell=True, env=os.environ) - autopush.db.g_dynamodb = boto3.resource( - 'dynamodb', - config=conf, - endpoint_url=os.getenv("AWS_LOCAL_DYNAMODB", "http://127.0.0.1:8000"), + if os.getenv("AWS_LOCAL_DYNAMODB") is None: + os.environ["AWS_LOCAL_DYNAMODB"] = "http://127.0.0.1:8000" + ddb_session_args = dict( + endpoint_url=os.getenv("AWS_LOCAL_DYNAMODB"), aws_access_key_id="BogusKey", aws_secret_access_key="BogusKey", ) - - autopush.db.g_client = autopush.db.g_dynamodb.meta.client - + boto_resources = BotoResources(conf=ddb_session_args) # Setup the necessary message tables message_table = os.environ.get("MESSAGE_TABLE", "message_int_test") - - create_rotating_message_table(prefix=message_table, delta=-1) - create_rotating_message_table(prefix=message_table) + resource = boto_resources.fetch() + create_rotating_message_table(prefix=message_table, delta=-1, + boto_resource=resource) + create_rotating_message_table(prefix=message_table, + boto_resource=resource) + boto_resources.release(resource) def tearDown(): diff --git a/autopush/tests/support.py b/autopush/tests/support.py index 1d4e4bfd..8265120d 100644 --- a/autopush/tests/support.py +++ b/autopush/tests/support.py @@ -8,6 +8,7 @@ Router, ) from autopush.metrics import SinkMetrics +import autopush.tests @implementer(ILogObserver) @@ -44,5 +45,6 @@ def test_db(metrics=None): router_conf=DDBTableConfig(tablename='router'), router=Mock(spec=Router), message_conf=DDBTableConfig(tablename='message'), - metrics=SinkMetrics() if metrics is None else metrics + metrics=SinkMetrics() if metrics is None else metrics, + resources=autopush.tests.boto_resources, ) diff --git a/autopush/tests/test_db.py b/autopush/tests/test_db.py index 920565dc..5f56d1cd 100644 --- a/autopush/tests/test_db.py +++ b/autopush/tests/test_db.py @@ -11,21 +11,26 @@ import pytest from autopush.db import ( - get_rotating_message_table, + get_rotating_message_tablename, get_router_table, create_router_table, preflight_check, table_exists, Message, Router, + BotoResources, generate_last_connect, make_rotating_tablename, _drop_table, - _make_table) + _make_table, + MAX_DDB_SESSIONS, + ) from autopush.exceptions import AutopushException from autopush.metrics import SinkMetrics from autopush.utils import WebPushNotification +# nose fails to import sessions correctly. +import autopush.tests dummy_uaid = str(uuid.UUID("abad1dea00000000aabbccdd00000000")) dummy_chid = str(uuid.UUID("deadbeef00000000decafbad00000000")) @@ -44,17 +49,43 @@ def make_webpush_notification(uaid, chid, ttl=100): class DbUtilsTest(unittest.TestCase): def test_make_table(self): + fake_resource = Mock() fake_func = Mock() fake_table = "DoesNotExist_{}".format(uuid.uuid4()) - _make_table(fake_func, fake_table, 5, 10) - assert fake_func.call_args[0] == (fake_table, 5, 10) + _make_table(fake_func, fake_table, 5, 10, boto_resource=fake_resource) + assert fake_func.call_args[0] == (fake_table, 5, 10, fake_resource) + + +class SessionsTest(unittest.TestCase): + def test_resource_pool(self): + + testpool = BotoResources( + conf={'endpoint_url': 'http://localhost:8000'} + ) + hold = [] + with pytest.raises(ClientError): + for i in range(0, MAX_DDB_SESSIONS + 1): + hold.append(testpool.fetch()) + + assert len(hold) == MAX_DDB_SESSIONS class DbCheckTestCase(unittest.TestCase): + def setUp(cls): + cls.resource = autopush.tests.boto_resources.fetch() + + def tearDown(cls): + autopush.tests.boto_resources.release(cls.resource) + def test_preflight_check_fail(self): - router = Router(get_router_table(), SinkMetrics()) - message = Message(get_rotating_message_table(), SinkMetrics()) + router = Router(get_router_table(boto_resource=self.resource), + SinkMetrics(), + resource_pool=autopush.tests.boto_resources) + message = Message(get_rotating_message_tablename( + boto_resource=self.resource), + SinkMetrics(), + resource_pool=autopush.tests.boto_resources) def raise_exc(*args, **kwargs): # pragma: no cover raise Exception("Oops") @@ -63,14 +94,18 @@ def raise_exc(*args, **kwargs): # pragma: no cover router.clear_node.side_effect = raise_exc with pytest.raises(Exception): - preflight_check(message, router) + preflight_check(message, router, self.resource) def test_preflight_check(self): - router = Router(get_router_table(), SinkMetrics()) - message = Message(get_rotating_message_table(), SinkMetrics()) + router = Router({}, SinkMetrics(), + resource_pool=autopush.tests.boto_resources) + message = Message(get_rotating_message_tablename( + boto_resource=self.resource), + SinkMetrics(), + resource_pool=autopush.tests.boto_resources) pf_uaid = "deadbeef00000000deadbeef01010101" - preflight_check(message, router, pf_uaid) + preflight_check(message, router, pf_uaid, self.resource) # now check that the database reports no entries. _, notifs = message.fetch_messages(uuid.UUID(pf_uaid)) assert len(notifs) == 0 @@ -78,14 +113,18 @@ def test_preflight_check(self): router.get_uaid(pf_uaid) def test_preflight_check_wait(self): - router = Router(get_router_table(), SinkMetrics()) - message = Message(get_rotating_message_table(), SinkMetrics()) + router = Router({}, SinkMetrics(), + resource_pool=autopush.tests.boto_resources) + message = Message(get_rotating_message_tablename( + boto_resource=self.resource), + SinkMetrics(), + resource_pool=autopush.tests.boto_resources) values = ["PENDING", "ACTIVE"] message.table_status = Mock(side_effect=values) pf_uaid = "deadbeef00000000deadbeef01010101" - preflight_check(message, router, pf_uaid) + preflight_check(message, router, pf_uaid, self.resource) # now check that the database reports no entries. _, notifs = message.fetch_messages(uuid.UUID(pf_uaid)) assert len(notifs) == 0 @@ -126,21 +165,25 @@ def test_normalize_id(self): class MessageTestCase(unittest.TestCase): def setUp(self): - table = get_rotating_message_table() + self.resource = autopush.tests.boto_resources.fetch() + table = get_rotating_message_tablename(boto_resource=self.resource) self.real_table = table self.uaid = str(uuid.uuid4()) def tearDown(self): + autopush.tests.boto_resources.release(self.resource) pass def test_register(self): chid = str(uuid.uuid4()) - m = get_rotating_message_table() - message = Message(m, SinkMetrics()) - message.register_channel(self.uaid, chid) + m = get_rotating_message_tablename(boto_resource=self.resource) + message = Message(m, SinkMetrics(), + resource_pool=autopush.tests.boto_resources) + message.register_channel(self.uaid, chid) + lm = self.resource.Table(m) # Verify it's in the db - response = m.query( + response = lm.query( KeyConditions={ 'uaid': { 'AttributeValueList': [self.uaid], @@ -157,12 +200,15 @@ def test_register(self): def test_unregister(self): chid = str(uuid.uuid4()) - m = get_rotating_message_table() - message = Message(m, SinkMetrics()) + m = get_rotating_message_tablename(boto_resource=self.resource) + message = Message(m, SinkMetrics(), + resource_pool=autopush.tests.boto_resources) message.register_channel(self.uaid, chid) # Verify its in the db - response = m.query( + lm = self.resource.Table(m) + # Verify it's in the db + response = lm.query( KeyConditions={ 'uaid': { 'AttributeValueList': [self.uaid], @@ -182,7 +228,7 @@ def test_unregister(self): message.unregister_channel(self.uaid, chid) # Verify its not in the db - response = m.query( + response = lm.query( KeyConditions={ 'uaid': { 'AttributeValueList': [self.uaid], @@ -200,18 +246,21 @@ def test_unregister(self): assert results[0].get("chids") is None # Test for the very unlikely case that there's no 'chid' - m.update_item = Mock(return_value={ + mtable = Mock() + mtable.update_item = Mock(return_value={ 'Attributes': {'uaid': self.uaid}, 'ResponseMetaData': {} }) + message.table = Mock(return_value=mtable) r = message.unregister_channel(self.uaid, dummy_chid) assert r is False def test_all_channels(self): chid = str(uuid.uuid4()) chid2 = str(uuid.uuid4()) - m = get_rotating_message_table() - message = Message(m, SinkMetrics()) + m = get_rotating_message_tablename(boto_resource=self.resource) + message = Message(m, SinkMetrics(), + resource_pool=autopush.tests.boto_resources) message.register_channel(self.uaid, chid) message.register_channel(self.uaid, chid2) @@ -225,24 +274,27 @@ def test_all_channels(self): assert chid in chans def test_all_channels_fail(self): - m = get_rotating_message_table() - message = Message(m, SinkMetrics()) - message.table.get_item = Mock() - message.table.get_item.return_value = { + m = get_rotating_message_tablename(boto_resource=self.resource) + message = Message(m, SinkMetrics(), + resource_pool=autopush.tests.boto_resources) + + mtable = Mock() + mtable.get_item.return_value = { "ResponseMetadata": { "HTTPStatusCode": 400 }, } - + message.table = Mock(return_value=mtable) res = message.all_channels(self.uaid) assert res == (False, set([])) def test_save_channels(self): chid = str(uuid.uuid4()) chid2 = str(uuid.uuid4()) - m = get_rotating_message_table() - message = Message(m, SinkMetrics()) + m = get_rotating_message_tablename(boto_resource=self.resource) + message = Message(m, SinkMetrics(), + resource_pool=autopush.tests.boto_resources) message.register_channel(self.uaid, chid) message.register_channel(self.uaid, chid2) @@ -253,16 +305,18 @@ def test_save_channels(self): assert chans == new_chans def test_all_channels_no_uaid(self): - m = get_rotating_message_table() - message = Message(m, SinkMetrics()) + m = get_rotating_message_tablename(boto_resource=self.resource) + message = Message(m, SinkMetrics(), + resource_pool=autopush.tests.boto_resources) exists, chans = message.all_channels(dummy_uaid) assert chans == set([]) def test_message_storage(self): chid = str(uuid.uuid4()) chid2 = str(uuid.uuid4()) - m = get_rotating_message_table() - message = Message(m, SinkMetrics()) + m = get_rotating_message_tablename(boto_resource=self.resource) + message = Message(m, SinkMetrics(), + resource_pool=autopush.tests.boto_resources) message.register_channel(self.uaid, chid) message.register_channel(self.uaid, chid2) @@ -283,8 +337,9 @@ def test_message_storage_overwrite(self): notif2 = make_webpush_notification(self.uaid, chid) notif3 = make_webpush_notification(self.uaid, chid2) notif2.message_id = notif1.message_id - m = get_rotating_message_table() - message = Message(m, SinkMetrics()) + m = get_rotating_message_tablename(boto_resource=self.resource) + message = Message(m, SinkMetrics(), + resource_pool=autopush.tests.boto_resources) message.register_channel(self.uaid, chid) message.register_channel(self.uaid, chid2) @@ -292,20 +347,23 @@ def test_message_storage_overwrite(self): message.store_message(notif2) message.store_message(notif3) - all_messages = list(message.fetch_messages(uuid.UUID(self.uaid))) + all_messages = list(message.fetch_messages( + uuid.UUID(self.uaid))) assert len(all_messages) == 2 def test_message_delete_fail_condition(self): notif = make_webpush_notification(dummy_uaid, dummy_chid) notif.message_id = notif.update_id = dummy_uaid - m = get_rotating_message_table() - message = Message(m, SinkMetrics()) + m = get_rotating_message_tablename(boto_resource=self.resource) + message = Message(m, SinkMetrics(), + resource_pool=autopush.tests.boto_resources) def raise_condition(*args, **kwargs): raise ClientError({}, 'delete_item') - message.table = Mock() - message.table.delete_item.side_effect = raise_condition + m_de = Mock() + m_de.delete_item = Mock(side_effect=raise_condition) + message.table = Mock(return_value=m_de) result = message.delete_message(notif) assert result is False @@ -314,22 +372,21 @@ def test_message_rotate_table_with_date(self): future = (datetime.today() + timedelta(days=32)).date() tbl_name = make_rotating_tablename(prefix, date=future) - m = get_rotating_message_table(prefix=prefix, date=future) - assert m.table_name == tbl_name + m = get_rotating_message_tablename(prefix=prefix, date=future, + boto_resource=self.resource) + assert m == tbl_name # Clean up the temp table. - _drop_table(tbl_name) + _drop_table(tbl_name, boto_resource=self.resource) class RouterTestCase(unittest.TestCase): @classmethod - def setup_class(self): - table = get_router_table() - self.real_table = table - self.real_connection = table.meta.client + def setUpClass(cls): + cls.boto_resource = autopush.tests.boto_resources.fetch() @classmethod - def teardown_class(self): - self.real_table.meta.client = self.real_connection + def tearDownClass(cls): + autopush.tests.boto_resources.release(cls.boto_resource) def _create_minimal_record(self): data = { @@ -342,10 +399,10 @@ def _create_minimal_record(self): def test_drop_old_users(self): # First create a bunch of users - r = get_router_table() - router = Router(r, SinkMetrics()) + router = Router({}, SinkMetrics(), + resource_pool=autopush.tests.boto_resources) # Purge any existing users from previous runs. - router.drop_old_users(0) + router.drop_old_users(months_ago=0) for _ in range(0, 53): router.register_user(self._create_minimal_record()) @@ -354,57 +411,61 @@ def test_drop_old_users(self): def test_custom_tablename(self): db_name = "router_%s" % uuid.uuid4() - assert not table_exists(db_name) - create_router_table(db_name) - assert table_exists(db_name) + assert not table_exists(db_name, boto_resource=self.boto_resource) + create_router_table(db_name, boto_resource=self.boto_resource) + assert table_exists(db_name, boto_resource=self.boto_resource) # Clean up the temp table. - _drop_table(db_name) + _drop_table(db_name, boto_resource=self.boto_resource) def test_provisioning(self): db_name = "router_%s" % uuid.uuid4() - r = create_router_table(db_name, 3, 17) + r = create_router_table(db_name, 3, 17, + boto_resource=self.boto_resource) assert r.provisioned_throughput.get('ReadCapacityUnits') == 3 assert r.provisioned_throughput.get('WriteCapacityUnits') == 17 def test_no_uaid_found(self): uaid = str(uuid.uuid4()) - r = get_router_table() - router = Router(r, SinkMetrics()) + router = Router({}, SinkMetrics(), + resource_pool=autopush.tests.boto_resources) with pytest.raises(ItemNotFound): router.get_uaid(uaid) def test_uaid_provision_failed(self): - r = get_router_table() - router = Router(r, SinkMetrics()) + router = Router({}, SinkMetrics(), + resource_pool=autopush.tests.boto_resources) router.table = Mock() def raise_condition(*args, **kwargs): - import autopush.db - raise autopush.db.g_client.exceptions.ClientError( + raise ClientError( {'Error': {'Code': 'ProvisionedThroughputExceededException'}}, 'mock_update_item' ) - router.table.get_item.side_effect = raise_condition + mm = Mock() + mm.get_item = Mock(side_effect=raise_condition) + router.table = Mock(return_value=mm) with pytest.raises(ClientError) as ex: router.get_uaid(uaid="asdf") assert (ex.value.response['Error']['Code'] == "ProvisionedThroughputExceededException") def test_register_user_provision_failed(self): - r = get_router_table() - router = Router(r, SinkMetrics()) - router.table.meta.client = Mock() + router = Router({}, SinkMetrics(), + resource_pool=autopush.tests.boto_resources) + mm = Mock() + mm.client = Mock() + + router.table = Mock(return_value=mm) def raise_condition(*args, **kwargs): - import autopush.db - raise autopush.db.g_client.exceptions.ClientError( + raise ClientError( {'Error': {'Code': 'ProvisionedThroughputExceededException'}}, 'mock_update_item' ) - router.table.update_item = Mock(side_effect=raise_condition) + mm.update_item = Mock(side_effect=raise_condition) with pytest.raises(ClientError) as ex: router.register_user(dict(uaid=dummy_uaid, node_id="me", connected_at=1234, @@ -413,35 +474,36 @@ def raise_condition(*args, **kwargs): "ProvisionedThroughputExceededException") def test_register_user_condition_failed(self): - r = get_router_table() - router = Router(r, SinkMetrics()) - router.table.meta.client = Mock() + router = Router({}, SinkMetrics(), + resource_pool=autopush.tests.boto_resources) + router.table().meta.client = Mock() def raise_error(*args, **kwargs): - import autopush.db - raise autopush.db.g_client.exceptions.ClientError( + raise ClientError( {'Error': {'Code': 'ConditionalCheckFailedException'}}, 'mock_update_item' ) - - router.table.update_item = Mock(side_effect=raise_error) + mm = Mock() + mm.update_item = Mock(side_effect=raise_error) + router.table = Mock(return_value=mm) res = router.register_user(dict(uaid=dummy_uaid, node_id="me", connected_at=1234, router_type="webpush")) assert res == (False, {}) def test_clear_node_provision_failed(self): - r = get_router_table() - router = Router(r, SinkMetrics()) + router = Router({}, SinkMetrics(), + resource_pool=autopush.tests.boto_resources) def raise_condition(*args, **kwargs): - import autopush.db - raise autopush.db.g_client.exceptions.ClientError( + raise ClientError( {'Error': {'Code': 'ProvisionedThroughputExceededException'}}, 'mock_update_item' ) - router.table.put_item = Mock(side_effect=raise_condition) + mm = Mock() + mm.put_item = Mock(side_effect=raise_condition) + router.table = Mock(return_value=mm) with pytest.raises(ClientError) as ex: router.clear_node(dict(uaid=dummy_uaid, connected_at="1234", @@ -451,32 +513,33 @@ def raise_condition(*args, **kwargs): "ProvisionedThroughputExceededException") def test_clear_node_condition_failed(self): - r = get_router_table() - router = Router(r, SinkMetrics()) + router = Router({}, SinkMetrics(), + resource_pool=autopush.tests.boto_resources) def raise_error(*args, **kwargs): - import autopush.db - raise autopush.db.g_client.exceptions.ClientError( + raise ClientError( {'Error': {'Code': 'ConditionalCheckFailedException'}}, 'mock_put_item' ) - router.table.put_item = Mock(side_effect=raise_error) + router.table().put_item = Mock( + side_effect=raise_error) res = router.clear_node(dict(uaid=dummy_uaid, connected_at="1234", node_id="asdf", router_type="webpush")) + assert res is False def test_incomplete_uaid(self): # Older records may be incomplete. We can't inject them using normal # methods. uaid = str(uuid.uuid4()) - r = get_router_table() - router = Router(r, SinkMetrics()) - router.table.get_item = Mock() - router.drop_user = Mock() - router.table.get_item.return_value = { + router = Router({}, SinkMetrics(), + resource_pool=autopush.tests.boto_resources) + mm = Mock() + mm.get_item = Mock() + mm.get_item.return_value = { "ResponseMetadata": { "HTTPStatusCode": 200 }, @@ -484,6 +547,13 @@ def test_incomplete_uaid(self): "uaid": uuid.uuid4().hex } } + mm.delete_item.return_value = { + "ResponseMetadata": { + "HTTPStatusCode": 200 + }, + } + router.table = Mock(return_value=mm) + router.drop_user = Mock() try: router.register_user(dict(uaid=uaid)) except AutopushException: @@ -494,24 +564,28 @@ def test_incomplete_uaid(self): def test_failed_uaid(self): uaid = str(uuid.uuid4()) - r = get_router_table() - router = Router(r, SinkMetrics()) - router.table.get_item = Mock() - router.drop_user = Mock() - router.table.get_item.return_value = { + router = Router({}, SinkMetrics(), + resource_pool=autopush.tests.boto_resources) + mm = Mock() + mm.get_item = Mock() + mm.get_item.return_value = { "ResponseMetadata": { "HTTPStatusCode": 400 }, } + router.table = Mock(return_value=mm) + router.drop_user = Mock() with pytest.raises(ItemNotFound): router.get_uaid(uaid) def test_save_new(self): - r = get_router_table() - router = Router(r, SinkMetrics()) + router = Router({}, SinkMetrics(), + resource_pool=autopush.tests.boto_resources) # Sadly, moto currently does not return an empty value like boto # when not updating data. - router.table.update_item = Mock(return_value={}) + mock_update = Mock() + mock_update.update_item = Mock(return_value={}) + router.table = Mock(return_value=mock_update) result = router.register_user(dict(uaid=dummy_uaid, node_id="me", router_type="webpush", @@ -519,25 +593,26 @@ def test_save_new(self): assert result[0] is True def test_save_fail(self): - r = get_router_table() - router = Router(r, SinkMetrics()) + router = Router({}, SinkMetrics(), + resource_pool=autopush.tests.boto_resources) def raise_condition(*args, **kwargs): - import autopush.db - raise autopush.db.g_client.exceptions.ClientError( + raise ClientError( {'Error': {'Code': 'ConditionalCheckFailedException'}}, 'mock_update_item' ) - router.table.update_item = Mock(side_effect=raise_condition) + mock_update = Mock() + mock_update.update_item = Mock(side_effect=raise_condition) + router.table = Mock(return_value=mock_update) router_data = dict(uaid=dummy_uaid, node_id="asdf", connected_at=1234, router_type="webpush") result = router.register_user(router_data) assert result == (False, {}) def test_node_clear(self): - r = get_router_table() - router = Router(r, SinkMetrics()) + router = Router({}, SinkMetrics(), + resource_pool=autopush.tests.boto_resources) # Register a node user router.register_user(dict(uaid=dummy_uaid, node_id="asdf", @@ -559,8 +634,8 @@ def test_node_clear(self): assert user["router_type"] == "webpush" def test_node_clear_fail(self): - r = get_router_table() - router = Router(r, SinkMetrics()) + router = Router({}, SinkMetrics(), + resource_pool=autopush.tests.boto_resources) def raise_condition(*args, **kwargs): raise ClientError( @@ -568,15 +643,17 @@ def raise_condition(*args, **kwargs): 'mock_update_item' ) - router.table.put_item = Mock(side_effect=raise_condition) + mock_put = Mock() + mock_put.put_item = Mock(side_effect=raise_condition) + router.table = Mock(return_value=mock_put) data = dict(uaid=dummy_uaid, node_id="asdf", connected_at=1234) result = router.clear_node(data) assert result is False def test_drop_user(self): uaid = str(uuid.uuid4()) - r = get_router_table() - router = Router(r, SinkMetrics()) + router = Router({}, SinkMetrics(), + resource_pool=autopush.tests.boto_resources) # Register a node user router.register_user(dict(uaid=uaid, node_id="asdf", router_type="webpush", diff --git a/autopush/tests/test_diagnostic_cli.py b/autopush/tests/test_diagnostic_cli.py index 9094f01e..4958dd6b 100644 --- a/autopush/tests/test_diagnostic_cli.py +++ b/autopush/tests/test_diagnostic_cli.py @@ -2,6 +2,8 @@ from mock import Mock, patch +import autopush.tests + class FakeDict(dict): pass @@ -10,14 +12,18 @@ class FakeDict(dict): class DiagnosticCLITestCase(unittest.TestCase): def _makeFUT(self, *args, **kwargs): from autopush.diagnostic_cli import EndpointDiagnosticCLI - return EndpointDiagnosticCLI(*args, use_files=False, **kwargs) + return EndpointDiagnosticCLI( + *args, + resource_pool=autopush.tests.boto_resources, + use_files=False, + **kwargs) def test_basic_load(self): cli = self._makeFUT([ "--router_tablename=fred", "http://someendpoint", ]) - assert cli.db.router.table.table_name == "fred" + assert cli.db.router.table().table_name == "fred" def test_bad_endpoint(self): cli = self._makeFUT([ @@ -27,9 +33,10 @@ def test_bad_endpoint(self): returncode = cli.run() assert returncode not in (None, 0) + @patch("autopush.diagnostic_cli.Message") @patch("autopush.diagnostic_cli.AutopushConfig") @patch("autopush.diagnostic_cli.DatabaseManager.from_config") - def test_successfull_lookup(self, mock_db_cstr, mock_conf_class): + def test_successfull_lookup(self, mock_db_cstr, mock_conf_class, mock_msg): from autopush.diagnostic_cli import run_endpoint_diagnostic_cli mock_conf_class.return_value = mock_conf = Mock() mock_conf.parse_endpoint.return_value = dict( @@ -39,11 +46,14 @@ def test_successfull_lookup(self, mock_db_cstr, mock_conf_class): mock_db.router.get_uaid.return_value = mock_item = FakeDict() mock_item._data = {} mock_item["current_month"] = "201608120002" - mock_message_table = Mock() - mock_db.message_tables = {"201608120002": mock_message_table} - - run_endpoint_diagnostic_cli([ - "--router_tablename=fred", - "http://something/wpush/v1/legit_endpoint", - ], use_files=False) - mock_message_table.all_channels.assert_called() + mock_db.message_tables = ["201608120002"] + mock_msg.return_value = mock_message = Mock() + + run_endpoint_diagnostic_cli( + sysargs=[ + "--router_tablename=fred", + "http://something/wpush/v1/legit_endpoint", + ], + use_files=False, + resource_pool=autopush.tests.boto_resources) + mock_message.all_channels.assert_called() diff --git a/autopush/tests/test_endpoint.py b/autopush/tests/test_endpoint.py index af7426aa..70757f70 100644 --- a/autopush/tests/test_endpoint.py +++ b/autopush/tests/test_endpoint.py @@ -14,6 +14,7 @@ Message, ItemNotFound, has_connected_this_month, + Router ) from autopush.exceptions import RouterException from autopush.http import EndpointHTTPFactory @@ -51,7 +52,7 @@ def setUp(self): crypto_key='AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=', ) db = test_db() - self.message_mock = db.message = Mock(spec=Message) + self.message_mock = db._message = Mock(spec=Message) self.fernet_mock = conf.fernet = Mock(spec=Fernet) app = EndpointHTTPFactory.for_handler(MessageHandler, conf, db=db) @@ -140,6 +141,7 @@ def setUp(self): self.fernet_mock = conf.fernet = Mock(spec=Fernet) self.db = db = test_db() + db.router = Mock(spec=Router) db.router.register_user.return_value = (True, {}, {}) db.router.get_uaid.return_value = { "router_type": "test", diff --git a/autopush/tests/test_health.py b/autopush/tests/test_health.py index 3e96896f..0a1cfc03 100644 --- a/autopush/tests/test_health.py +++ b/autopush/tests/test_health.py @@ -1,15 +1,13 @@ import json import twisted.internet.base -from boto.dynamodb2.exceptions import InternalServerError from mock import Mock from twisted.internet.defer import inlineCallbacks from twisted.logger import globalLogPublisher from twisted.trial import unittest -import autopush.db from autopush import __version__ -from autopush.config import AutopushConfig +from autopush.config import AutopushConfig, DDBTableConfig from autopush.db import DatabaseManager from autopush.exceptions import MissingTableException from autopush.http import EndpointHTTPFactory @@ -17,6 +15,7 @@ from autopush.tests.client import Client from autopush.tests.support import TestingLogObserver from autopush.web.health import HealthHandler, StatusHandler +import autopush.tests class HealthTestCase(unittest.TestCase): @@ -27,10 +26,13 @@ def setUp(self): conf = AutopushConfig( hostname="localhost", statsd_host=None, + router_table=DDBTableConfig(tablename="router_test"), + message_table=DDBTableConfig(tablename="message_int_test"), ) - db = DatabaseManager.from_config(conf) - db.client = autopush.db.g_client + db = DatabaseManager.from_config( + conf, + resource_pool=autopush.tests.boto_resources) db.setup_tables() # ignore logging @@ -50,82 +52,23 @@ def test_healthy(self): "version": __version__, "clients": 0, "storage": {"status": "OK"}, - "router": {"status": "OK"} + "router_test": {"status": "OK"} }) - @inlineCallbacks - def test_aws_error(self): - - def raise_error(*args, **kwargs): - raise InternalServerError(None, None) - - safe = self.client.app.db.client - self.client.app.db.client = Mock() - self.client.app.db.client.list_tables = Mock(side_effect=raise_error) - - yield self._assert_reply({ - "status": "NOT OK", - "version": __version__, - "clients": 0, - "storage": { - "status": "NOT OK", - "error": "Server error" - }, - "router": { - "status": "NOT OK", - "error": "Server error" - } - }, InternalServerError) - - self.client.app.db.client = safe - @inlineCallbacks def test_nonexistent_table(self): - no_tables = Mock(return_value={"TableNames": []}) - safe = self.client.app.db.client - self.client.app.db.client = Mock() - self.client.app.db.client.list_tables = no_tables + self.client.app.db.message.table().delete() yield self._assert_reply({ "status": "NOT OK", "version": __version__, "clients": 0, + "router_test": {"status": "OK"}, "storage": { "status": "NOT OK", "error": "Nonexistent table" - }, - "router": { - "status": "NOT OK", - "error": "Nonexistent table" } }, MissingTableException) - self.client.app.db.client = safe - - @inlineCallbacks - def test_internal_error(self): - def raise_error(*args, **kwargs): - raise Exception("synergies not aligned") - - safe = self.client.app.db.client - self.client.app.db.client = Mock() - self.client.app.db.client.list_tables = Mock( - side_effect=raise_error - ) - - yield self._assert_reply({ - "status": "NOT OK", - "version": __version__, - "clients": 0, - "storage": { - "status": "NOT OK", - "error": "Internal error" - }, - "router": { - "status": "NOT OK", - "error": "Internal error" - } - }, Exception) - self.client.app.db.client = safe @inlineCallbacks def _assert_reply(self, reply, exception=None): diff --git a/autopush/tests/test_integration.py b/autopush/tests/test_integration.py index b819be2e..aae8e3dd 100644 --- a/autopush/tests/test_integration.py +++ b/autopush/tests/test_integration.py @@ -36,12 +36,14 @@ from autopush.config import AutopushConfig from autopush.db import ( get_month, - has_connected_this_month + has_connected_this_month, + Message, ) from autopush.logging import begin_or_register from autopush.main import ConnectionApplication, EndpointApplication from autopush.utils import base64url_encode, normalize_id from autopush.metrics import SinkMetrics, DatadogMetrics +import autopush.tests from autopush.tests.support import TestingLogObserver from autopush.websocket import PushServerFactory @@ -130,7 +132,8 @@ def register(self, chid=None, key=None): key=key)) log.debug("Send: %s", msg) self.ws.send(msg) - result = json.loads(self.ws.recv()) + rcv = self.ws.recv() + result = json.loads(rcv) log.debug("Recv: %s", result) assert result["status"] == 200 assert result["channelID"] == chid @@ -319,15 +322,19 @@ def setUp(self): ) # Endpoint HTTP router - self.ep = ep = EndpointApplication(ep_conf) - ep.db.client = db.g_client + self.ep = ep = EndpointApplication( + conf=ep_conf, + resource_pool=autopush.tests.boto_resources + ) ep.setup(rotate_tables=False) ep.startService() self.addCleanup(ep.stopService) # Websocket server - self.conn = conn = ConnectionApplication(conn_conf) - conn.db.client = db.g_client + self.conn = conn = ConnectionApplication( + conf=conn_conf, + resource_pool=autopush.tests.boto_resources + ) conn.setup(rotate_tables=False) conn.startService() self.addCleanup(conn.stopService) @@ -472,10 +479,10 @@ def test_legacy_simplepush_record(self): return_value={'router_type': 'simplepush', 'uaid': uaid, 'current_month': self.ep.db.current_msg_month}) - self.ep.db.message_tables[ - self.ep.db.current_msg_month].all_channels = Mock( - return_value=(True, client.channels)) + safe = db.Message.all_channels + db.Message.all_channels = Mock(return_value=(True, client.channels)) yield client.send_notification() + db.Message.all_channels = safe yield self.shut_down(client) @patch("autopush.metrics.datadog") @@ -537,12 +544,13 @@ def test_webpush_data_save_fail(self): yield client.hello() yield client.register(chid=chan) yield client.disconnect() - self.ep.db.message_tables[ - self.ep.db.current_msg_month].store_message = Mock( + safe = db.Message.store_message + db.Message.store_message = Mock( return_value=False) yield client.send_notification(channel=chan, data=test["data"], status=201) + db.Message.store_message = safe yield self.shut_down(client) @@ -1157,15 +1165,16 @@ def test_webpush_monthly_rotation(self): # Move the client back one month to the past last_month = make_rotating_tablename( prefix=self.conn.conf.message_table.tablename, delta=-1) - lm_message = self.conn.db.message_tables[last_month] + lm_message = Message(last_month, resource_pool=self.conn.db.resources) yield deferToThread( self.conn.db.router.update_message_month, client.uaid, - last_month + last_month, ) # Verify the move - c = yield deferToThread(self.conn.db.router.get_uaid, client.uaid) + c = yield deferToThread(self.conn.db.router.get_uaid, + client.uaid) assert c["current_month"] == last_month # Verify last_connect is current, then move that back @@ -1177,33 +1186,34 @@ def test_webpush_monthly_rotation(self): yield deferToThread( self.conn.db.router._update_last_connect, client.uaid, - last_connect, - ) - c = yield deferToThread(self.conn.db.router.get_uaid, client.uaid) + last_connect) + c = yield deferToThread(self.conn.db.router.get_uaid, + client.uaid) assert has_connected_this_month(c) is False # Move the clients channels back one month exists, chans = yield deferToThread( - self.conn.db.message.all_channels, client.uaid + self.conn.db.message.all_channels, + client.uaid ) assert exists is True assert len(chans) == 1 yield deferToThread( lm_message.save_channels, client.uaid, - chans + chans, ) # Remove the channels entry entirely from this month yield deferToThread( - self.conn.db.message.table.delete_item, + self.conn.db.message.table().delete_item, Key={'uaid': client.uaid, 'chidmessageid': ' '} ) # Verify the channel is gone exists, chans = yield deferToThread( self.conn.db.message.all_channels, - client.uaid + client.uaid, ) assert exists is False assert len(chans) == 0 @@ -1239,7 +1249,8 @@ def test_webpush_monthly_rotation(self): start = time.time() while time.time()-start < 2: c = yield deferToThread( - self.conn.db.router.get_uaid, client.uaid) + self.conn.db.router.get_uaid, + client.uaid) if c["current_month"] == self.conn.db.current_msg_month: break else: @@ -1247,7 +1258,8 @@ def test_webpush_monthly_rotation(self): # Verify the month update in the router table c = yield deferToThread( - self.conn.db.router.get_uaid, client.uaid) + self.conn.db.router.get_uaid, + client.uaid) assert c["current_month"] == self.conn.db.current_msg_month assert server_client.ps.rotate_message_table is False @@ -1261,7 +1273,6 @@ def test_webpush_monthly_rotation(self): ) assert exists is True assert len(chans) == 1 - yield self.shut_down(client) @inlineCallbacks @@ -1273,15 +1284,17 @@ def test_webpush_monthly_rotation_prior_record_exists(self): # Move the client back one month to the past last_month = make_rotating_tablename( prefix=self.conn.conf.message_table.tablename, delta=-1) - lm_message = self.conn.db.message_tables[last_month] + lm_message = Message(last_month, + resource_pool=autopush.tests.boto_resources) yield deferToThread( self.conn.db.router.update_message_month, client.uaid, - last_month + last_month, ) # Verify the move - c = yield deferToThread(self.conn.db.router.get_uaid, client.uaid) + c = yield deferToThread(self.conn.db.router.get_uaid, + client.uaid) assert c["current_month"] == last_month # Verify last_connect is current, then move that back @@ -1290,21 +1303,22 @@ def test_webpush_monthly_rotation_prior_record_exists(self): yield deferToThread( self.conn.db.router._update_last_connect, client.uaid, - int("%s%s020001" % (today.year, str(today.month).zfill(2))) + int("%s%s020001" % (today.year, str(today.month).zfill(2))), ) c = yield deferToThread(self.conn.db.router.get_uaid, client.uaid) assert has_connected_this_month(c) is False # Move the clients channels back one month exists, chans = yield deferToThread( - self.conn.db.message.all_channels, client.uaid + self.conn.db.message.all_channels, + client.uaid, ) assert exists is True assert len(chans) == 1 yield deferToThread( lm_message.save_channels, client.uaid, - chans + chans, ) # Send in a notification, verify it landed in last months notification @@ -1338,7 +1352,8 @@ def test_webpush_monthly_rotation_prior_record_exists(self): start = time.time() while time.time()-start < 2: c = yield deferToThread( - self.conn.db.router.get_uaid, client.uaid) + self.conn.db.router.get_uaid, + client.uaid) if c["current_month"] == self.conn.db.current_msg_month: break else: @@ -1359,7 +1374,6 @@ def test_webpush_monthly_rotation_prior_record_exists(self): ) assert exists is True assert len(chans) == 1 - yield self.shut_down(client) @inlineCallbacks @@ -1380,13 +1394,15 @@ def test_webpush_monthly_rotation_no_channels(self): ) # Verify the move - c = yield deferToThread(self.conn.db.router.get_uaid, client.uaid) + c = yield deferToThread(self.conn.db.router.get_uaid, + client.uaid + ) assert c["current_month"] == last_month # Verify there's no channels exists, chans = yield deferToThread( self.conn.db.message.all_channels, - client.uaid + client.uaid, ) assert exists is False assert len(chans) == 0 @@ -1403,17 +1419,19 @@ def test_webpush_monthly_rotation_no_channels(self): start = time.time() while time.time()-start < 2: c = yield deferToThread( - self.conn.db.router.get_uaid, client.uaid) + self.conn.db.router.get_uaid, + client.uaid, + ) if c["current_month"] == self.conn.db.current_msg_month: break else: yield deferToThread(time.sleep, 0.2) # Verify the month update in the router table - c = yield deferToThread(self.conn.db.router.get_uaid, client.uaid) + c = yield deferToThread(self.conn.db.router.get_uaid, + client.uaid) assert c["current_month"] == self.conn.db.current_msg_month assert server_client.ps.rotate_message_table is False - yield self.shut_down(client) @inlineCallbacks diff --git a/autopush/tests/test_main.py b/autopush/tests/test_main.py index 4b95aa07..9ef167ec 100644 --- a/autopush/tests/test_main.py +++ b/autopush/tests/test_main.py @@ -11,7 +11,7 @@ import autopush.db from autopush.config import AutopushConfig -from autopush.db import DatabaseManager, get_rotating_message_table +from autopush.db import DatabaseManager, get_rotating_message_tablename from autopush.exceptions import InvalidConfig from autopush.http import skip_request_logging from autopush.main import ( @@ -20,6 +20,7 @@ ) from autopush.tests.support import test_db from autopush.utils import resolve_ip +import autopush.tests connection_main = ConnectionApplication.main endpoint_main = EndpointApplication.main @@ -60,19 +61,21 @@ def test_update_rotating_tables(self): from autopush.db import get_month conf = AutopushConfig( hostname="example.com", resolve_hostname=True) - db = DatabaseManager.from_config(conf) + db = DatabaseManager.from_config( + conf, + resource_pool=autopush.tests.boto_resources) db.create_initial_message_tables() # Erase the tables it has on init, and move current month back one last_month = get_month(-1) db.current_month = last_month.month - db.message_tables = {} + db.message_tables = [] # Create the next month's table, just in case today is the day before # a new month, in which case the lack of keys will cause an error in # update_rotating_tables next_month = get_month(1) - db.message_tables[next_month.month] = None + assert next_month.month not in db.message_tables # Get the deferred back e = Deferred() @@ -115,7 +118,9 @@ def test_update_rotating_tables_month_end(self): conf = AutopushConfig( hostname="example.com", resolve_hostname=True) - db = DatabaseManager.from_config(conf) + db = DatabaseManager.from_config( + conf, + resource_pool=autopush.tests.boto_resources) db._tomorrow = Mock(return_value=tomorrow) db.create_initial_message_tables() @@ -123,37 +128,35 @@ def test_update_rotating_tables_month_end(self): assert len(db.message_tables) == 3 # Grab next month's table name and remove it - next_month = get_rotating_message_table( + resource = db.resources.fetch() + next_month = get_rotating_message_tablename( conf.message_table.tablename, - delta=1 + delta=1, + boto_resource=resource ) - db.message_tables.pop(next_month.table_name) + db.resources.release(resource) + db.message_tables.pop(db.message_tables.index(next_month)) # Get the deferred back d = db.update_rotating_tables() def check_tables(result): assert len(db.message_tables) == 3 - assert next_month.table_name in db.message_tables + assert next_month in db.message_tables d.addCallback(check_tables) return d def test_update_not_needed(self): - from autopush.db import get_month conf = AutopushConfig( hostname="google.com", resolve_hostname=True) - db = DatabaseManager.from_config(conf) + db = DatabaseManager.from_config( + conf, + resource_pool=autopush.tests.boto_resources) db.create_initial_message_tables() # Erase the tables it has on init, and move current month back one - db.message_tables = {} - - # Create the next month's table, just in case today is the day before - # a new month, in which case the lack of keys will cause an error in - # update_rotating_tables - next_month = get_month(1) - db.message_tables[next_month.month] = None + db.message_tables = [] # Get the deferred back e = Deferred() @@ -184,7 +187,7 @@ def tearDown(self): mock.stop() def test_basic(self): - connection_main([], False) + connection_main([], False, resource_pool=autopush.tests.boto_resources) def test_ssl(self): connection_main([ @@ -193,12 +196,12 @@ def test_ssl(self): "--ssl_key=keys/server.key", "--router_ssl_cert=keys/server.crt", "--router_ssl_key=keys/server.key", - ], False) + ], False, resource_pool=autopush.tests.boto_resources) def test_memusage(self): connection_main([ "--memusage_port=8083", - ], False) + ], False, resource_pool=autopush.tests.boto_resources) def test_skip_logging(self): # Should skip setting up logging on the handler @@ -258,6 +261,7 @@ class TestArg(AutopushConfig): use_cryptography = False sts_max_age = 1234 _no_sslcontext_cache = False + aws_ddb_endpoint = None def setUp(self): patchers = [ @@ -277,15 +281,18 @@ def tearDown(self): autopush.db.key_hash = "" def test_basic(self): - endpoint_main([ - ], False) + endpoint_main( + [], + False, + resource_pool=autopush.tests.boto_resources + ) def test_ssl(self): endpoint_main([ "--ssl_dh_param=keys/dhparam.pem", "--ssl_cert=keys/server.crt", "--ssl_key=keys/server.key", - ], False) + ], False, resource_pool=autopush.tests.boto_resources) def test_bad_senderidlist(self): returncode = endpoint_main([ @@ -306,18 +313,18 @@ def test_client_certs(self): "--ssl_cert=keys/server.crt", "--ssl_key=keys/server.key", '--client_certs={"foo": ["%s"]}' % cert - ], False) + ], False, resource_pool=autopush.tests.boto_resources) assert not returncode def test_proxy_protocol_port(self): endpoint_main([ "--proxy_protocol_port=8081", - ], False) + ], False, resource_pool=autopush.tests.boto_resources) def test_memusage(self): endpoint_main([ "--memusage_port=8083", - ], False) + ], False, resource_pool=autopush.tests.boto_resources) def test_client_certs_parse(self): conf = AutopushConfig.from_argparse(self.TestArg) @@ -344,7 +351,8 @@ def test_bad_client_certs(self): @patch('hyper.tls', spec=hyper.tls) def test_conf(self, *args): conf = AutopushConfig.from_argparse(self.TestArg) - app = EndpointApplication(conf) + app = EndpointApplication(conf, + resource_pool=autopush.tests.boto_resources) # verify that the hostname is what we said. assert conf.hostname == self.TestArg.hostname assert app.routers["gcm"].router_conf['collapsekey'] == "collapse" @@ -375,7 +383,7 @@ def test_gcm_start(self): endpoint_main([ "--gcm_enabled", """--senderid_list={"123":{"auth":"abcd"}}""", - ], False) + ], False, resource_pool=autopush.tests.boto_resources) @patch("requests.get") def test_aws_ami_id(self, request_mock): diff --git a/autopush/tests/test_router.py b/autopush/tests/test_router.py index 9e721585..cf25aa23 100644 --- a/autopush/tests/test_router.py +++ b/autopush/tests/test_router.py @@ -7,9 +7,8 @@ import requests import ssl -from autopush.utils import WebPushNotification -from mock import Mock, PropertyMock, patch import pytest +from mock import Mock, PropertyMock, patch from twisted.trial import unittest from twisted.internet.error import ConnectionRefusedError from twisted.internet.defer import inlineCallbacks @@ -35,6 +34,7 @@ from autopush.router.interface import RouterResponse, IRouter from autopush.tests import MockAssist from autopush.tests.support import test_db +from autopush.utils import WebPushNotification class RouterInterfaceTestCase(TestCase): @@ -999,7 +999,7 @@ def setUp(self): mock_result.not_registered = dict() mock_result.retry_after = 1000 self.router_mock = db.router - self.message_mock = db.message = Mock(spec=Message) + self.message_mock = db._message = Mock(spec=Message) self.conf = conf def test_route_to_busy_node_saves_looks_up_and_sends_check_201(self): @@ -1009,6 +1009,7 @@ def test_route_to_busy_node_saves_looks_up_and_sends_check_201(self): side_effect=MockAssist([202, 200])) self.message_mock.store_message.return_value = True self.message_mock.all_channels.return_value = (True, [dummy_chid]) + self.db.message_table = Mock(return_value=self.message_mock) router_data = dict(node_id="http://somewhere", uaid=dummy_uaid, current_month=self.db.current_msg_month) self.router_mock.get_uaid.return_value = router_data @@ -1033,6 +1034,7 @@ def test_route_failure(self): self.agent_mock.request = Mock(side_effect=ConnectionRefusedError) self.message_mock.store_message.return_value = True self.message_mock.all_channels.return_value = (True, [dummy_chid]) + self.db.message_table = Mock(return_value=self.message_mock) router_data = dict(node_id="http://somewhere", uaid=dummy_uaid, current_month=self.db.current_msg_month) self.router_mock.get_uaid.return_value = router_data @@ -1070,6 +1072,7 @@ def test_route_to_busy_node_with_ttl_zero(self): side_effect=MockAssist([202, 200])) self.message_mock.store_message.return_value = True self.message_mock.all_channels.return_value = (True, [dummy_chid]) + self.db.message_table = Mock(return_value=self.message_mock) router_data = dict(node_id="http://somewhere", uaid=dummy_uaid, current_month=self.db.current_msg_month) self.router_mock.get_uaid.return_value = router_data @@ -1100,9 +1103,8 @@ def throw(): self.agent_mock.request.return_value = response_mock = Mock() response_mock.code = 202 - self.message_mock.store_message.side_effect = MockAssist( - [throw] - ) + self.message_mock.store_message.side_effect = MockAssist([throw]) + self.db.message_table = Mock(return_value=self.message_mock) router_data = dict(node_id="http://somewhere", uaid=dummy_uaid, current_month=self.db.current_msg_month) @@ -1123,6 +1125,7 @@ def throw(): raise JSONResponseError(500, "Whoops") self.message_mock.store_message.return_value = True + self.db.message_table = Mock(return_value=self.message_mock) self.router_mock.get_uaid.side_effect = MockAssist( [throw] ) @@ -1144,6 +1147,7 @@ def throw(): raise ItemNotFound() self.message_mock.store_message.return_value = True + self.db.message_table = Mock(return_value=self.message_mock) self.router_mock.get_uaid.side_effect = MockAssist( [throw] ) @@ -1160,9 +1164,8 @@ def verify_deliver(status): def test_route_lookup_uaid_no_nodeid(self): self.message_mock.store_message.return_value = True - self.router_mock.get_uaid.return_value = dict( - - ) + self.db.message_table = Mock(return_value=self.message_mock) + self.router_mock.get_uaid.return_value = dict() router_data = dict(node_id="http://somewhere", uaid=dummy_uaid, current_month=self.db.current_msg_month) @@ -1179,6 +1182,7 @@ def test_route_and_clear_failure(self): self.agent_mock.request = Mock(side_effect=ConnectionRefusedError) self.message_mock.store_message.return_value = True self.message_mock.all_channels.return_value = (True, [dummy_chid]) + self.db.message_table = Mock(return_value=self.message_mock) router_data = dict(node_id="http://somewhere", uaid=dummy_uaid, current_month=self.db.current_msg_month) self.router_mock.get_uaid.return_value = router_data diff --git a/autopush/tests/test_web_base.py b/autopush/tests/test_web_base.py index dc56255d..fd1d2f71 100644 --- a/autopush/tests/test_web_base.py +++ b/autopush/tests/test_web_base.py @@ -195,8 +195,7 @@ def test_response_err(self): def test_overload_err(self): try: - import autopush.db - raise autopush.db.g_client.exceptions.ClientError( + raise ClientError( {'Error': { 'Code': 'ProvisionedThroughputExceededException'}}, 'mock_update_item' @@ -208,8 +207,7 @@ def test_overload_err(self): def test_client_err(self): try: - import autopush.db - raise autopush.db.g_client.exceptions.ClientError( + raise ClientError( {'Error': { 'Code': 'Flibbertygidgit'}}, 'mock_update_item' diff --git a/autopush/tests/test_web_webpush.py b/autopush/tests/test_web_webpush.py index ca6327f5..204fe0cd 100644 --- a/autopush/tests/test_web_webpush.py +++ b/autopush/tests/test_web_webpush.py @@ -20,6 +20,7 @@ class TestWebpushHandler(unittest.TestCase): def setUp(self): + import autopush from autopush.web.webpush import WebPushHandler self.conf = conf = AutopushConfig( @@ -30,11 +31,13 @@ def setUp(self): self.fernet_mock = conf.fernet = Mock(spec=Fernet) self.db = db = test_db() - self.message_mock = db.message = Mock(spec=Message) + self.message_mock = db._message = Mock(spec=Message) + self.db.message_table = Mock(return_value=self.message_mock) self.message_mock.all_channels.return_value = (True, [dummy_chid]) app = EndpointHTTPFactory.for_handler(WebPushHandler, conf, db=db) self.wp_router_mock = app.routers["webpush"] = Mock(spec=IRouter) + self.db.router = Mock(spec=autopush.db.Router) self.client = Client(app) def url(self, **kwargs): @@ -48,16 +51,18 @@ def test_router_needs_update(self): public_key="asdfasdf", )) self.fernet_mock.decrypt.return_value = dummy_token - self.db.router.get_uaid.return_value = dict( + self.db.router.get_uaid = Mock(return_value=dict( router_type="webpush", router_data=dict(), uaid=dummy_uaid, current_month=self.db.current_msg_month, - ) + )) + self.db.router.register_user = Mock(return_value=False) self.wp_router_mock.route_notification.return_value = RouterResponse( status_code=503, router_data=dict(token="new_connect"), ) + self.db.message_tables.append(self.db.current_msg_month) resp = yield self.client.post( self.url(api_ver="v1", token=dummy_token), @@ -85,6 +90,7 @@ def test_router_returns_data_without_detail(self): status_code=503, router_data=dict(), ) + self.db.message_tables.append(self.db.current_msg_month) resp = yield self.client.post( self.url(api_ver="v1", token=dummy_token), diff --git a/autopush/tests/test_webpush_server.py b/autopush/tests/test_webpush_server.py index 2805ff37..025d14c8 100644 --- a/autopush/tests/test_webpush_server.py +++ b/autopush/tests/test_webpush_server.py @@ -15,6 +15,7 @@ DatabaseManager, generate_last_connect, make_rotating_tablename, + Message, ) from autopush.metrics import SinkMetrics from autopush.config import AutopushConfig @@ -35,6 +36,7 @@ Unregister, WebPushMessage, ) +import autopush.tests class AutopushCall(object): @@ -170,14 +172,18 @@ def setUp(self): begin_or_register(self.logs) self.addCleanup(globalLogPublisher.removeObserver, self.logs) - self.db = db = DatabaseManager.from_config(self.conf) + self.db = db = DatabaseManager.from_config( + self.conf, + resource_pool=autopush.tests.boto_resources) self.metrics = db.metrics = Mock(spec=SinkMetrics) db.setup_tables() def _store_messages(self, uaid, topic=False, num=5): try: item = self.db.router.get_uaid(uaid.hex) - message_table = self.db.message_tables[item["current_month"]] + message_table = Message( + item["current_month"], + resource_pool=autopush.tests.boto_resources) except ItemNotFound: message_table = self.db.message messages = [WebPushNotificationFactory(uaid=uaid) @@ -441,7 +447,9 @@ def test_migrate_user(self): # Check that it's there item = self.db.router.get_uaid(uaid) - _, channels = self.db.message_tables[last_month].all_channels(uaid) + _, channels = Message( + last_month, + resource_pool=self.db.resources).all_channels(uaid) assert item["current_month"] != self.db.current_msg_month assert item is not None assert len(channels) == 3 @@ -504,16 +512,18 @@ def test_register_bad_chid_nodash(self): self._test_invalid(uuid4().hex) def test_register_over_provisioning(self): + import autopush def raise_condition(*args, **kwargs): - import autopush.db - raise autopush.db.g_client.exceptions.ClientError( + from botocore.exceptions import ClientError + raise ClientError( {'Error': {'Code': 'ProvisionedThroughputExceededException'}}, 'mock_update_item' ) - self.db.message.table.update_item = Mock( - side_effect=raise_condition) + mock_table = Mock(spec=autopush.db.Message) + mock_table.register_channel = Mock(side_effect=raise_condition) + self.db.message_table = Mock(return_value=mock_table) self._test_invalid(str(uuid4()), "overloaded", 503) diff --git a/autopush/tests/test_websocket.py b/autopush/tests/test_websocket.py index 9f973dad..a09f44f2 100644 --- a/autopush/tests/test_websocket.py +++ b/autopush/tests/test_websocket.py @@ -8,6 +8,7 @@ import twisted.internet.base from autobahn.twisted.util import sleep from autobahn.websocket.protocol import ConnectionRequest +from botocore.exceptions import ClientError from mock import Mock, patch import pytest from twisted.internet import reactor @@ -25,7 +26,6 @@ from autopush.db import DatabaseManager from autopush.http import InternalRouterHTTPFactory from autopush.metrics import SinkMetrics -from autopush.tests import MockAssist from autopush.utils import WebPushNotification from autopush.tests.client import Client from autopush.tests.test_db import make_webpush_notification @@ -37,6 +37,7 @@ WebSocketServerProtocol, ) from autopush.utils import base64url_encode, ms_time +import autopush.tests dummy_version = (u'gAAAAABX_pXhN22H-hvscOHsMulKvtC0hKJimrZivbgQPFB3sQAtOPmb' @@ -101,7 +102,10 @@ def setUp(self): statsd_host=None, env="test", ) - db = DatabaseManager.from_config(conf) + db = DatabaseManager.from_config( + conf, + resource_pool=autopush.tests.boto_resources + ) self.metrics = db.metrics = Mock(spec=SinkMetrics) db.setup_tables() @@ -381,7 +385,9 @@ def test_close_with_delivery_cleanup(self): self.proto.ps.direct_updates[chid] = [notif] # Apply some mocks - self.proto.db.message.store_message = Mock() + msg_mock = Mock(spec=db.Message) + msg_mock.store_message = Mock() + self.proto.db.message_table = Mock(return_value=msg_mock) self.proto.db.router.get_uaid = mock_get = Mock() mock_get.return_value = dict(node_id="localhost:2000") @@ -400,7 +406,9 @@ def test_close_with_delivery_cleanup_using_webpush(self): self.proto.ps.direct_updates[dummy_chid_str] = [dummy_notif()] # Apply some mocks - self.proto.db.message.store_message = Mock() + msg_mock = Mock(spec=db.Message) + msg_mock.store_message = Mock() + self.proto.db.message_table = Mock(return_value=msg_mock) self.proto.db.router.get_uaid = mock_get = Mock() mock_get.return_value = dict(node_id="localhost:2000") @@ -421,7 +429,9 @@ def test_close_with_delivery_cleanup_and_get_no_result(self): self.proto.ps.direct_updates[chid] = [notif] # Apply some mocks - self.proto.db.message.store_message = Mock() + msg_mock = Mock(spec=db.Message) + msg_mock.store_message = Mock() + self.proto.db.message_table = Mock(return_value=msg_mock) self.proto.db.router.get_uaid = mock_get = Mock() mock_get.return_value = False self.metrics.reset_mock() @@ -451,7 +461,7 @@ def test_hello_old(self): "current_month": msg_date, } router = self.proto.db.router - router.table.put_item( + router.table().put_item( Item=dict( uaid=orig_uaid, connected_at=ms_time(), @@ -460,7 +470,7 @@ def test_hello_old(self): ) ) - def fake_msg(data): + def fake_msg(data, **kwargs): return (True, msg_data) mock_msg = Mock(wraps=db.Message) @@ -473,12 +483,10 @@ def fake_msg(data): # notifications are irrelevant for this test. self.proto.process_notifications = Mock() # massage message_tables to include our fake range - mt = self.proto.ps.db.message_tables - for k in mt.keys(): - del(mt[k]) - mt['message_2016_1'] = mock_msg - mt['message_2016_2'] = mock_msg - mt['message_2016_3'] = mock_msg + self.proto.ps.db.message_tables = [ + 'message_2016_1', 'message_2016_2', 'message_2016_3' + ] + self.proto.ps.db.message_table = Mock(return_value=mock_msg) with patch.object(datetime, 'date', Mock(wraps=datetime.date)) as patched: patched.today.return_value = target_day @@ -521,7 +529,7 @@ def test_hello_tomorrow(self): "current_month": msg_date, } - def fake_msg(data): + def fake_msg(data, **kwargs): return (True, msg_data) mock_msg = Mock(wraps=db.Message) @@ -530,12 +538,10 @@ def fake_msg(data): mock_msg.all_channels.return_value = (None, []) self.proto.db.router.register_user = fake_msg # massage message_tables to include our fake range - mt = self.proto.ps.db.message_tables - for k in mt.keys(): - del(mt[k]) - mt['message_2016_1'] = mock_msg - mt['message_2016_2'] = mock_msg - mt['message_2016_3'] = mock_msg + self.proto.db.message_table = Mock(return_value=mock_msg) + self.proto.ps.db.message_tables = [ + 'message_2016_1', 'message_2016_2', 'message_2016_3' + ] with patch.object(datetime, 'date', Mock(wraps=datetime.date)) as patched: patched.today.return_value = target_day @@ -559,10 +565,11 @@ def fake_msg(data): def test_hello_tomorrow_provision_error(self): orig_uaid = "deadbeef00000000abad1dea00000000" router = self.proto.db.router + current_month = "message_2016_3" router.register_user(dict( uaid=orig_uaid, connected_at=ms_time(), - current_month="message_2016_3", + current_month=current_month, router_type="webpush", )) @@ -580,36 +587,42 @@ def test_hello_tomorrow_provision_error(self): "current_month": msg_date, } - def fake_msg(data): - return (True, msg_data) - mock_msg = Mock(wraps=db.Message) mock_msg.fetch_messages.return_value = "01;", [] mock_msg.fetch_timestamp_messages.return_value = None, [] mock_msg.all_channels.return_value = (None, []) - self.proto.db.router.register_user = fake_msg + self.proto.db.router.register_user = Mock( + return_value=(True, msg_data) + ) # massage message_tables to include our fake range - mt = self.proto.ps.db.message_tables - mt.clear() - mt['message_2016_1'] = mock_msg - mt['message_2016_2'] = mock_msg - mt['message_2016_3'] = mock_msg - + self.proto.ps.db.message_tables = [ + 'message_2016_1', 'message_2016_2', current_month + ] + self.proto.db.message_table = Mock(return_value=mock_msg) patch_range = patch("autopush.websocket.randrange") mock_patch = patch_range.start() mock_patch.return_value = 1 def raise_condition(*args, **kwargs): - import autopush.db - raise autopush.db.g_client.exceptions.ClientError( + raise ClientError( {'Error': {'Code': 'ProvisionedThroughputExceededException'}}, 'mock_update_item' ) - self.proto.db.router.update_message_month = MockAssist([ - raise_condition, - Mock(), - ]) + self.proto.db.register_user = Mock(return_value=(False, {})) + mock_router = Mock(spec=db.Router) + mock_router.register_user = Mock(return_value=(True, msg_data)) + mock_router.update_message_month = Mock(side_effect=raise_condition) + self.proto.db.router = mock_router + self.proto.db.router.get_uaid = Mock(return_value={ + "router_type": "webpush", + "connected_at": int(msg_day.strftime("%s")), + "current_month": current_month, + "last_connect": int(msg_day.strftime("%s")), + "record_version": 1, + }) + self.proto.db.current_msg_month = current_month + self.proto.ps.message_month = current_month with patch.object(datetime, 'date', Mock(wraps=datetime.date)) as patched: @@ -711,7 +724,9 @@ def test_hello_failure(self): self._connect() # Fail out the register_user call router = self.proto.db.router - router.table.update_item = Mock(side_effect=KeyError) + mock_up = Mock() + mock_up.update_item = Mock(side_effect=KeyError) + router.table = Mock(return_value=mock_up) self._send_message(dict(messageType="hello", channelIDs=[], use_webpush=True, stop=1)) @@ -727,15 +742,15 @@ def test_hello_provisioned_during_check(self): # Fail out the register_user call def raise_condition(*args, **kwargs): - import autopush.db - raise autopush.db.g_client.exceptions.ClientError( + raise ClientError( {'Error': {'Code': 'ProvisionedThroughputExceededException'}}, 'mock_update_item' ) router = self.proto.db.router - router.table.update_item = Mock(side_effect=raise_condition) - + mock_table = Mock() + mock_table.update_item = Mock(side_effect=raise_condition) + router.table = Mock(return_value=mock_table) self._send_message(dict(messageType="hello", use_webpush=True, channelIDs=[])) msg = yield self.get_response() @@ -911,10 +926,12 @@ def test_register_webpush(self): self._connect() chid = str(uuid.uuid4()) self.proto.ps.uaid = uuid.uuid4().hex - self.proto.db.message.register_channel = Mock() + msg_mock = Mock(spec=db.Message) + msg_mock.register_channel = Mock() + self.proto.db.message_table = Mock(return_value=msg_mock) yield self.proto.process_register(dict(channelID=chid)) - assert self.proto.db.message.register_channel.called + assert msg_mock.register_channel.called assert_called_included(self.proto.log.info, format="Register") @inlineCallbacks @@ -922,7 +939,8 @@ def test_register_webpush_with_key(self): self._connect() chid = str(uuid.uuid4()) self.proto.ps.uaid = uuid.uuid4().hex - self.proto.db.message.register_channel = Mock() + msg_mock = Mock(spec=db.Message) + self.proto.db.message_table = Mock(return_value=msg_mock) test_key = "SomeRandomCryptoKeyString" test_sha = sha256(test_key).hexdigest() test_endpoint = ('http://localhost/wpush/v2/' + @@ -942,7 +960,7 @@ def echo(string): ) assert test_endpoint == self.proto.sendJSON.call_args[0][0][ 'pushEndpoint'] - assert self.proto.db.message.register_channel.called + assert msg_mock.register_channel.called assert_called_included(self.proto.log.info, format="Register") @inlineCallbacks @@ -1046,11 +1064,12 @@ def test_register_over_provisioning(self): self._connect() chid = str(uuid.uuid4()) self.proto.ps.uaid = uuid.uuid4().hex - self.proto.db.message.register_channel = register = Mock() + msg_mock = Mock(spec=db.Message) + msg_mock.register_channel = register = Mock() + self.proto.ps.db.message_table = Mock(return_value=msg_mock) def raise_condition(*args, **kwargs): - import autopush.db - raise autopush.db.g_client.exceptions.ClientError( + raise ClientError( {'Error': {'Code': 'ProvisionedThroughputExceededException'}}, 'mock_update_item' ) @@ -1058,7 +1077,7 @@ def raise_condition(*args, **kwargs): register.side_effect = raise_condition yield self.proto.process_register(dict(channelID=chid)) - assert self.proto.db.message.register_channel.called + assert msg_mock.register_channel.called assert self.send_mock.called args, _ = self.send_mock.call_args msg = json.loads(args[0]) @@ -1305,9 +1324,11 @@ def test_process_notifications(self): self.proto.ps.uaid = uuid.uuid4().hex # Swap out fetch_notifications - self.proto.db.message.fetch_messages = Mock( + msg_mock = Mock(spec=db.Message) + msg_mock.fetch_messages = Mock( return_value=(None, []) ) + self.proto.ps.db.message_table = Mock(return_value=msg_mock) self.proto.process_notifications() @@ -1337,13 +1358,12 @@ def throw(*args, **kwargs): twisted.internet.base.DelayedCall.debug = True self._connect() - self.proto.db.message.fetch_messages = Mock( - return_value=(None, []) - ) - self.proto.db.message.fetch_messages = Mock( + msg_mock = Mock(spec=db.Message) + msg_mock.fetch_messages = Mock( side_effect=throw) - self.proto.db.message.fetch_timestamp_messages = Mock( + msg_mock.fetch_timestamp_messages = Mock( side_effect=throw) + self.proto.db.message_table = Mock(return_value=msg_mock) self.proto.ps.uaid = uuid.uuid4().hex self.proto.process_notifications() @@ -1367,21 +1387,19 @@ def wait(result): def test_process_notifications_provision_err(self): def raise_condition(*args, **kwargs): - import autopush.db - raise autopush.db.g_client.exceptions.ClientError( + raise ClientError( {'Error': {'Code': 'ProvisionedThroughputExceededException'}}, 'mock_update_item' ) twisted.internet.base.DelayedCall.debug = True self._connect() - self.proto.db.message.fetch_messages = Mock( - return_value=(None, []) - ) - self.proto.db.message.fetch_messages = Mock( + msg_mock = Mock(spec=db.Message) + msg_mock.fetch_messages = Mock( side_effect=raise_condition) - self.proto.db.message.fetch_timestamp_messages = Mock( + msg_mock.fetch_timestamp_messages = Mock( side_effect=raise_condition) + self.proto.db.message_table = Mock(return_value=msg_mock) self.proto.deferToLater = Mock() self.proto.ps.uaid = uuid.uuid4().hex @@ -1488,7 +1506,9 @@ def test_notif_finished_with_too_many_messages(self): self.proto.ps.uaid = uuid.uuid4().hex self.proto.ps._check_notifications = True self.proto.db.router.drop_user = Mock() - self.proto.ps.message.fetch_messages = Mock() + msg_mock = Mock() + msg_mock.fetch_messages = Mock() + self.proto.ps.db.message_table = Mock(return_value=msg_mock) notif = make_webpush_notification( self.proto.ps.uaid, @@ -1496,7 +1516,7 @@ def test_notif_finished_with_too_many_messages(self): ttl=500 ) self.proto.ps.updates_sent = defaultdict(lambda: []) - self.proto.ps.message.fetch_messages.return_value = ( + msg_mock.fetch_messages.return_value = ( None, [notif, notif, notif] ) diff --git a/autopush/web/health.py b/autopush/web/health.py index d28265b2..1f2be960 100644 --- a/autopush/web/health.py +++ b/autopush/web/health.py @@ -33,14 +33,15 @@ def get(self): ) dl = DeferredList([ - self._check_table(self.db.router.table), - self._check_table(self.db.message.table, "storage") + self._check_table(self.db.router.table()), + self._check_table(self.db.message.table(), "storage"), ]) dl.addBoth(self._finish_response) def _check_table(self, table, name_over=None): """Checks the tables known about in DynamoDB""" - d = deferToThread(table_exists, table.table_name, self.db.client) + d = deferToThread(table_exists, table.table_name, + resource_pool=self.db.resources) d.addCallback(self._check_success, name_over or table.table_name) d.addErrback(self._check_error, name_over or table.table_name) return d @@ -61,7 +62,7 @@ def _check_error(self, failure, name): if failure.check(InternalServerError): cause["error"] = "Server error" elif failure.check(MissingTableException): - cause["error"] = failure.getErrorMessage() + cause["error"] = failure.value.message else: cause["error"] = "Internal error" diff --git a/autopush/web/webpush.py b/autopush/web/webpush.py index 2d6002fa..954198b8 100644 --- a/autopush/web/webpush.py +++ b/autopush/web/webpush.py @@ -149,7 +149,8 @@ def _validate_webpush(self, d, result): db.router.drop_user(uaid) raise InvalidRequest("No such subscription", status_code=410, errno=106) - exists, chans = db.message_tables[month_table].all_channels(uaid=uaid) + msg = db.message_table(month_table) + exists, chans = msg.all_channels(uaid=uaid) if (not exists or channel_id.lower() not in map(lambda x: normalize_id(x), chans)): diff --git a/autopush/webpush_server.py b/autopush/webpush_server.py index 596cfe6e..58361292 100644 --- a/autopush/webpush_server.py +++ b/autopush/webpush_server.py @@ -25,6 +25,8 @@ has_connected_this_month, hasher, generate_last_connect, + Message, + BotoResources, ) from autopush.config import AutopushConfig # noqa @@ -459,10 +461,11 @@ def process(self, command): def _check_storage(self, command): timestamp = None messages = [] - message = self.db.message_tables[command.message_month] + message = Message(command.message_month, + resource_pool=self.db.resources) if command.include_topic: timestamp, messages = message.fetch_messages( - uaid=command.uaid, limit=11, + uaid=command.uaid, limit=11 ) # If we have topic messages, return them immediately @@ -488,7 +491,8 @@ def _check_storage(self, command): class IncrementStorageCommand(ProcessorCommand): def process(self, command): # type: (IncStoragePosition) -> IncStoragePositionResponse - message = self.db.message_tables[command.message_month] + message = Message(command.message_month, + resource_pool=self.db.resources) message.update_last_message_read(command.uaid, command.timestamp) return IncStoragePositionResponse() @@ -497,7 +501,8 @@ class DeleteMessageCommand(ProcessorCommand): def process(self, command): # type: (DeleteMessage) -> DeleteMessageResponse notif = command.message.to_WebPushNotification() - message = self.db.message_tables[command.message_month] + message = Message(command.message_month, + resource_pool=self.db.resources) message.delete_message(notif) return DeleteMessageResponse() @@ -513,25 +518,30 @@ class MigrateUserCommand(ProcessorCommand): def process(self, command): # type: (MigrateUser) -> MigrateUserResponse # Get the current channels for this month - message = self.db.message_tables[command.message_month] + message = Message(command.message_month, + resource_pool=self.db.resources) _, channels = message.all_channels(command.uaid.hex) # Get the current message month cur_month = self.db.current_msg_month if channels: # Save the current channels into this months message table - msg_table = self.db.message_tables[cur_month] - msg_table.save_channels(command.uaid.hex, channels) + msg_table = Message(cur_month, + resource_pool=self.db.resources) + msg_table.save_channels(command.uaid.hex, + channels) # Finally, update the route message month - self.db.router.update_message_month(command.uaid.hex, cur_month) + self.db.router.update_message_month(command.uaid.hex, + cur_month) return MigrateUserResponse(message_month=cur_month) class StoreMessagesUserCommand(ProcessorCommand): def process(self, command): # type: (StoreMessages) -> StoreMessagesResponse - message = self.db.message_tables[command.message_month] + message = Message(command.message_month, + resource_pool=self.db.resources) for m in command.messages: if "topic" not in m: m["topic"] = None @@ -585,15 +595,15 @@ def process(self, command): command.channel_id, command.key ) - message = self.db.message_tables[command.message_month] + message = self.db.message_table(command.message_month) try: - message.register_channel(command.uaid.hex, command.channel_id) + message.register_channel(command.uaid.hex, + command.channel_id) except ClientError as ex: if (ex.response['Error']['Code'] == "ProvisionedThroughputExceededException"): return RegisterErrorResponse(error_msg="overloaded", status=503) - self.metrics.increment('ua.command.register') log.info( "Register", @@ -634,12 +644,12 @@ def process(self, if not valid: return UnregisterErrorResponse(error_msg=msg) - message = self.db.message_tables[command.message_month] + message = Message(command.message_month, + resource_pool=self.db.resources) # TODO: JSONResponseError not handled (no force_retry) message.unregister_channel(command.uaid.hex, command.channel_id) - # TODO: Clear out any existing tracked messages for this - # channel + # TODO: Clear out any existing tracked messages for this channel self.metrics.increment('ua.command.unregister') # TODO: user/raw_agent? diff --git a/autopush/websocket.py b/autopush/websocket.py index a3de965a..ca68ce50 100644 --- a/autopush/websocket.py +++ b/autopush/websocket.py @@ -249,7 +249,7 @@ def __attrs_post_init__(self): def message(self): # type: () -> Message """Property to access the currently used message table""" - return self.db.message_tables[self.message_month] + return Message(self.message_month, resource_pool=self.db.resources) @property def user_agent(self): @@ -596,7 +596,8 @@ def cleanUp(self, wasClean, code, reason): def _save_webpush_notif(self, notif): """Save a direct_update webpush style notification""" - return deferToThread(self.ps.message.store_message, + message = self.db.message_table(self.ps.message_month) + return deferToThread(message.store_message, notif).addErrback(self.log_failure) def _lookup_node(self, results): @@ -648,7 +649,7 @@ def returnError(self, messageType, reason, statusCode, close=True, if close: self.sendClose() - def err_overload(self, failure, message_type, disconnect=True): + def error_overload(self, failure, message_type, disconnect=True): """Handle database overloads and errors If ``disconnect`` is False, the an overload error is returned and the @@ -668,7 +669,7 @@ def err_overload(self, failure, message_type, disconnect=True): if disconnect: self.transport.pauseProducing() d = self.deferToLater(self.randrange(4, 9), - self.err_finish_overload, message_type) + self.error_finish_overload, message_type) d.addErrback(self.trap_cancel) else: if (failure.value.response["Error"]["Code"] != @@ -678,7 +679,7 @@ def err_overload(self, failure, message_type, disconnect=True): "status": 503} self.sendJSON(send) - def err_finish_overload(self, message_type): + def error_finish_overload(self, message_type): """Close the connection down and resume consuming input after the random interval from a db overload""" # Resume producing so we can finish the shutdown @@ -716,8 +717,8 @@ def process_hello(self, data): d = self.deferToThread(self._register_user, existing_user) d.addCallback(self._check_other_nodes) d.addErrback(self.trap_cancel) - d.addErrback(self.err_overload, "hello") - d.addErrback(self.err_hello) + d.addErrback(self.error_overload, "hello") + d.addErrback(self.error_hello) self.ps._register = d return d @@ -779,8 +780,7 @@ def _verify_user_record(self): self.log.debug(format="Dropping User", code=105, uaid_hash=self.ps.uaid_hash, uaid_record=repr(record)) - self.force_retry(self.db.router.drop_user, - self.ps.uaid) + self.force_retry(self.db.router.drop_user, self.ps.uaid) tags = ['code:105'] self.metrics.increment("ua.expiration", tags=tags) return None @@ -806,7 +806,7 @@ def _verify_user_record(self): record["connected_at"] = self.ps.connected_at return record - def err_hello(self, failure): + def error_hello(self, failure): """errBack for hello failures""" self.transport.resumeProducing() self.log_failure(failure) @@ -903,11 +903,14 @@ def process_notifications(self): def webpush_fetch(self): """Helper to return an appropriate function to fetch messages""" + message = self.db.message_table(self.ps.message_month) if self.ps.scan_timestamps: - return partial(self.ps.message.fetch_timestamp_messages, - self.ps.uaid_obj, self.ps.current_timestamp) + return partial(message.fetch_timestamp_messages, + self.ps.uaid_obj, + self.ps.current_timestamp) else: - return partial(self.ps.message.fetch_messages, self.ps.uaid_obj) + return partial(message.fetch_messages, + self.ps.uaid_obj) def error_notifications(self, fail): """errBack for notification check failing""" @@ -931,7 +934,11 @@ def error_notification_overload(self, fail): def error_message_overload(self, fail): """errBack for handling excessive messages per UAID""" fail.trap(MessageOverloadException) - self.force_retry(self.db.router.drop_user, self.ps.uaid) + ddb_resource = self.db.resources.fetch() + try: + self.force_retry(self.db.router.drop_user, self.ps.uaid) + finally: + self.db.resources.release(ddb_resource) self.sendClose() def finish_notifications(self, notifs): @@ -978,7 +985,11 @@ def finish_webpush_notifications(self, result): # Told to reset the user? if self.ps.reset_uaid: - self.force_retry(self.db.router.drop_user, self.ps.uaid) + ddb_resource = self.db.resources.fetch() + try: + self.force_retry(self.db.router.drop_user, self.ps.uaid) + finally: + self.db.resources.release(ddb_resource) self.sendClose() # Not told to check for notifications, do we need to now rotate @@ -990,13 +1001,15 @@ def finish_webpush_notifications(self, result): # Send out all the notifications now = int(time.time()) messages_sent = False + message = self.db.message_table(self.ps.message_month) for notif in notifs: self.ps.stats.stored_retrieved += 1 # If the TTL is too old, don't deliver and fire a delete off if notif.expired(at_time=now): if not notif.sortkey_timestamp: # Delete non-timestamped messages - self.force_retry(self.ps.message.delete_message, notif) + self.force_retry(message.delete_message, + notif) # nocover here as coverage gets confused on the line below # for unknown reasons @@ -1010,7 +1023,8 @@ def finish_webpush_notifications(self, result): raise MessageOverloadException() if notif.topic: self.metrics.increment("ua.notification.topic") - self.metrics.increment('ua.message_data', len(msg.get('data', '')), + self.metrics.increment('ua.message_data', + len(msg.get('data', '')), tags=make_tags(source=notif.source)) self.sendJSON(msg) @@ -1021,8 +1035,9 @@ def finish_webpush_notifications(self, result): # No messages sent, update the record if needed if self.ps.current_timestamp: self.force_retry( - self.ps.message.update_last_message_read, - self.ps.uaid_obj, self.ps.current_timestamp) + message.update_last_message_read, + self.ps.uaid_obj, + self.ps.current_timestamp) # Schedule a new process check self.check_missed_notifications(None) @@ -1047,13 +1062,14 @@ def _monthly_transition(self): """ # Get the current channels for this month - _, channels = self.ps.message.all_channels(self.ps.uaid) + message = self.db.message_table(self.ps.message_month) + _, channels = message.all_channels(self.ps.uaid) # Get the current message month cur_month = self.db.current_msg_month if channels: # Save the current channels into this months message table - msg_table = self.db.message_tables[cur_month] + msg_table = self.db.message_table(cur_month) msg_table.save_channels(self.ps.uaid, channels) # Finally, update the route message month @@ -1069,7 +1085,7 @@ def _finish_monthly_transition(self, result): def error_monthly_rotation_overload(self, fail): """Capture overload on monthly table rotation attempt - If a provision exdeeded error hits while attempting monthly table + If a provision exceeded error hits while attempting monthly table rotation, schedule it all over and re-scan the messages. Normal websocket client flow is returned in the meantime. @@ -1137,12 +1153,13 @@ def error_register(self, fail): def finish_register(self, endpoint, chid): """callback for successful endpoint creation, sends register reply""" - d = self.deferToThread(self.ps.message.register_channel, self.ps.uaid, + message = self.db.message_table(self.ps.message_month) + d = self.deferToThread(message.register_channel, self.ps.uaid, chid) d.addCallback(self.send_register_finish, endpoint, chid) # Note: No trap_cancel needed here since the deferred here is # returned to process_register which will trap it - d.addErrback(self.err_overload, "register", disconnect=False) + d.addErrback(self.error_overload, "register", disconnect=False) return d def send_register_finish(self, result, endpoint, chid): @@ -1180,7 +1197,8 @@ def process_unregister(self, data): self.ps.updates_sent[chid] = [] # Unregister the channel - self.force_retry(self.ps.message.unregister_channel, self.ps.uaid, + message = self.db.message_table(self.ps.message_month) + self.force_retry(message.unregister_channel, self.ps.uaid, chid) data["status"] = 200 @@ -1238,6 +1256,7 @@ def ver_filter(notif): **self.ps.raw_agent) self.ps.stats.stored_acked += 1 + message = self.db.message_table(self.ps.message_month) if msg.sortkey_timestamp: # Is this the last un-acked message we're waiting for? last_unacked = sum( @@ -1249,7 +1268,7 @@ def ver_filter(notif): # If it's the last message in the batch, or last un-acked # message d = self.force_retry( - self.ps.message.update_last_message_read, + message.update_last_message_read, self.ps.uaid_obj, self.ps.current_timestamp, ) d.addBoth(self._handle_webpush_update_remove, chid, msg) @@ -1260,7 +1279,7 @@ def ver_filter(notif): d = None else: # No sortkey_timestamp, so legacy/topic message, delete - d = self.force_retry(self.ps.message.delete_message, msg) + d = self.force_retry(message.delete_message, msg) # We don't remove the update until we know the delete ran # This is because we don't use range queries on dynamodb and # we need to make sure this notification is deleted from the diff --git a/autopush_rs/Cargo.lock b/autopush_rs/Cargo.lock index 400b6032..6af06de1 100644 --- a/autopush_rs/Cargo.lock +++ b/autopush_rs/Cargo.lock @@ -1,4 +1,4 @@ -[root] +[[package]] name = "autopush" version = "0.1.0" dependencies = [