diff --git a/.travis.yml b/.travis.yml index 6a2e6457..163463e4 100644 --- a/.travis.yml +++ b/.travis.yml @@ -12,7 +12,7 @@ matrix: - python: 2.7 env: TOXENV=py27 DDB=true CODECOV=true - python: pypy - env: TOXENV=pypy DDB=true CODECOV=true + env: TOXENV=pypy DDB=true CODECOV=true AWS_LOCAL_DYNAMODB=http://127.0.0.1:8000 - env: TOXENV=flake8 WITH_RUST=false - python: 3.6 env: TOXENV=py36-mypy WITH_RUST=false diff --git a/autopush/db.py b/autopush/db.py index 7fee98cd..24a52d68 100644 --- a/autopush/db.py +++ b/autopush/db.py @@ -36,20 +36,22 @@ import random import time import uuid +from collections import deque from functools import wraps from attr import ( asdict, attrs, attrib, - Factory -) + Factory, + Attribute) from boto.dynamodb2.exceptions import ( ItemNotFound, ) import boto3 import botocore +from boto3 import Session # noqa from boto3.dynamodb.conditions import Key from boto3.exceptions import Boto3Error from botocore.exceptions import ClientError @@ -96,15 +98,42 @@ TRACK_DB_CALLS = False DB_CALLS = [] -# See https://botocore.readthedocs.io/en/stable/reference/config.html for -# additional config options -g_dynamodb = boto3.resource( - 'dynamodb', - config=botocore.config.Config( - region_name=os.getenv("AWS_REGION_NAME", "us-east-1") - ) -) -g_client = g_dynamodb.meta.client +SESSION_ARGS = {} +MAX_SESSIONS = 10 + + +class BotoSessions(object): + + def __init__(self, conf=None): + if conf is None: + conf = SESSION_ARGS + self.pool = deque(maxlen=MAX_SESSIONS) + session = boto3.session.Session() + if not conf.get("endpoint_url") and os.getenv("AWS_LOCAL_DYNAMODB"): + conf["endpoint_url"] = os.getenv("AWS_LOCAL_DYNAMODB") + self.pool.extendleft((session.resource( + 'dynamodb', + config=botocore.config.Config( + region_name=os.getenv("AWS_REGION_NAME", "us-east-1") + ), + **conf + ) for x in range(0, MAX_SESSIONS))) + + def fetch(self): + try: + return self.pool.pop() + except IndexError: + raise ClientError( + { + 'Error': { + 'Code': 'ProvisionedThroughputExceededException', + 'Message': 'Session pool exhausted', + } + }, + 'session.fetch') + + def release(self, resource): + self.pool.appendleft(resource) def get_month(delta=0): @@ -146,20 +175,21 @@ def make_rotating_tablename(prefix, delta=0, date=None): def create_rotating_message_table(prefix="message", delta=0, date=None, read_throughput=5, - write_throughput=5): - # type: (str, int, Optional[datetime.date], int, int) -> Table + write_throughput=5, + boto_session=None): + # type: (str, int, Optional[datetime.date], int, int, Session) -> Table """Create a new message table for webpush style message storage""" tablename = make_rotating_tablename(prefix, delta, date) try: - table = g_dynamodb.Table(tablename) + table = boto_session.Table(tablename) if table.table_status == 'ACTIVE': # pragma nocover return table except ClientError as ex: if ex.response['Error']['Code'] != 'ResourceNotFoundException': # If we hit this, our boto3 is misconfigured and we need to bail. raise ex # pragma nocover - table = g_dynamodb.create_table( + table = boto_session.create_table( TableName=tablename, KeySchema=[ { @@ -200,24 +230,30 @@ def create_rotating_message_table(prefix="message", delta=0, date=None, return table -def get_rotating_message_table(prefix="message", delta=0, date=None, - message_read_throughput=5, - message_write_throughput=5): - # type: (str, int, Optional[datetime.date], int, int) -> Table +def get_rotating_message_tablename(prefix="message", delta=0, date=None, + message_read_throughput=5, + message_write_throughput=5, + boto_session=None): + # type: (str, int, Optional[datetime.date], int, int, Session) -> str """Gets the message table for the current month.""" + if not boto_session: + raise Exception("Missing session") tablename = make_rotating_tablename(prefix, delta, date) - if not table_exists(tablename): - return create_rotating_message_table( + if not table_exists(tablename, boto_session=boto_session): + create_rotating_message_table( prefix=prefix, delta=delta, date=date, read_throughput=message_read_throughput, write_throughput=message_write_throughput, + boto_session=boto_session ) + return tablename else: - return g_dynamodb.Table(tablename) + return tablename def create_router_table(tablename="router", read_throughput=5, - write_throughput=5): + write_throughput=5, + boto_session=None): # type: (str, int, int) -> Table """Create a new router table @@ -232,8 +268,7 @@ def create_router_table(tablename="router", read_throughput=5, cost of additional queries during GC to locate expired users. """ - - table = g_dynamodb.create_table( + table = boto_session.create_table( TableName=tablename, KeySchema=[ { @@ -292,20 +327,22 @@ def create_router_table(tablename="router", read_throughput=5, return table -def _drop_table(tablename): +def _drop_table(tablename, boto_session): try: - g_client.delete_table(TableName=tablename) + boto_session.meta.client.delete_table(TableName=tablename) except ClientError: # pragma nocover pass -def _make_table(table_func, tablename, read_throughput, write_throughput): +def _make_table(table_func, tablename, read_throughput, write_throughput, + boto_session): # type: (Callable[[str, int, int], Table], str, int, int) -> Table """Private common function to make a table with a table func""" - if not table_exists(tablename): - return table_func(tablename, read_throughput, write_throughput) + if not table_exists(tablename, boto_session): + return table_func(tablename, read_throughput, write_throughput, + boto_session) else: - return g_dynamodb.Table(tablename) + return boto_session.Table(tablename) def _expiry(ttl): @@ -313,7 +350,7 @@ def _expiry(ttl): def get_router_table(tablename="router", read_throughput=5, - write_throughput=5): + write_throughput=5, boto_session=None): # type: (str, int, int) -> Table """Get the main router table object @@ -322,10 +359,11 @@ def get_router_table(tablename="router", read_throughput=5, """ return _make_table(create_router_table, tablename, read_throughput, - write_throughput) + write_throughput, boto_session=boto_session) -def preflight_check(message, router, uaid="deadbeef00000000deadbeef00000000"): +def preflight_check(message, router, uaid="deadbeef00000000deadbeef00000000", + boto_session=None): # type: (Message, Router, str) -> None """Performs a pre-flight check of the router/message to ensure appropriate permissions for operation. @@ -336,7 +374,8 @@ def preflight_check(message, router, uaid="deadbeef00000000deadbeef00000000"): # Verify tables are ready for use if they just got created ready = False while not ready: - tbl_status = [x.table_status() for x in [message, router]] + tbl_status = [x.table_status(boto_session=boto_session) + for x in [message, router]] ready = all([status == "ACTIVE" for status in tbl_status]) if not ready: time.sleep(1) @@ -354,20 +393,20 @@ def preflight_check(message, router, uaid="deadbeef00000000deadbeef00000000"): message_id=message_id, ttl=60, ) - # Store a notification, fetch it, delete it - message.store_message(notif) - assert message.delete_message(notif) + message.store_message(notif, boto_session=boto_session) + assert message.delete_message(notif, boto_session=boto_session) # Store a router entry, fetch it, delete it router.register_user(dict(uaid=uaid.hex, node_id=node_id, connected_at=connected_at, - router_type="webpush")) - item = router.get_uaid(uaid.hex) + router_type="webpush"), + boto_session=boto_session) + item = router.get_uaid(uaid.hex, boto_session=boto_session) assert item.get("node_id") == node_id # Clean up the preflight data. - router.clear_node(item) - router.drop_user(uaid.hex) + router.clear_node(item, boto_session=boto_session) + router.drop_user(uaid.hex, boto_session=boto_session) def track_provisioned(func): @@ -430,49 +469,42 @@ def generate_last_connect_values(date): yield int(val) -def list_tables(client=g_client): - """Return a list of the names of all DynamoDB tables.""" - start_table = None - while True: - if start_table: # pragma nocover - result = client.list_tables(ExclusiveStartTableName=start_table) - else: - result = client.list_tables() - for table in result.get('TableNames', []): - yield table - start_table = result.get('LastEvaluatedTableName', None) - if not start_table: - break - - -def table_exists(tablename, client=None): +def table_exists(tablename, boto_session): """Determine if the specified Table exists""" - if not client: - client = g_client - return tablename in list_tables(client) + try: + return boto_session.Table(tablename).table_status in [ + 'CREATING', 'UPDATING', 'ACTIVE'] + except ClientError: + return False class Message(object): """Create a Message table abstraction on top of a DynamoDB Table object""" - def __init__(self, table, metrics, max_ttl=MAX_EXPIRY): - # type: (Table, IMetrics) -> None + def __init__(self, tablename, metrics=None, boto_sessions=None, + max_ttl=MAX_EXPIRY): + # type: (str, IMetrics, int) -> None """Create a new Message object - :param table: :class:`Table` object. - :param metrics: Metrics object that implements the - :class:`autopush.metrics.IMetrics` interface. + :param tablename: name of the table. + :param metrics: unused + :param session: Session for thread """ - self.table = table - self.metrics = metrics + self.tablename = tablename self._max_ttl = max_ttl + self._sessions = boto_sessions - def table_status(self): - return self.table.table_status + def table(self, boto_session, tablename=None): + if not tablename: + tablename = self.tablename + return boto_session.Table(tablename) + + def table_status(self, boto_session): + return self.table(boto_session).table_status @track_provisioned - def register_channel(self, uaid, channel_id, ttl=None): - # type: (str, str, int) -> bool + def register_channel(self, uaid, channel_id, boto_session=None, ttl=None): + # type: (str, str, boto_session, int) -> bool """Register a channel for a given uaid""" # Generate our update expression if ttl is None: @@ -481,80 +513,114 @@ def register_channel(self, uaid, channel_id, ttl=None): ":channel_id": set([normalize_id(channel_id)]), ":expiry": _expiry(ttl) } - self.table.update_item( - Key={ - 'uaid': hasher(uaid), - 'chidmessageid': ' ', - }, - UpdateExpression='ADD chids :channel_id, expiry :expiry', - ExpressionAttributeValues=expr_values, - ) - return True + release = False + if not boto_session: + boto_session = self._sessions.fetch() + release = True + try: + self.table(boto_session).update_item( + Key={ + 'uaid': hasher(uaid), + 'chidmessageid': ' ', + }, + UpdateExpression='ADD chids :channel_id, expiry :expiry', + ExpressionAttributeValues=expr_values, + ) + return True + finally: + if release: + self._sessions.release(boto_session) @track_provisioned - def unregister_channel(self, uaid, channel_id, **kwargs): + def unregister_channel(self, uaid, channel_id, boto_session=None, + **kwargs): # type: (str, str, **str) -> bool """Remove a channel registration for a given uaid""" - expr = "DELETE chids :channel_id" - chid = normalize_id(channel_id) - expr_values = {":channel_id": set([chid])} - - response = self.table.update_item( - Key={ - 'uaid': hasher(uaid), - 'chidmessageid': ' ', - }, - UpdateExpression=expr, - ExpressionAttributeValues=expr_values, - ReturnValues="UPDATED_OLD", - ) - chids = response.get('Attributes', {}).get('chids', {}) - if chids: - try: - return chid in chids - except (TypeError, AttributeError): # pragma: nocover - pass - # if, for some reason, there are no chids defined, return False. - return False + release = False + if not boto_session: + boto_session = self._sessions.fetch() + release = True + try: + expr = "DELETE chids :channel_id" + chid = normalize_id(channel_id) + expr_values = {":channel_id": set([chid])} + + response = self.table(boto_session).update_item( + Key={ + 'uaid': hasher(uaid), + 'chidmessageid': ' ', + }, + UpdateExpression=expr, + ExpressionAttributeValues=expr_values, + ReturnValues="UPDATED_OLD", + ) + chids = response.get('Attributes', {}).get('chids', {}) + if chids: + try: + return chid in chids + except (TypeError, AttributeError): # pragma: nocover + pass + # if, for some reason, there are no chids defined, return False. + return False + finally: + if release: + self._sessions.release(boto_session) @track_provisioned - def all_channels(self, uaid): + def all_channels(self, uaid, boto_session=None): # type: (str) -> Tuple[bool, Set[str]] """Retrieve a list of all channels for a given uaid""" # Note: This only returns the chids associated with the UAID. # Functions that call store_message() would be required to # update that list as well using register_channel() - result = self.table.get_item( - Key={ - 'uaid': hasher(uaid), - 'chidmessageid': ' ', - }, - ConsistentRead=True - ) - if result['ResponseMetadata']['HTTPStatusCode'] != 200: - return False, set([]) - if 'Item' not in result: - return False, set([]) - return True, result['Item'].get("chids", set([])) + release = False + if not boto_session: + boto_session = self._sessions.fetch() + release = True + try: + result = self.table(boto_session).get_item( + Key={ + 'uaid': hasher(uaid), + 'chidmessageid': ' ', + }, + ConsistentRead=True + ) + if result['ResponseMetadata']['HTTPStatusCode'] != 200: + return False, set([]) + if 'Item' not in result: + return False, set([]) + return True, result['Item'].get("chids", set([])) + finally: + if release: + self._sessions.release(boto_session) @track_provisioned - def save_channels(self, uaid, channels): + def save_channels(self, uaid, channels, boto_session=None): # type: (str, Set[str]) -> None """Save out a set of channels""" - self.table.put_item( - Item={ - 'uaid': hasher(uaid), - 'chidmessageid': ' ', - 'chids': channels, - 'expiry': _expiry(self._max_ttl), - }, - ) + release = False + if not boto_session: + boto_session = self._sessions.fetch() + release = True + try: + self.table(boto_session).put_item( + Item={ + 'uaid': hasher(uaid), + 'chidmessageid': ' ', + 'chids': channels, + 'expiry': _expiry(self._max_ttl), + }, + ) + finally: + if release: + self._sessions.release(boto_session) @track_provisioned - def store_message(self, notification): + def store_message(self, notification, boto_session=None): # type: (WebPushNotification) -> None """Stores a WebPushNotification in the message table""" + release = False item = dict( uaid=hasher(notification.uaid.hex), chidmessageid=notification.sort_key, @@ -568,39 +634,55 @@ def store_message(self, notification): ) if notification.data: item['data'] = notification.data - self.table.put_item(Item=item) + if not boto_session: + boto_session = self._sessions.fetch() + release = True + try: + self.table(boto_session).put_item(Item=item) + finally: + if release: + self._sessions.release(boto_session) @track_provisioned - def delete_message(self, notification): + def delete_message(self, notification, boto_session=None): # type: (WebPushNotification) -> bool """Deletes a specific message""" - if notification.update_id: - try: - self.table.delete_item( + release = False + if not boto_session: + boto_session = self._sessions.fetch() + release = True + try: + if notification.update_id: + try: + self.table(boto_session).delete_item( + Key={ + 'uaid': hasher(notification.uaid.hex), + 'chidmessageid': notification.sort_key + }, + Expected={ + 'updateid': { + 'Exists': True, + 'Value': notification.update_id + } + }) + except ClientError: + return False + else: + self.table(boto_session).delete_item( Key={ 'uaid': hasher(notification.uaid.hex), - 'chidmessageid': notification.sort_key - }, - Expected={ - 'updateid': { - 'Exists': True, - 'Value': notification.update_id - } + 'chidmessageid': notification.sort_key, }) - except ClientError: - return False - else: - self.table.delete_item( - Key={ - 'uaid': hasher(notification.uaid.hex), - 'chidmessageid': notification.sort_key, - }) - return True + return True + finally: + if release: + self._sessions.release(boto_session) @track_provisioned def fetch_messages( self, uaid, # type: uuid.UUID + boto_session=None, limit=10, # type: int ): # type: (...) -> Tuple[Optional[int], List[WebPushNotification]] @@ -610,27 +692,35 @@ def fetch_messages( messages and the list of non-timestamped messages. """ - # Eagerly fetches all results in the result set. - response = self.table.query( - KeyConditionExpression=(Key("uaid").eq(hasher(uaid.hex)) - & Key('chidmessageid').lt('02')), - ConsistentRead=True, - Limit=limit, - ) - results = list(response['Items']) - # First extract the position if applicable, slightly higher than 01: - # to ensure we don't load any 01 remainders that didn't get deleted - # yet - last_position = None - if results: - # Ensure we return an int, as boto2 can return Decimals - if results[0].get("current_timestamp"): - last_position = int(results[0]["current_timestamp"]) - - return last_position, [ - WebPushNotification.from_message_table(uaid, x) - for x in results[1:] - ] + release = False + if not boto_session: + boto_session = self._sessions.fetch() + release = True + try: + # Eagerly fetches all results in the result set. + response = self.table(boto_session).query( + KeyConditionExpression=(Key("uaid").eq(hasher(uaid.hex)) + & Key('chidmessageid').lt('02')), + ConsistentRead=True, + Limit=limit, + ) + results = list(response['Items']) + # First extract the position if applicable, slightly higher than + # 01: to ensure we don't load any 01 remainders that didn't get + # deleted yet + last_position = None + if results: + # Ensure we return an int, as boto2 can return Decimals + if results[0].get("current_timestamp"): + last_position = int(results[0]["current_timestamp"]) + + return last_position, [ + WebPushNotification.from_message_table(uaid, x) + for x in results[1:] + ] + finally: + if release: + self._sessions.release(boto_session) @track_provisioned def fetch_timestamp_messages( @@ -638,6 +728,7 @@ def fetch_timestamp_messages( uaid, # type: uuid.UUID timestamp=None, # type: Optional[Union[int, str]] limit=10, # type: int + boto_session=None, # type: Session ): # type: (...) -> Tuple[Optional[int], List[WebPushNotification]] """Fetches timestamped messages for a uaid @@ -650,50 +741,66 @@ def fetch_timestamp_messages( timestamped messages. """ - # Turn the timestamp into a proper sort key - if timestamp: - sortkey = "02:{timestamp}:z".format(timestamp=timestamp) - else: - sortkey = "01;" + release = False + if not boto_session: + boto_session = self._sessions.fetch() + release = True + try: + # Turn the timestamp into a proper sort key + if timestamp: + sortkey = "02:{timestamp}:z".format(timestamp=timestamp) + else: + sortkey = "01;" - response = self.table.query( - KeyConditionExpression=(Key('uaid').eq(hasher(uaid.hex)) - & Key('chidmessageid').gt(sortkey)), - ConsistentRead=True, - Limit=limit - ) - notifs = [ - WebPushNotification.from_message_table(uaid, x) for x in - response.get("Items") - ] - ts_notifs = [x for x in notifs if x.sortkey_timestamp] - last_position = None - if ts_notifs: - last_position = ts_notifs[-1].sortkey_timestamp - return last_position, notifs + response = self.table(boto_session).query( + KeyConditionExpression=(Key('uaid').eq(hasher(uaid.hex)) + & Key('chidmessageid').gt(sortkey)), + ConsistentRead=True, + Limit=limit + ) + notifs = [ + WebPushNotification.from_message_table(uaid, x) for x in + response.get("Items") + ] + ts_notifs = [x for x in notifs if x.sortkey_timestamp] + last_position = None + if ts_notifs: + last_position = ts_notifs[-1].sortkey_timestamp + return last_position, notifs + finally: + if release: + self._sessions.release(boto_session) @track_provisioned - def update_last_message_read(self, uaid, timestamp): - # type: (uuid.UUID, int) -> bool + def update_last_message_read(self, uaid, timestamp, boto_session=None): + # type: (uuid.UUID, int, Session) -> bool """Update the last read timestamp for a user""" - expr = "SET current_timestamp=:timestamp, expiry=:expiry" - expr_values = {":timestamp": timestamp, - ":expiry": _expiry(self._max_ttl)} - self.table.update_item( - Key={ - "uaid": hasher(uaid.hex), - "chidmessageid": " " - }, - UpdateExpression=expr, - ExpressionAttributeValues=expr_values, - ) - return True + release = False + if not boto_session: + boto_session = self._sessions.fetch() + release = True + try: + expr = "SET current_timestamp=:timestamp, expiry=:expiry" + expr_values = {":timestamp": timestamp, + ":expiry": _expiry(self._max_ttl)} + self.table(boto_session).update_item( + Key={ + "uaid": hasher(uaid.hex), + "chidmessageid": " " + }, + UpdateExpression=expr, + ExpressionAttributeValues=expr_values, + ) + return True + finally: + if release: + self._sessions.release(boto_session) class Router(object): """Create a Router table abstraction on top of a DynamoDB Table object""" - def __init__(self, table, metrics, max_ttl=MAX_EXPIRY): - # type: (Table, IMetrics) -> None + def __init__(self, conf, metrics, boto_sessions, max_ttl=MAX_EXPIRY): + # type: (dict, IMetrics, BotoSessions) -> None """Create a new Router object :param table: :class:`Table` object. @@ -701,14 +808,30 @@ def __init__(self, table, metrics, max_ttl=MAX_EXPIRY): :class:`autopush.metrics.IMetrics` interface. """ - self.table = table + self.conf = conf self.metrics = metrics self._max_ttl = max_ttl + self._session = None + self._cached_table = None + self._sessions = boto_sessions or BotoSessions() + self._session = None + + def table(self, boto_session): + if self._cached_table and boto_session == self._session: + return self._cached_table + self._session = boto_session + + if self.conf: + self._cached_table = get_router_table(boto_session=boto_session, + **asdict(self.conf)) + else: + self._cached_table = get_router_table(boto_session=boto_session) + return self._cached_table - def table_status(self): - return self.table.table_status + def table_status(self, boto_session): + return self.table(boto_session).table_status - def get_uaid(self, uaid): + def get_uaid(self, uaid, boto_session=None): # type: (str) -> Item """Get the database record for the UAID @@ -718,8 +841,12 @@ def get_uaid(self, uaid): exceeds throughput. """ + release = False + if not boto_session: + boto_session = self._sessions.fetch() + release = True try: - item = self.table.get_item( + item = self.table(boto_session).get_item( Key={ 'uaid': hasher(uaid) }, @@ -741,9 +868,12 @@ def get_uaid(self, uaid): # JSON when looking up values in empty tables. We re-throw the # correct ItemNotFound exception raise ItemNotFound("uaid not found") + finally: + if release: + self._sessions.release(boto_session) @track_provisioned - def register_user(self, data): + def register_user(self, data, boto_session=None): # type: (ItemLike) -> Tuple[bool, Dict[str, Any]] """Register this user @@ -756,79 +886,97 @@ def register_user(self, data): exceeds throughput. """ - # Fetch a senderid for this user - db_key = {"uaid": hasher(data["uaid"])} - del data["uaid"] - if "router_type" not in data or "connected_at" not in data: - # Not specifying these values will generate an exception in AWS. - raise AutopushException("data is missing router_type " - "or connected_at") - # Generate our update expression - expr = "SET " + ", ".join(["%s=:%s" % (x, x) for x in data.keys()]) - expr_values = {":%s" % k: v for k, v in data.items()} + release = False + if not boto_session: + boto_session = self._sessions.fetch() + release = True try: - cond = """( - attribute_not_exists(router_type) or - (router_type = :router_type) - ) and ( - attribute_not_exists(node_id) or - (connected_at < :connected_at) - )""" - result = self.table.update_item( - Key=db_key, - UpdateExpression=expr, - ConditionExpression=cond, - ExpressionAttributeValues=expr_values, - ReturnValues="ALL_OLD", - ) - if "Attributes" in result: - r = {} - for key, value in result["Attributes"].items(): - try: - r[key] = self.table._dynamizer.decode(value) - except (TypeError, AttributeError): # pragma: nocover - # Included for safety as moto has occasionally made - # this not work - r[key] = value - result = r - return (True, result) - except ClientError as ex: - # ClientErrors are generated by a factory, and while they have a - # class, it's dynamically generated. - if ex.response['Error']['Code'] == \ - 'ConditionalCheckFailedException': - return (False, {}) - raise + # Fetch a senderid for this user + db_key = {"uaid": hasher(data["uaid"])} + del data["uaid"] + if "router_type" not in data or "connected_at" not in data: + # Not specifying these values will generate an exception in + # AWS. + raise AutopushException("data is missing router_type " + "or connected_at") + # Generate our update expression + expr = "SET " + ", ".join(["%s=:%s" % (x, x) for x in data.keys()]) + expr_values = {":%s" % k: v for k, v in data.items()} + try: + cond = """( + attribute_not_exists(router_type) or + (router_type = :router_type) + ) and ( + attribute_not_exists(node_id) or + (connected_at < :connected_at) + )""" + result = self.table(boto_session).update_item( + Key=db_key, + UpdateExpression=expr, + ConditionExpression=cond, + ExpressionAttributeValues=expr_values, + ReturnValues="ALL_OLD", + ) + if "Attributes" in result: + r = {} + for key, value in result["Attributes"].items(): + try: + r[key] = self.table( + boto_session)._dynamizer.decode(value) + except (TypeError, AttributeError): # pragma: nocover + # Included for safety as moto has occasionally made + # this not work + r[key] = value + result = r + return (True, result) + except ClientError as ex: + # ClientErrors are generated by a factory, and while they have + # a class, it's dynamically generated. + if ex.response['Error']['Code'] == \ + 'ConditionalCheckFailedException': + return (False, {}) + raise + finally: + if release: + self._sessions.release(boto_session) @track_provisioned - def drop_user(self, uaid): + def drop_user(self, uaid, boto_session=None): # type: (str) -> bool """Drops a user record""" # The following hack ensures that only uaids that exist and are # deleted return true. + release = False + if not boto_session: + boto_session = self._sessions.fetch() try: - item = self.table.get_item( - Key={ - 'uaid': hasher(uaid) - }, - ConsistentRead=True, - ) - if 'Item' not in item: - return False - except ClientError: - pass - result = self.table.delete_item(Key={'uaid': hasher(uaid)}) - return result['ResponseMetadata']['HTTPStatusCode'] == 200 - - def delete_uaids(self, uaids): + try: + item = self.table(boto_session).get_item( + Key={ + 'uaid': hasher(uaid) + }, + ConsistentRead=True, + ) + if 'Item' not in item: + return False + except ClientError: + pass + result = self.table(boto_session).delete_item( + Key={'uaid': hasher(uaid)}) + return result['ResponseMetadata']['HTTPStatusCode'] == 200 + finally: + if release: + self._sessions.release(boto_session) + + def delete_uaids(self, uaids, boto_session): # type: (List[str]) -> None """Issue a batch delete call for the given uaids""" - with self.table.batch_writer() as batch: + with self.table(boto_session).batch_writer() as batch: for uaid in uaids: batch.delete_item(Key={'uaid': uaid}) - def drop_old_users(self, months_ago=2): - # type: (int) -> Iterable[int] + def drop_old_users(self, boto_session, months_ago=2): + # type: (int, Session) -> Iterable[int] """Drops user records that have no recent connection Utilizes the last_connect index to locate users that haven't @@ -849,6 +997,7 @@ def drop_old_users(self, months_ago=2): quickly as possible. :param months_ago: how many months ago since the last connect + :param boto_session: Session for thread :returns: Iterable of how many deletes were run @@ -857,7 +1006,7 @@ def drop_old_users(self, months_ago=2): batched = [] for hash_key in generate_last_connect_values(prior_date): - response = self.table.query( + response = self.table(boto_session).query( KeyConditionExpression=Key("last_connect").eq(hash_key), IndexName="AccessIndex", ) @@ -866,26 +1015,26 @@ def drop_old_users(self, months_ago=2): batched.append(result["uaid"]) if len(batched) == 25: - self.delete_uaids(batched) + self.delete_uaids(batched, boto_session=boto_session) batched = [] yield 25 # Delete any leftovers if batched: - self.delete_uaids(batched) + self.delete_uaids(batched, boto_session=boto_session) yield len(batched) @track_provisioned - def _update_last_connect(self, uaid, last_connect): - self.table.update_item( + def _update_last_connect(self, uaid, last_connect, boto_session): + self.table(boto_session).update_item( Key={"uaid": hasher(uaid)}, UpdateExpression="SET last_connect=:last_connect", ExpressionAttributeValues={":last_connect": last_connect} ) @track_provisioned - def update_message_month(self, uaid, month): - # type: (str, str) -> bool + def update_message_month(self, uaid, month, boto_session): + # type: (str, str, Session) -> bool """Update the route tables current_message_month Note that we also update the last_connect at this point since webpush @@ -901,7 +1050,7 @@ def update_message_month(self, uaid, month): ":last_connect": generate_last_connect(), ":expiry": _expiry(self._max_ttl), } - self.table.update_item( + self.table(boto_session).update_item( Key=db_key, UpdateExpression=expr, ExpressionAttributeValues=expr_values, @@ -909,8 +1058,8 @@ def update_message_month(self, uaid, month): return True @track_provisioned - def clear_node(self, item): - # type: (dict) -> bool + def clear_node(self, item, boto_session): + # type: (dict, Session) -> bool """Given a router item and remove the node_id The node_id will only be cleared if the ``connected_at`` matches up @@ -928,7 +1077,7 @@ def clear_node(self, item): try: cond = "(node_id = :node) and (connected_at = :conn)" - self.table.put_item( + self.table(boto_session).put_item( Item=item, ConditionExpression=cond, ExpressionAttributeValues={ @@ -951,15 +1100,15 @@ class DatabaseManager(object): _router_conf = attrib() # type: DDBTableConfig _message_conf = attrib() # type: DDBTableConfig - - metrics = attrib() # type: IMetrics + metrics = attrib() # type: IMetrics + _sessions = attrib() # type: BotoSessions router = attrib(default=None) # type: Optional[Router] - message_tables = attrib(default=Factory(dict)) # type: Dict[str, Message] + message_tables = attrib(default=Factory(list)) # type: List[str] current_msg_month = attrib(init=False) # type: Optional[str] current_month = attrib(init=False) # type: Optional[int] + _message = attrib(init=None) # type: Optional[Message] # for testing: - client = attrib(default=g_client) # type: Optional[Any] def __attrs_post_init__(self): """Initialize sane defaults""" @@ -969,9 +1118,11 @@ def __attrs_post_init__(self): self._message_conf.tablename, date=today ) + if not self._sessions: + self._sessions = BotoSessions() @classmethod - def from_config(cls, conf, **kwargs): + def from_config(cls, conf, boto_sessions=None, **kwargs): # type: (AutopushConfig, **Any) -> DatabaseManager """Create a DatabaseManager from the given config""" metrics = autopush.metrics.from_config(conf) @@ -979,6 +1130,7 @@ def from_config(cls, conf, **kwargs): router_conf=conf.router_table, message_conf=conf.message_table, metrics=metrics, + sessions=boto_sessions, **kwargs ) @@ -986,14 +1138,21 @@ def setup(self, preflight_uaid): # type: (str) -> None """Setup metrics, message tables and perform preflight_check""" self.metrics.start() - self.setup_tables() - preflight_check(self.message, self.router, preflight_uaid) + session = self._sessions.fetch() + try: + self.setup_tables(boto_session=session) + preflight_check(self.message, self.router, preflight_uaid, + boto_session=session) + finally: + self._sessions.release(session) - def setup_tables(self): + def setup_tables(self, boto_session=None): """Lookup or create the database tables""" + release = False self.router = Router( - get_router_table(**asdict(self._router_conf)), - self.metrics + conf=self._router_conf, + metrics=self.metrics, + boto_sessions=self._sessions, ) # Used to determine whether a connection is out of date with current # db objects. There are three noteworty cases: @@ -1003,25 +1162,42 @@ def setup_tables(self): # timing, some nodes may roll over sooner. Ensuring the next month's # table is present before the switchover is the main reason for this, # just in case some nodes do switch sooner. - self.create_initial_message_tables() + if not boto_session: + boto_session = self.sessions.fetch() + release = True + try: + self.create_initial_message_tables(boto_session) + finally: + if release: + self.sessions.release(boto_session) + self._message = Message(self.current_msg_month, + self.metrics, + boto_sessions=self.sessions) @property def message(self): # type: () -> Message """Property that access the current message table""" - return self.message_tables[self.current_msg_month] + if not self._message or isinstance(self._message, Attribute): + self._message = self.message_table(self.current_msg_month) + return self._message + + @property + def sessions(self): + return self._sessions @message.setter - def message(self, value): - # type: (Message) -> None - """Setter to set the current message table""" - self.message_tables[self.current_msg_month] = value + def message_set(self, value): + self._message = value + + def message_table(self, tablename): + return Message(tablename, self.metrics, boto_sessions=self.sessions) def _tomorrow(self): # type: () -> datetime.date return datetime.date.today() + datetime.timedelta(days=1) - def create_initial_message_tables(self): + def create_initial_message_tables(self, boto_session): """Initializes a dict of the initial rotating messages tables. An entry for last months table, an entry for this months table, @@ -1030,32 +1206,31 @@ def create_initial_message_tables(self): """ mconf = self._message_conf today = datetime.date.today() - last_month = get_rotating_message_table( + last_month = get_rotating_message_tablename( prefix=mconf.tablename, delta=-1, message_read_throughput=mconf.read_throughput, - message_write_throughput=mconf.write_throughput + message_write_throughput=mconf.write_throughput, + boto_session=boto_session, ) - this_month = get_rotating_message_table( + this_month = get_rotating_message_tablename( prefix=mconf.tablename, message_read_throughput=mconf.read_throughput, - message_write_throughput=mconf.write_throughput + message_write_throughput=mconf.write_throughput, + boto_session=boto_session, ) self.current_month = today.month - self.current_msg_month = this_month.table_name - self.message_tables = { - last_month.table_name: Message(last_month, self.metrics), - this_month.table_name: Message(this_month, self.metrics) - } + self.current_msg_month = this_month + self.message_tables = [last_month, this_month] if self._tomorrow().month != today.month: - next_month = get_rotating_message_table( + next_month = get_rotating_message_tablename( prefix=mconf.tablename, delta=1, message_read_throughput=mconf.read_throughput, - message_write_throughput=mconf.write_throughput + message_write_throughput=mconf.write_throughput, + boto_session=boto_session, ) - self.message_tables[next_month.table_name] = Message( - next_month, self.metrics) + self.message_tables.append(next_month) @inlineCallbacks def update_rotating_tables(self): @@ -1070,35 +1245,36 @@ def update_rotating_tables(self): mconf = self._message_conf today = datetime.date.today() tomorrow = self._tomorrow() - if ((tomorrow.month != today.month) and - sorted(self.message_tables.keys())[-1] != tomorrow.month): - next_month = yield deferToThread( - get_rotating_message_table, + boto_session = self._sessions.fetch() + try: + if ((tomorrow.month != today.month) and + sorted(self.message_tables)[-1] != tomorrow.month): + next_month = yield deferToThread( + get_rotating_message_tablename, + prefix=mconf.tablename, + delta=0, + date=tomorrow, + message_read_throughput=mconf.read_throughput, + message_write_throughput=mconf.write_throughput, + boto_session=boto_session + ) + self.message_tables.append(next_month) + if today.month == self.current_month: + # No change in month, we're fine. + returnValue(False) + + # Get tables for the new month, and verify they exist before we + # try to switch over + message_table = yield deferToThread( + get_rotating_message_tablename, prefix=mconf.tablename, - delta=0, - date=tomorrow, message_read_throughput=mconf.read_throughput, message_write_throughput=mconf.write_throughput ) - self.message_tables[next_month.table_name] = Message( - next_month, self.metrics) - - if today.month == self.current_month: - # No change in month, we're fine. - returnValue(False) - # Get tables for the new month, and verify they exist before we try to - # switch over - message_table = yield deferToThread( - get_rotating_message_table, - prefix=mconf.tablename, - message_read_throughput=mconf.read_throughput, - message_write_throughput=mconf.write_throughput - ) - - # Both tables found, safe to switch-over - self.current_month = today.month - self.current_msg_month = message_table.table_name - self.message_tables[self.current_msg_month] = Message( - message_table, self.metrics) + # Both tables found, safe to switch-over + self.current_month = today.month + self.current_msg_month = message_table + finally: + self._sessions.release(boto_session) returnValue(True) diff --git a/autopush/diagnostic_cli.py b/autopush/diagnostic_cli.py index b44d5594..9bf63605 100644 --- a/autopush/diagnostic_cli.py +++ b/autopush/diagnostic_cli.py @@ -7,7 +7,7 @@ from twisted.logger import Logger from autopush.config import AutopushConfig -from autopush.db import DatabaseManager +from autopush.db import DatabaseManager, Message from autopush.main import AutopushMultiService from autopush.main_argparse import add_shared_args @@ -18,11 +18,12 @@ class EndpointDiagnosticCLI(object): log = Logger() - def __init__(self, sysargs, use_files=True): + def __init__(self, sysargs, sessions, use_files=True): ns = self._load_args(sysargs, use_files) self._conf = conf = AutopushConfig.from_argparse(ns) conf.statsd_host = None - self.db = DatabaseManager.from_config(conf) + self.db = DatabaseManager.from_config(conf, + boto_sessions=sessions) self.db.setup(conf.preflight_uaid) self._endpoint = ns.endpoint self._pp = pprint.PrettyPrinter(indent=4) @@ -69,11 +70,13 @@ def run(self): print("\n") mess_table = rec["current_month"] - chans = self.db.message_tables[mess_table].all_channels(uaid) + chans = Message(mess_table).all_channels(uaid) print("Channels in message table:") self._pp.pprint(chans) -def run_endpoint_diagnostic_cli(sysargs=None, use_files=True): - cli = EndpointDiagnosticCLI(sysargs, use_files) +def run_endpoint_diagnostic_cli(sysargs=None, use_files=True, sessions=None): + cli = EndpointDiagnosticCLI(sysargs, + sessions=sessions, + use_files=use_files) return cli.run() diff --git a/autopush/main.py b/autopush/main.py index 5e24f924..77ad7048 100644 --- a/autopush/main.py +++ b/autopush/main.py @@ -63,11 +63,11 @@ class AutopushMultiService(MultiService): THREAD_POOL_SIZE = 50 - def __init__(self, conf): + def __init__(self, conf, sessions=None): # type: (AutopushConfig) -> None super(AutopushMultiService, self).__init__() self.conf = conf - self.db = DatabaseManager.from_config(conf) + self.db = DatabaseManager.from_config(conf, boto_sessions=sessions) self.agent = agent_from_config(conf) @staticmethod @@ -115,7 +115,7 @@ def stopService(self): undo_monkey_patch_ssl_wrap_socket() @classmethod - def _from_argparse(cls, ns, **kwargs): + def _from_argparse(cls, ns, sessions=None, **kwargs): # type: (Namespace, **Any) -> AutopushMultiService """Create an instance from argparse/additional kwargs""" # Add some entropy to prevent potential conflicts. @@ -126,10 +126,10 @@ def _from_argparse(cls, ns, **kwargs): preflight_uaid="deadbeef00000000deadbeef" + postfix, **kwargs ) - return cls(conf) + return cls(conf, sessions=sessions) @classmethod - def main(cls, args=None, use_files=True): + def main(cls, args=None, use_files=True, sessions=None): # type: (Sequence[str], bool) -> Any """Entry point to autopush's main command line scripts. @@ -148,7 +148,8 @@ def main(cls, args=None, use_files=True): firehose_delivery_stream=ns.firehose_stream_name ) try: - app = cls.from_argparse(ns) + cls.argparse = cls.from_argparse(ns, sessions=sessions) + app = cls.argparse except InvalidConfig as e: log.critical(str(e)) return 1 @@ -172,9 +173,9 @@ class EndpointApplication(AutopushMultiService): endpoint_factory = EndpointHTTPFactory - def __init__(self, conf): + def __init__(self, conf, sessions=None): # type: (AutopushConfig) -> None - super(EndpointApplication, self).__init__(conf) + super(EndpointApplication, self).__init__(conf, sessions=sessions) self.routers = routers_from_config(conf, self.db, self.agent) def setup(self, rotate_tables=True): @@ -209,7 +210,7 @@ def add_endpoint(self): self.addService(StreamServerEndpointService(ep, factory)) @classmethod - def from_argparse(cls, ns): + def from_argparse(cls, ns, sessions=None): # type: (Namespace) -> AutopushMultiService return super(EndpointApplication, cls)._from_argparse( ns, @@ -220,6 +221,7 @@ def from_argparse(cls, ns): cors=not ns.no_cors, bear_hash_key=ns.auth_key, proxy_protocol_port=ns.proxy_protocol_port, + sessions=sessions ) @@ -240,9 +242,9 @@ class ConnectionApplication(AutopushMultiService): websocket_factory = PushServerFactory websocket_site_factory = ConnectionWSSite - def __init__(self, conf): + def __init__(self, conf, sessions=None): # type: (AutopushConfig) -> None - super(ConnectionApplication, self).__init__(conf) + super(ConnectionApplication, self).__init__(conf, sessions=sessions) self.clients = {} # type: Dict[str, PushServerProtocol] def setup(self, rotate_tables=True): @@ -276,7 +278,7 @@ def add_websocket(self): self.add_maybe_ssl(conf.port, site_factory, site_factory.ssl_cf()) @classmethod - def from_argparse(cls, ns): + def from_argparse(cls, ns, sessions=None): # type: (Namespace) -> AutopushMultiService return super(ConnectionApplication, cls)._from_argparse( ns, @@ -298,6 +300,7 @@ def from_argparse(cls, ns): auto_ping_timeout=ns.auto_ping_timeout, max_connections=ns.max_connections, close_handshake_timeout=ns.close_handshake_timeout, + sessions=sessions ) @@ -340,7 +343,7 @@ def stopService(self): yield super(RustConnectionApplication, self).stopService() @classmethod - def from_argparse(cls, ns): + def from_argparse(cls, ns, sessions=None): # type: (Namespace) -> AutopushMultiService return super(RustConnectionApplication, cls)._from_argparse( ns, @@ -363,10 +366,11 @@ def from_argparse(cls, ns): auto_ping_timeout=ns.auto_ping_timeout, max_connections=ns.max_connections, close_handshake_timeout=ns.close_handshake_timeout, + sessions=sessions ) @classmethod - def main(cls, args=None, use_files=True): + def main(cls, args=None, use_files=True, sessions=None): # type: (Sequence[str], bool) -> Any """Entry point to autopush's main command line scripts. @@ -385,7 +389,7 @@ def main(cls, args=None, use_files=True): firehose_delivery_stream=ns.firehose_stream_name ) try: - app = cls.from_argparse(ns) + app = cls.from_argparse(ns, sessions=sessions) except InvalidConfig as e: log.critical(str(e)) return 1 diff --git a/autopush/router/webpush.py b/autopush/router/webpush.py index 177c9a08..93602140 100644 --- a/autopush/router/webpush.py +++ b/autopush/router/webpush.py @@ -231,7 +231,7 @@ def _save_notification(self, uaid_data, notification): "Location": location}, logged_status=204) return deferToThread( - self.db.message_tables[month_table].store_message, + self.db.message_table(month_table).store_message, notification=notification, ) diff --git a/autopush/scripts/drop_user.py b/autopush/scripts/drop_user.py index 2405e176..ab9bec35 100644 --- a/autopush/scripts/drop_user.py +++ b/autopush/scripts/drop_user.py @@ -5,6 +5,7 @@ from autopush.db import ( get_router_table, Router, + BotoSessions, ) from autopush.metrics import SinkMetrics @@ -18,20 +19,24 @@ help="Seconds to pause between batches.") def drop_users(router_table_name, months_ago, batch_size, pause_time): router_table = get_router_table(router_table_name) - router = Router(router_table, SinkMetrics()) - + sessions = BotoSessions() + router = Router(router_table, SinkMetrics(), boto_sessions=sessions) + session = sessions.fetch() click.echo("Deleting users with a last_connect %s months ago." % months_ago) count = 0 - for deletes in router.drop_old_users(months_ago): - click.echo("") - count += deletes - if count >= batch_size: - click.echo("Deleted %s user records, pausing for %s seconds." - % pause_time) - time.sleep(pause_time) - count = 0 + try: + for deletes in router.drop_old_users(months_ago, session): + click.echo("") + count += deletes + if count >= batch_size: + click.echo("Deleted %s user records, pausing for %s seconds." + % pause_time) + time.sleep(pause_time) + count = 0 + finally: + sessions.release(session) click.echo("Finished old user purge.") diff --git a/autopush/tests/__init__.py b/autopush/tests/__init__.py index adc0ead4..050b1a12 100644 --- a/autopush/tests/__init__.py +++ b/autopush/tests/__init__.py @@ -4,12 +4,9 @@ import subprocess import boto -import botocore -import boto3 import psutil -import autopush.db -from autopush.db import create_rotating_message_table +from autopush.db import create_rotating_message_table, BotoSessions here_dir = os.path.abspath(os.path.dirname(__file__)) root_dir = os.path.dirname(os.path.dirname(here_dir)) @@ -17,35 +14,34 @@ ddb_lib_dir = os.path.join(ddb_dir, "DynamoDBLocal_lib") ddb_jar = os.path.join(ddb_dir, "DynamoDBLocal.jar") ddb_process = None +boto_sessions = None def setUp(): for name in ('boto', 'boto3', 'botocore'): logging.getLogger(name).setLevel(logging.CRITICAL) - global ddb_process + global ddb_process, boto_sessions cmd = " ".join([ "java", "-Djava.library.path=%s" % ddb_lib_dir, "-jar", ddb_jar, "-sharedDb", "-inMemory" ]) - conf = botocore.config.Config( - region_name=os.getenv('AWS_REGION_NAME', 'us-east-1') - ) ddb_process = subprocess.Popen(cmd, shell=True, env=os.environ) - autopush.db.g_dynamodb = boto3.resource( - 'dynamodb', - config=conf, - endpoint_url=os.getenv("AWS_LOCAL_DYNAMODB", "http://127.0.0.1:8000"), + if os.getenv("AWS_LOCAL_DYNAMODB") is None: + os.environ["AWS_LOCAL_DYNAMODB"] = "http://127.0.0.1:8000" + session_args = dict( + endpoint_url=os.getenv("AWS_LOCAL_DYNAMODB"), aws_access_key_id="BogusKey", aws_secret_access_key="BogusKey", ) - - autopush.db.g_client = autopush.db.g_dynamodb.meta.client - + boto_sessions = BotoSessions(conf=session_args) # Setup the necessary message tables message_table = os.environ.get("MESSAGE_TABLE", "message_int_test") - - create_rotating_message_table(prefix=message_table, delta=-1) - create_rotating_message_table(prefix=message_table) + session = boto_sessions.fetch() + create_rotating_message_table(prefix=message_table, delta=-1, + boto_session=session) + create_rotating_message_table(prefix=message_table, + boto_session=session) + boto_sessions.release(session) def tearDown(): diff --git a/autopush/tests/support.py b/autopush/tests/support.py index 1d4e4bfd..8b11c38c 100644 --- a/autopush/tests/support.py +++ b/autopush/tests/support.py @@ -8,6 +8,7 @@ Router, ) from autopush.metrics import SinkMetrics +import autopush.tests @implementer(ILogObserver) @@ -44,5 +45,6 @@ def test_db(metrics=None): router_conf=DDBTableConfig(tablename='router'), router=Mock(spec=Router), message_conf=DDBTableConfig(tablename='message'), - metrics=SinkMetrics() if metrics is None else metrics + metrics=SinkMetrics() if metrics is None else metrics, + sessions=autopush.tests.boto_sessions, ) diff --git a/autopush/tests/test_db.py b/autopush/tests/test_db.py index 920565dc..5c4e64f4 100644 --- a/autopush/tests/test_db.py +++ b/autopush/tests/test_db.py @@ -11,21 +11,26 @@ import pytest from autopush.db import ( - get_rotating_message_table, + get_rotating_message_tablename, get_router_table, create_router_table, preflight_check, table_exists, Message, Router, + BotoSessions, generate_last_connect, make_rotating_tablename, _drop_table, - _make_table) + _make_table, + MAX_SESSIONS, + ) from autopush.exceptions import AutopushException from autopush.metrics import SinkMetrics from autopush.utils import WebPushNotification +# nose fails to import sessions correctly. +import autopush.tests dummy_uaid = str(uuid.UUID("abad1dea00000000aabbccdd00000000")) dummy_chid = str(uuid.UUID("deadbeef00000000decafbad00000000")) @@ -44,17 +49,41 @@ def make_webpush_notification(uaid, chid, ttl=100): class DbUtilsTest(unittest.TestCase): def test_make_table(self): + fake_session = Mock() fake_func = Mock() fake_table = "DoesNotExist_{}".format(uuid.uuid4()) - _make_table(fake_func, fake_table, 5, 10) - assert fake_func.call_args[0] == (fake_table, 5, 10) + _make_table(fake_func, fake_table, 5, 10, boto_session=fake_session) + assert fake_func.call_args[0] == (fake_table, 5, 10, fake_session) + + +class SessionsTest(unittest.TestCase): + def test_session_pool(self): + + testpool = BotoSessions(conf={'endpoint_url': 'http://localhost:8000'}) + hold = [] + with pytest.raises(ClientError): + for i in range(0, MAX_SESSIONS+1): + hold.append(testpool.fetch()) + + assert len(hold) == MAX_SESSIONS class DbCheckTestCase(unittest.TestCase): + def setUp(cls): + cls.session = autopush.tests.boto_sessions.fetch() + + def tearDown(cls): + autopush.tests.boto_sessions.release(cls.session) + def test_preflight_check_fail(self): - router = Router(get_router_table(), SinkMetrics()) - message = Message(get_rotating_message_table(), SinkMetrics()) + router = Router(get_router_table(boto_session=self.session), + SinkMetrics(), + boto_sessions=autopush.tests.boto_sessions) + message = Message(get_rotating_message_tablename( + boto_session=self.session), + SinkMetrics(), + boto_sessions=autopush.tests.boto_sessions) def raise_exc(*args, **kwargs): # pragma: no cover raise Exception("Oops") @@ -63,34 +92,44 @@ def raise_exc(*args, **kwargs): # pragma: no cover router.clear_node.side_effect = raise_exc with pytest.raises(Exception): - preflight_check(message, router) + preflight_check(message, router, self.session) def test_preflight_check(self): - router = Router(get_router_table(), SinkMetrics()) - message = Message(get_rotating_message_table(), SinkMetrics()) + router = Router({}, SinkMetrics(), + boto_sessions=autopush.tests.boto_sessions) + message = Message(get_rotating_message_tablename( + boto_session=self.session), + SinkMetrics(), + boto_sessions=autopush.tests.boto_sessions) pf_uaid = "deadbeef00000000deadbeef01010101" - preflight_check(message, router, pf_uaid) + preflight_check(message, router, pf_uaid, self.session) # now check that the database reports no entries. - _, notifs = message.fetch_messages(uuid.UUID(pf_uaid)) + _, notifs = message.fetch_messages(uuid.UUID(pf_uaid), + self.session) assert len(notifs) == 0 with pytest.raises(ItemNotFound): router.get_uaid(pf_uaid) def test_preflight_check_wait(self): - router = Router(get_router_table(), SinkMetrics()) - message = Message(get_rotating_message_table(), SinkMetrics()) + router = Router({}, SinkMetrics(), + boto_sessions=autopush.tests.boto_sessions) + message = Message(get_rotating_message_tablename( + boto_session=self.session), + SinkMetrics(), + boto_sessions=autopush.tests.boto_sessions) values = ["PENDING", "ACTIVE"] message.table_status = Mock(side_effect=values) pf_uaid = "deadbeef00000000deadbeef01010101" - preflight_check(message, router, pf_uaid) + preflight_check(message, router, pf_uaid, self.session) # now check that the database reports no entries. - _, notifs = message.fetch_messages(uuid.UUID(pf_uaid)) + _, notifs = message.fetch_messages(uuid.UUID(pf_uaid), + boto_session=self.session) assert len(notifs) == 0 with pytest.raises(ItemNotFound): - router.get_uaid(pf_uaid) + router.get_uaid(pf_uaid, boto_session=self.session) def test_get_month(self): from autopush.db import get_month @@ -126,21 +165,25 @@ def test_normalize_id(self): class MessageTestCase(unittest.TestCase): def setUp(self): - table = get_rotating_message_table() + self.session = autopush.tests.boto_sessions.fetch() + table = get_rotating_message_tablename(boto_session=self.session) self.real_table = table self.uaid = str(uuid.uuid4()) def tearDown(self): + autopush.tests.boto_sessions.release(self.session) pass def test_register(self): chid = str(uuid.uuid4()) - m = get_rotating_message_table() - message = Message(m, SinkMetrics()) - message.register_channel(self.uaid, chid) + m = get_rotating_message_tablename(boto_session=self.session) + message = Message(m, SinkMetrics(), + boto_sessions=autopush.tests.boto_sessions) + message.register_channel(self.uaid, chid, boto_session=self.session) + lm = self.session.Table(m) # Verify it's in the db - response = m.query( + response = lm.query( KeyConditions={ 'uaid': { 'AttributeValueList': [self.uaid], @@ -157,12 +200,15 @@ def test_register(self): def test_unregister(self): chid = str(uuid.uuid4()) - m = get_rotating_message_table() - message = Message(m, SinkMetrics()) + m = get_rotating_message_tablename(boto_session=self.session) + message = Message(m, SinkMetrics(), + boto_sessions=autopush.tests.boto_sessions) message.register_channel(self.uaid, chid) # Verify its in the db - response = m.query( + lm = self.session.Table(m) + # Verify it's in the db + response = lm.query( KeyConditions={ 'uaid': { 'AttributeValueList': [self.uaid], @@ -179,10 +225,10 @@ def test_unregister(self): assert len(results) == 1 assert results[0]["chids"] == {chid} - message.unregister_channel(self.uaid, chid) + message.unregister_channel(self.uaid, chid, boto_session=self.session) # Verify its not in the db - response = m.query( + response = lm.query( KeyConditions={ 'uaid': { 'AttributeValueList': [self.uaid], @@ -200,78 +246,93 @@ def test_unregister(self): assert results[0].get("chids") is None # Test for the very unlikely case that there's no 'chid' - m.update_item = Mock(return_value={ + mtable = Mock() + mtable.update_item = Mock(return_value={ 'Attributes': {'uaid': self.uaid}, 'ResponseMetaData': {} }) - r = message.unregister_channel(self.uaid, dummy_chid) + message.table = Mock(return_value=mtable) + r = message.unregister_channel(self.uaid, dummy_chid, + boto_session=self.session) assert r is False def test_all_channels(self): chid = str(uuid.uuid4()) chid2 = str(uuid.uuid4()) - m = get_rotating_message_table() - message = Message(m, SinkMetrics()) - message.register_channel(self.uaid, chid) - message.register_channel(self.uaid, chid2) + m = get_rotating_message_tablename(boto_session=self.session) + message = Message(m, SinkMetrics(), + boto_sessions=autopush.tests.boto_sessions) + message.register_channel(self.uaid, chid, boto_session=self.session) + message.register_channel(self.uaid, chid2, boto_session=self.session) - _, chans = message.all_channels(self.uaid) + _, chans = message.all_channels(self.uaid, boto_session=self.session) assert chid in chans assert chid2 in chans - message.unregister_channel(self.uaid, chid2) - _, chans = message.all_channels(self.uaid) + message.unregister_channel(self.uaid, chid2, boto_session=self.session) + _, chans = message.all_channels(self.uaid, boto_session=self.session) assert chid2 not in chans assert chid in chans def test_all_channels_fail(self): - m = get_rotating_message_table() - message = Message(m, SinkMetrics()) - message.table.get_item = Mock() - message.table.get_item.return_value = { + m = get_rotating_message_tablename(boto_session=self.session) + message = Message(m, SinkMetrics(), + boto_sessions=autopush.tests.boto_sessions) + + mtable = Mock() + mtable.get_item.return_value = { "ResponseMetadata": { "HTTPStatusCode": 400 }, } - + message.table = Mock(return_value=mtable) res = message.all_channels(self.uaid) assert res == (False, set([])) def test_save_channels(self): chid = str(uuid.uuid4()) chid2 = str(uuid.uuid4()) - m = get_rotating_message_table() - message = Message(m, SinkMetrics()) - message.register_channel(self.uaid, chid) - message.register_channel(self.uaid, chid2) - - exists, chans = message.all_channels(self.uaid) + m = get_rotating_message_tablename(boto_session=self.session) + message = Message(m, SinkMetrics(), + boto_sessions=autopush.tests.boto_sessions) + message.register_channel(self.uaid, chid, boto_session=self.session) + message.register_channel(self.uaid, chid2, boto_session=self.session) + + exists, chans = message.all_channels(self.uaid, + boto_session=self.session) new_uaid = uuid.uuid4().hex - message.save_channels(new_uaid, chans) - _, new_chans = message.all_channels(new_uaid) + message.save_channels(new_uaid, chans, boto_session=self.session) + _, new_chans = message.all_channels(new_uaid, + boto_session=self.session) assert chans == new_chans def test_all_channels_no_uaid(self): - m = get_rotating_message_table() - message = Message(m, SinkMetrics()) - exists, chans = message.all_channels(dummy_uaid) + m = get_rotating_message_tablename(boto_session=self.session) + message = Message(m, SinkMetrics(), + boto_sessions=autopush.tests.boto_sessions) + exists, chans = message.all_channels(dummy_uaid, + boto_session=self.session) assert chans == set([]) def test_message_storage(self): chid = str(uuid.uuid4()) chid2 = str(uuid.uuid4()) - m = get_rotating_message_table() - message = Message(m, SinkMetrics()) - message.register_channel(self.uaid, chid) - message.register_channel(self.uaid, chid2) - - message.store_message(make_webpush_notification(self.uaid, chid)) - message.store_message(make_webpush_notification(self.uaid, chid)) - message.store_message(make_webpush_notification(self.uaid, chid)) + m = get_rotating_message_tablename(boto_session=self.session) + message = Message(m, SinkMetrics(), + boto_sessions=autopush.tests.boto_sessions) + message.register_channel(self.uaid, chid, boto_session=self.session) + message.register_channel(self.uaid, chid2, boto_session=self.session) + + message.store_message(make_webpush_notification(self.uaid, chid), + boto_session=self.session) + message.store_message(make_webpush_notification(self.uaid, chid), + boto_session=self.session) + message.store_message(make_webpush_notification(self.uaid, chid), + boto_session=self.session) _, all_messages = message.fetch_timestamp_messages( - uuid.UUID(self.uaid), " ") + uuid.UUID(self.uaid), " ", boto_session=self.session) assert len(all_messages) == 3 def test_message_storage_overwrite(self): @@ -283,30 +344,34 @@ def test_message_storage_overwrite(self): notif2 = make_webpush_notification(self.uaid, chid) notif3 = make_webpush_notification(self.uaid, chid2) notif2.message_id = notif1.message_id - m = get_rotating_message_table() - message = Message(m, SinkMetrics()) - message.register_channel(self.uaid, chid) - message.register_channel(self.uaid, chid2) - - message.store_message(notif1) - message.store_message(notif2) - message.store_message(notif3) - - all_messages = list(message.fetch_messages(uuid.UUID(self.uaid))) + m = get_rotating_message_tablename(boto_session=self.session) + message = Message(m, SinkMetrics(), + boto_sessions=autopush.tests.boto_sessions) + message.register_channel(self.uaid, chid, boto_session=self.session) + message.register_channel(self.uaid, chid2, boto_session=self.session) + + message.store_message(notif1, boto_session=self.session) + message.store_message(notif2, boto_session=self.session) + message.store_message(notif3, boto_session=self.session) + + all_messages = list(message.fetch_messages(uuid.UUID(self.uaid), + boto_session=self.session)) assert len(all_messages) == 2 def test_message_delete_fail_condition(self): notif = make_webpush_notification(dummy_uaid, dummy_chid) notif.message_id = notif.update_id = dummy_uaid - m = get_rotating_message_table() - message = Message(m, SinkMetrics()) + m = get_rotating_message_tablename(boto_session=self.session) + message = Message(m, SinkMetrics(), + boto_sessions=autopush.tests.boto_sessions) def raise_condition(*args, **kwargs): raise ClientError({}, 'delete_item') - message.table = Mock() - message.table.delete_item.side_effect = raise_condition - result = message.delete_message(notif) + m_de = Mock() + m_de.delete_item = Mock(side_effect=raise_condition) + message.table = Mock(return_value=m_de) + result = message.delete_message(notif, boto_session=self.session) assert result is False def test_message_rotate_table_with_date(self): @@ -314,22 +379,21 @@ def test_message_rotate_table_with_date(self): future = (datetime.today() + timedelta(days=32)).date() tbl_name = make_rotating_tablename(prefix, date=future) - m = get_rotating_message_table(prefix=prefix, date=future) - assert m.table_name == tbl_name + m = get_rotating_message_tablename(prefix=prefix, date=future, + boto_session=self.session) + assert m == tbl_name # Clean up the temp table. - _drop_table(tbl_name) + _drop_table(tbl_name, boto_session=self.session) class RouterTestCase(unittest.TestCase): @classmethod - def setup_class(self): - table = get_router_table() - self.real_table = table - self.real_connection = table.meta.client + def setUpClass(cls): + cls.boto_session = autopush.tests.boto_sessions.fetch() @classmethod - def teardown_class(self): - self.real_table.meta.client = self.real_connection + def tearDownClass(cls): + autopush.tests.boto_sessions.release(cls.boto_session) def _create_minimal_record(self): data = { @@ -342,141 +406,151 @@ def _create_minimal_record(self): def test_drop_old_users(self): # First create a bunch of users - r = get_router_table() - router = Router(r, SinkMetrics()) + router = Router({}, SinkMetrics(), + boto_sessions=autopush.tests.boto_sessions) # Purge any existing users from previous runs. - router.drop_old_users(0) + router.drop_old_users(months_ago=0, boto_session=self.boto_session) for _ in range(0, 53): - router.register_user(self._create_minimal_record()) + router.register_user(self._create_minimal_record(), + boto_session=self.boto_session) - results = router.drop_old_users(months_ago=0) + results = router.drop_old_users(months_ago=0, + boto_session=self.boto_session) assert list(results) == [25, 25, 3] def test_custom_tablename(self): db_name = "router_%s" % uuid.uuid4() - assert not table_exists(db_name) - create_router_table(db_name) - assert table_exists(db_name) + assert not table_exists(db_name, boto_session=self.boto_session) + create_router_table(db_name, boto_session=self.boto_session) + assert table_exists(db_name, boto_session=self.boto_session) # Clean up the temp table. - _drop_table(db_name) + _drop_table(db_name, boto_session=self.boto_session) def test_provisioning(self): db_name = "router_%s" % uuid.uuid4() - r = create_router_table(db_name, 3, 17) + r = create_router_table(db_name, 3, 17, boto_session=self.boto_session) assert r.provisioned_throughput.get('ReadCapacityUnits') == 3 assert r.provisioned_throughput.get('WriteCapacityUnits') == 17 def test_no_uaid_found(self): uaid = str(uuid.uuid4()) - r = get_router_table() - router = Router(r, SinkMetrics()) + router = Router({}, SinkMetrics(), + boto_sessions=autopush.tests.boto_sessions) with pytest.raises(ItemNotFound): - router.get_uaid(uaid) + router.get_uaid(uaid, boto_session=self.boto_session) def test_uaid_provision_failed(self): - r = get_router_table() - router = Router(r, SinkMetrics()) + router = Router({}, SinkMetrics(), + boto_sessions=autopush.tests.boto_sessions) router.table = Mock() def raise_condition(*args, **kwargs): - import autopush.db - raise autopush.db.g_client.exceptions.ClientError( + raise ClientError( {'Error': {'Code': 'ProvisionedThroughputExceededException'}}, 'mock_update_item' ) - router.table.get_item.side_effect = raise_condition + mm = Mock() + mm.get_item = Mock(side_effect=raise_condition) + router.table = Mock(return_value=mm) with pytest.raises(ClientError) as ex: - router.get_uaid(uaid="asdf") + router.get_uaid(uaid="asdf", boto_session=self.boto_session) assert (ex.value.response['Error']['Code'] == "ProvisionedThroughputExceededException") def test_register_user_provision_failed(self): - r = get_router_table() - router = Router(r, SinkMetrics()) - router.table.meta.client = Mock() + router = Router({}, SinkMetrics(), + boto_sessions=autopush.tests.boto_sessions) + mm = Mock() + mm.client = Mock() + router.table = Mock(return_value=mm) def raise_condition(*args, **kwargs): - import autopush.db - raise autopush.db.g_client.exceptions.ClientError( + raise ClientError( {'Error': {'Code': 'ProvisionedThroughputExceededException'}}, 'mock_update_item' ) - router.table.update_item = Mock(side_effect=raise_condition) + mm.update_item = Mock(side_effect=raise_condition) with pytest.raises(ClientError) as ex: router.register_user(dict(uaid=dummy_uaid, node_id="me", connected_at=1234, - router_type="webpush")) + router_type="webpush"), + boto_session=self.boto_session) assert (ex.value.response['Error']['Code'] == "ProvisionedThroughputExceededException") def test_register_user_condition_failed(self): - r = get_router_table() - router = Router(r, SinkMetrics()) - router.table.meta.client = Mock() + router = Router({}, SinkMetrics(), + boto_sessions=autopush.tests.boto_sessions) + router.table(self.boto_session).meta.client = Mock() def raise_error(*args, **kwargs): - import autopush.db - raise autopush.db.g_client.exceptions.ClientError( + raise ClientError( {'Error': {'Code': 'ConditionalCheckFailedException'}}, 'mock_update_item' ) - - router.table.update_item = Mock(side_effect=raise_error) + mm = Mock() + mm.update_item = Mock(side_effect=raise_error) + router.table = Mock(return_value=mm) res = router.register_user(dict(uaid=dummy_uaid, node_id="me", connected_at=1234, - router_type="webpush")) + router_type="webpush"), + boto_session=self.boto_session) assert res == (False, {}) def test_clear_node_provision_failed(self): - r = get_router_table() - router = Router(r, SinkMetrics()) + router = Router({}, SinkMetrics(), + boto_sessions=autopush.tests.boto_sessions) def raise_condition(*args, **kwargs): - import autopush.db - raise autopush.db.g_client.exceptions.ClientError( + raise ClientError( {'Error': {'Code': 'ProvisionedThroughputExceededException'}}, 'mock_update_item' ) - router.table.put_item = Mock(side_effect=raise_condition) + mm = Mock() + mm.put_item = Mock(side_effect=raise_condition) + router.table = Mock(return_value=mm) with pytest.raises(ClientError) as ex: router.clear_node(dict(uaid=dummy_uaid, connected_at="1234", node_id="asdf", - router_type="webpush")) + router_type="webpush"), + boto_session=self.boto_session) assert (ex.value.response['Error']['Code'] == "ProvisionedThroughputExceededException") def test_clear_node_condition_failed(self): - r = get_router_table() - router = Router(r, SinkMetrics()) + router = Router({}, SinkMetrics(), + boto_sessions=autopush.tests.boto_sessions) def raise_error(*args, **kwargs): - import autopush.db - raise autopush.db.g_client.exceptions.ClientError( + raise ClientError( {'Error': {'Code': 'ConditionalCheckFailedException'}}, 'mock_put_item' ) - router.table.put_item = Mock(side_effect=raise_error) + router.table(self.boto_session).put_item = Mock( + side_effect=raise_error) res = router.clear_node(dict(uaid=dummy_uaid, connected_at="1234", node_id="asdf", - router_type="webpush")) + router_type="webpush"), + boto_session=self.boto_session) + assert res is False def test_incomplete_uaid(self): # Older records may be incomplete. We can't inject them using normal # methods. uaid = str(uuid.uuid4()) - r = get_router_table() - router = Router(r, SinkMetrics()) - router.table.get_item = Mock() - router.drop_user = Mock() - router.table.get_item.return_value = { + router = Router({}, SinkMetrics(), + boto_sessions=autopush.tests.boto_sessions) + mm = Mock() + mm.get_item = Mock() + mm.get_item.return_value = { "ResponseMetadata": { "HTTPStatusCode": 200 }, @@ -484,83 +558,97 @@ def test_incomplete_uaid(self): "uaid": uuid.uuid4().hex } } + mm.delete_item.return_value = { + "ResponseMetadata": { + "HTTPStatusCode": 200 + }, + } + router.table = Mock(return_value=mm) + router.drop_user = Mock() try: - router.register_user(dict(uaid=uaid)) + router.register_user(dict(uaid=uaid), + boto_session=self.boto_session) except AutopushException: pass with pytest.raises(ItemNotFound): - router.get_uaid(uaid) + router.get_uaid(uaid, boto_session=self.boto_session) assert router.drop_user.called def test_failed_uaid(self): uaid = str(uuid.uuid4()) - r = get_router_table() - router = Router(r, SinkMetrics()) - router.table.get_item = Mock() - router.drop_user = Mock() - router.table.get_item.return_value = { + router = Router({}, SinkMetrics(), + boto_sessions=autopush.tests.boto_sessions) + mm = Mock() + mm.get_item = Mock() + mm.get_item.return_value = { "ResponseMetadata": { "HTTPStatusCode": 400 }, } + router.table = Mock(return_value=mm) + router.drop_user = Mock() with pytest.raises(ItemNotFound): - router.get_uaid(uaid) + router.get_uaid(uaid, boto_session=self.boto_session) def test_save_new(self): - r = get_router_table() - router = Router(r, SinkMetrics()) + router = Router({}, SinkMetrics(), + boto_sessions=autopush.tests.boto_sessions) # Sadly, moto currently does not return an empty value like boto # when not updating data. - router.table.update_item = Mock(return_value={}) + router.table(self.boto_session).update_item = Mock(return_value={}) result = router.register_user(dict(uaid=dummy_uaid, node_id="me", router_type="webpush", - connected_at=1234)) + connected_at=1234), + boto_session=self.boto_session) assert result[0] is True def test_save_fail(self): - r = get_router_table() - router = Router(r, SinkMetrics()) + router = Router({}, SinkMetrics(), + boto_sessions=autopush.tests.boto_sessions) def raise_condition(*args, **kwargs): - import autopush.db - raise autopush.db.g_client.exceptions.ClientError( + raise ClientError( {'Error': {'Code': 'ConditionalCheckFailedException'}}, 'mock_update_item' ) - router.table.update_item = Mock(side_effect=raise_condition) + router.table(self.boto_session).update_item = Mock( + side_effect=raise_condition + ) router_data = dict(uaid=dummy_uaid, node_id="asdf", connected_at=1234, router_type="webpush") - result = router.register_user(router_data) + result = router.register_user(router_data, + boto_session=self.boto_session) assert result == (False, {}) def test_node_clear(self): - r = get_router_table() - router = Router(r, SinkMetrics()) + router = Router({}, SinkMetrics(), + boto_sessions=autopush.tests.boto_sessions) # Register a node user router.register_user(dict(uaid=dummy_uaid, node_id="asdf", connected_at=1234, - router_type="webpush")) + router_type="webpush"), + boto_session=self.boto_session) # Verify - user = router.get_uaid(dummy_uaid) + user = router.get_uaid(dummy_uaid, boto_session=self.boto_session) assert user["node_id"] == "asdf" assert user["connected_at"] == 1234 assert user["router_type"] == "webpush" # Clear - router.clear_node(user) + router.clear_node(user, boto_session=self.boto_session) # Verify - user = router.get_uaid(dummy_uaid) + user = router.get_uaid(dummy_uaid, boto_session=self.boto_session) assert user.get("node_id") is None assert user["connected_at"] == 1234 assert user["router_type"] == "webpush" def test_node_clear_fail(self): - r = get_router_table() - router = Router(r, SinkMetrics()) + router = Router({}, SinkMetrics(), + boto_sessions=autopush.tests.boto_sessions) def raise_condition(*args, **kwargs): raise ClientError( @@ -568,21 +656,22 @@ def raise_condition(*args, **kwargs): 'mock_update_item' ) - router.table.put_item = Mock(side_effect=raise_condition) + router.table(boto_session=self.boto_session).put_item = Mock( + side_effect=raise_condition) data = dict(uaid=dummy_uaid, node_id="asdf", connected_at=1234) - result = router.clear_node(data) + result = router.clear_node(data, boto_session=self.boto_session) assert result is False def test_drop_user(self): uaid = str(uuid.uuid4()) - r = get_router_table() - router = Router(r, SinkMetrics()) + router = Router({}, SinkMetrics(), + boto_sessions=autopush.tests.boto_sessions) # Register a node user router.register_user(dict(uaid=uaid, node_id="asdf", router_type="webpush", connected_at=1234)) - result = router.drop_user(uaid) + result = router.drop_user(uaid, boto_session=self.boto_session) assert result is True # Deleting already deleted record should return false. - result = router.drop_user(uaid) + result = router.drop_user(uaid, boto_session=self.boto_session) assert result is False diff --git a/autopush/tests/test_diagnostic_cli.py b/autopush/tests/test_diagnostic_cli.py index 9094f01e..0979bbf3 100644 --- a/autopush/tests/test_diagnostic_cli.py +++ b/autopush/tests/test_diagnostic_cli.py @@ -2,6 +2,8 @@ from mock import Mock, patch +import autopush.tests + class FakeDict(dict): pass @@ -10,14 +12,19 @@ class FakeDict(dict): class DiagnosticCLITestCase(unittest.TestCase): def _makeFUT(self, *args, **kwargs): from autopush.diagnostic_cli import EndpointDiagnosticCLI - return EndpointDiagnosticCLI(*args, use_files=False, **kwargs) + return EndpointDiagnosticCLI(*args, + sessions=autopush.tests.boto_sessions, + use_files=False, + **kwargs) def test_basic_load(self): cli = self._makeFUT([ "--router_tablename=fred", "http://someendpoint", ]) - assert cli.db.router.table.table_name == "fred" + session = cli.db.sessions.fetch() + assert cli.db.router.table(boto_session=session).table_name == "fred" + cli.db.sessions.release(session) def test_bad_endpoint(self): cli = self._makeFUT([ @@ -27,9 +34,10 @@ def test_bad_endpoint(self): returncode = cli.run() assert returncode not in (None, 0) + @patch("autopush.diagnostic_cli.Message") @patch("autopush.diagnostic_cli.AutopushConfig") @patch("autopush.diagnostic_cli.DatabaseManager.from_config") - def test_successfull_lookup(self, mock_db_cstr, mock_conf_class): + def test_successfull_lookup(self, mock_db_cstr, mock_conf_class, mock_msg): from autopush.diagnostic_cli import run_endpoint_diagnostic_cli mock_conf_class.return_value = mock_conf = Mock() mock_conf.parse_endpoint.return_value = dict( @@ -39,11 +47,14 @@ def test_successfull_lookup(self, mock_db_cstr, mock_conf_class): mock_db.router.get_uaid.return_value = mock_item = FakeDict() mock_item._data = {} mock_item["current_month"] = "201608120002" - mock_message_table = Mock() - mock_db.message_tables = {"201608120002": mock_message_table} - - run_endpoint_diagnostic_cli([ - "--router_tablename=fred", - "http://something/wpush/v1/legit_endpoint", - ], use_files=False) - mock_message_table.all_channels.assert_called() + mock_db.message_tables = ["201608120002"] + mock_msg.return_value = mock_message = Mock() + + run_endpoint_diagnostic_cli( + sysargs=[ + "--router_tablename=fred", + "http://something/wpush/v1/legit_endpoint", + ], + use_files=False, + sessions=autopush.tests.boto_sessions) + mock_message.all_channels.assert_called() diff --git a/autopush/tests/test_endpoint.py b/autopush/tests/test_endpoint.py index af7426aa..b7e5cb2d 100644 --- a/autopush/tests/test_endpoint.py +++ b/autopush/tests/test_endpoint.py @@ -14,6 +14,7 @@ Message, ItemNotFound, has_connected_this_month, + Router ) from autopush.exceptions import RouterException from autopush.http import EndpointHTTPFactory @@ -51,7 +52,7 @@ def setUp(self): crypto_key='AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=', ) db = test_db() - self.message_mock = db.message = Mock(spec=Message) + self.message_mock = db._message = Mock(spec=Message) self.fernet_mock = conf.fernet = Mock(spec=Fernet) app = EndpointHTTPFactory.for_handler(MessageHandler, conf, db=db) @@ -140,12 +141,15 @@ def setUp(self): self.fernet_mock = conf.fernet = Mock(spec=Fernet) self.db = db = test_db() + db.router = Mock(spec=Router) db.router.register_user.return_value = (True, {}, {}) db.router.get_uaid.return_value = { "router_type": "test", "router_data": dict() } - db.create_initial_message_tables() + session = self.db.sessions.fetch() + db.create_initial_message_tables(boto_session=session) + self.db.sessions.release(session) self.routers = routers = routers_from_config(conf, db, Mock()) routers["test"] = Mock(spec=IRouter) diff --git a/autopush/tests/test_health.py b/autopush/tests/test_health.py index 3e96896f..485611b7 100644 --- a/autopush/tests/test_health.py +++ b/autopush/tests/test_health.py @@ -1,15 +1,13 @@ import json import twisted.internet.base -from boto.dynamodb2.exceptions import InternalServerError from mock import Mock from twisted.internet.defer import inlineCallbacks from twisted.logger import globalLogPublisher from twisted.trial import unittest -import autopush.db from autopush import __version__ -from autopush.config import AutopushConfig +from autopush.config import AutopushConfig, DDBTableConfig from autopush.db import DatabaseManager from autopush.exceptions import MissingTableException from autopush.http import EndpointHTTPFactory @@ -17,21 +15,26 @@ from autopush.tests.client import Client from autopush.tests.support import TestingLogObserver from autopush.web.health import HealthHandler, StatusHandler +import autopush.tests class HealthTestCase(unittest.TestCase): def setUp(self): - self.timeout = 0.5 + self.timeout = 6 twisted.internet.base.DelayedCall.debug = True conf = AutopushConfig( hostname="localhost", statsd_host=None, + router_table=DDBTableConfig(tablename="router_test") ) - db = DatabaseManager.from_config(conf) - db.client = autopush.db.g_client - db.setup_tables() + db = DatabaseManager.from_config( + conf, + boto_sessions=autopush.tests.boto_sessions) + session = autopush.tests.boto_sessions.fetch() + db.setup_tables(boto_session=session) + autopush.tests.boto_sessions.release(session) # ignore logging logs = TestingLogObserver() @@ -50,82 +53,25 @@ def test_healthy(self): "version": __version__, "clients": 0, "storage": {"status": "OK"}, - "router": {"status": "OK"} + "router_test": {"status": "OK"} }) - @inlineCallbacks - def test_aws_error(self): - - def raise_error(*args, **kwargs): - raise InternalServerError(None, None) - - safe = self.client.app.db.client - self.client.app.db.client = Mock() - self.client.app.db.client.list_tables = Mock(side_effect=raise_error) - - yield self._assert_reply({ - "status": "NOT OK", - "version": __version__, - "clients": 0, - "storage": { - "status": "NOT OK", - "error": "Server error" - }, - "router": { - "status": "NOT OK", - "error": "Server error" - } - }, InternalServerError) - - self.client.app.db.client = safe - @inlineCallbacks def test_nonexistent_table(self): - no_tables = Mock(return_value={"TableNames": []}) - safe = self.client.app.db.client - self.client.app.db.client = Mock() - self.client.app.db.client.list_tables = no_tables + session = autopush.tests.boto_sessions.fetch() + self.client.app.db.router.table(boto_session=session).delete() + autopush.tests.boto_sessions.release(session) yield self._assert_reply({ "status": "NOT OK", "version": __version__, "clients": 0, - "storage": { - "status": "NOT OK", - "error": "Nonexistent table" - }, - "router": { + "storage": {"status": "OK"}, + "router_test": { "status": "NOT OK", "error": "Nonexistent table" } }, MissingTableException) - self.client.app.db.client = safe - - @inlineCallbacks - def test_internal_error(self): - def raise_error(*args, **kwargs): - raise Exception("synergies not aligned") - - safe = self.client.app.db.client - self.client.app.db.client = Mock() - self.client.app.db.client.list_tables = Mock( - side_effect=raise_error - ) - - yield self._assert_reply({ - "status": "NOT OK", - "version": __version__, - "clients": 0, - "storage": { - "status": "NOT OK", - "error": "Internal error" - }, - "router": { - "status": "NOT OK", - "error": "Internal error" - } - }, Exception) - self.client.app.db.client = safe @inlineCallbacks def _assert_reply(self, reply, exception=None): diff --git a/autopush/tests/test_integration.py b/autopush/tests/test_integration.py index b819be2e..945f2f63 100644 --- a/autopush/tests/test_integration.py +++ b/autopush/tests/test_integration.py @@ -36,12 +36,14 @@ from autopush.config import AutopushConfig from autopush.db import ( get_month, - has_connected_this_month + has_connected_this_month, + Message, ) from autopush.logging import begin_or_register from autopush.main import ConnectionApplication, EndpointApplication from autopush.utils import base64url_encode, normalize_id from autopush.metrics import SinkMetrics, DatadogMetrics +import autopush.tests from autopush.tests.support import TestingLogObserver from autopush.websocket import PushServerFactory @@ -130,7 +132,8 @@ def register(self, chid=None, key=None): key=key)) log.debug("Send: %s", msg) self.ws.send(msg) - result = json.loads(self.ws.recv()) + rcv = self.ws.recv() + result = json.loads(rcv) log.debug("Recv: %s", result) assert result["status"] == 200 assert result["channelID"] == chid @@ -227,7 +230,7 @@ def send_notification(self, channel=None, version=None, data=None, if self.ws and self.ws.connected: return object.__getattribute__(self, "get_notification")(timeout) - def get_notification(self, timeout=1): + def get_notification(self, timeout=6): orig_timeout = self.ws.gettimeout() self.ws.settimeout(timeout) try: @@ -319,15 +322,19 @@ def setUp(self): ) # Endpoint HTTP router - self.ep = ep = EndpointApplication(ep_conf) - ep.db.client = db.g_client + self.ep = ep = EndpointApplication( + ep_conf, + sessions=autopush.tests.boto_sessions + ) ep.setup(rotate_tables=False) ep.startService() self.addCleanup(ep.stopService) # Websocket server - self.conn = conn = ConnectionApplication(conn_conf) - conn.db.client = db.g_client + self.conn = conn = ConnectionApplication( + conn_conf, + sessions=autopush.tests.boto_sessions + ) conn.setup(rotate_tables=False) conn.startService() self.addCleanup(conn.stopService) @@ -472,10 +479,10 @@ def test_legacy_simplepush_record(self): return_value={'router_type': 'simplepush', 'uaid': uaid, 'current_month': self.ep.db.current_msg_month}) - self.ep.db.message_tables[ - self.ep.db.current_msg_month].all_channels = Mock( - return_value=(True, client.channels)) + safe = db.Message.all_channels + db.Message.all_channels = Mock(return_value=(True, client.channels)) yield client.send_notification() + db.Message.all_channels = safe yield self.shut_down(client) @patch("autopush.metrics.datadog") @@ -537,12 +544,13 @@ def test_webpush_data_save_fail(self): yield client.hello() yield client.register(chid=chan) yield client.disconnect() - self.ep.db.message_tables[ - self.ep.db.current_msg_month].store_message = Mock( + safe = db.Message.store_message + db.Message.store_message = Mock( return_value=False) yield client.send_notification(channel=chan, data=test["data"], status=201) + db.Message.store_message = safe yield self.shut_down(client) @@ -1152,20 +1160,24 @@ def test_delete_saved_notification(self): def test_webpush_monthly_rotation(self): from autopush.db import make_rotating_tablename client = yield self.quick_register() + session = self.conn.db.sessions.fetch() yield client.disconnect() # Move the client back one month to the past last_month = make_rotating_tablename( prefix=self.conn.conf.message_table.tablename, delta=-1) - lm_message = self.conn.db.message_tables[last_month] + lm_message = Message(last_month, boto_sessions=self.conn.db.sessions) yield deferToThread( self.conn.db.router.update_message_month, client.uaid, - last_month + last_month, + boto_session=session ) # Verify the move - c = yield deferToThread(self.conn.db.router.get_uaid, client.uaid) + c = yield deferToThread(self.conn.db.router.get_uaid, + client.uaid, + boto_session=session) assert c["current_month"] == last_month # Verify last_connect is current, then move that back @@ -1178,32 +1190,39 @@ def test_webpush_monthly_rotation(self): self.conn.db.router._update_last_connect, client.uaid, last_connect, + boto_session=session, ) - c = yield deferToThread(self.conn.db.router.get_uaid, client.uaid) + c = yield deferToThread(self.conn.db.router.get_uaid, + client.uaid, + boto_session=session) assert has_connected_this_month(c) is False # Move the clients channels back one month exists, chans = yield deferToThread( - self.conn.db.message.all_channels, client.uaid + self.conn.db.message.all_channels, + client.uaid, + boto_session=session ) assert exists is True assert len(chans) == 1 yield deferToThread( lm_message.save_channels, client.uaid, - chans + chans, + boto_session=session ) # Remove the channels entry entirely from this month yield deferToThread( - self.conn.db.message.table.delete_item, + self.conn.db.message.table(session).delete_item, Key={'uaid': client.uaid, 'chidmessageid': ' '} ) # Verify the channel is gone exists, chans = yield deferToThread( self.conn.db.message.all_channels, - client.uaid + client.uaid, + boto_session=session ) assert exists is False assert len(chans) == 0 @@ -1215,7 +1234,8 @@ def test_webpush_monthly_rotation(self): yield client.send_notification(data=data) ts, notifs = yield deferToThread(lm_message.fetch_timestamp_messages, uuid.UUID(client.uaid), - " ") + " ", + boto_session=session) assert len(notifs) == 1 # Connect the client, verify the migration @@ -1239,7 +1259,9 @@ def test_webpush_monthly_rotation(self): start = time.time() while time.time()-start < 2: c = yield deferToThread( - self.conn.db.router.get_uaid, client.uaid) + self.conn.db.router.get_uaid, + client.uaid, + boto_session=session) if c["current_month"] == self.conn.db.current_msg_month: break else: @@ -1247,7 +1269,9 @@ def test_webpush_monthly_rotation(self): # Verify the month update in the router table c = yield deferToThread( - self.conn.db.router.get_uaid, client.uaid) + self.conn.db.router.get_uaid, + client.uaid, + boto_session=session) assert c["current_month"] == self.conn.db.current_msg_month assert server_client.ps.rotate_message_table is False @@ -1257,31 +1281,36 @@ def test_webpush_monthly_rotation(self): # Verify the channels were moved exists, chans = yield deferToThread( self.conn.db.message.all_channels, - client.uaid + client.uaid, + boto_session=session ) assert exists is True assert len(chans) == 1 - + self.conn.db.sessions.release(session) yield self.shut_down(client) @inlineCallbacks def test_webpush_monthly_rotation_prior_record_exists(self): from autopush.db import make_rotating_tablename client = yield self.quick_register() + session = self.conn.db.sessions.fetch() yield client.disconnect() # Move the client back one month to the past last_month = make_rotating_tablename( prefix=self.conn.conf.message_table.tablename, delta=-1) - lm_message = self.conn.db.message_tables[last_month] + lm_message = Message(last_month) yield deferToThread( self.conn.db.router.update_message_month, client.uaid, - last_month + last_month, + boto_session=session ) # Verify the move - c = yield deferToThread(self.conn.db.router.get_uaid, client.uaid) + c = yield deferToThread(self.conn.db.router.get_uaid, + client.uaid, + boto_session=session) assert c["current_month"] == last_month # Verify last_connect is current, then move that back @@ -1290,21 +1319,24 @@ def test_webpush_monthly_rotation_prior_record_exists(self): yield deferToThread( self.conn.db.router._update_last_connect, client.uaid, - int("%s%s020001" % (today.year, str(today.month).zfill(2))) + int("%s%s020001" % (today.year, str(today.month).zfill(2))), + boto_session=session ) c = yield deferToThread(self.conn.db.router.get_uaid, client.uaid) assert has_connected_this_month(c) is False # Move the clients channels back one month exists, chans = yield deferToThread( - self.conn.db.message.all_channels, client.uaid + self.conn.db.message.all_channels, client.uaid, + boto_session=session ) assert exists is True assert len(chans) == 1 yield deferToThread( lm_message.save_channels, client.uaid, - chans + chans, + boto_session=session ) # Send in a notification, verify it landed in last months notification @@ -1314,7 +1346,8 @@ def test_webpush_monthly_rotation_prior_record_exists(self): yield client.send_notification(data=data) _, notifs = yield deferToThread(lm_message.fetch_timestamp_messages, uuid.UUID(client.uaid), - " ") + " ", + boto_session=session) assert len(notifs) == 1 # Connect the client, verify the migration @@ -1338,7 +1371,9 @@ def test_webpush_monthly_rotation_prior_record_exists(self): start = time.time() while time.time()-start < 2: c = yield deferToThread( - self.conn.db.router.get_uaid, client.uaid) + self.conn.db.router.get_uaid, + client.uaid, + boto_session=session) if c["current_month"] == self.conn.db.current_msg_month: break else: @@ -1355,17 +1390,19 @@ def test_webpush_monthly_rotation_prior_record_exists(self): # Verify the channels were moved exists, chans = yield deferToThread( self.conn.db.message.all_channels, - client.uaid + client.uaid, + boto_session=session ) assert exists is True assert len(chans) == 1 - + self.conn.db.sessions.release(session) yield self.shut_down(client) @inlineCallbacks def test_webpush_monthly_rotation_no_channels(self): from autopush.db import make_rotating_tablename client = Client("ws://localhost:9010/") + session = self.conn.db.sessions.fetch() yield client.connect() yield client.hello() yield client.disconnect() @@ -1376,17 +1413,21 @@ def test_webpush_monthly_rotation_no_channels(self): yield deferToThread( self.conn.db.router.update_message_month, client.uaid, - last_month + last_month, + boto_session=session ) # Verify the move - c = yield deferToThread(self.conn.db.router.get_uaid, client.uaid) + c = yield deferToThread(self.conn.db.router.get_uaid, + client.uaid, + boto_session=session) assert c["current_month"] == last_month # Verify there's no channels exists, chans = yield deferToThread( self.conn.db.message.all_channels, - client.uaid + client.uaid, + boto_session=session ) assert exists is False assert len(chans) == 0 @@ -1403,17 +1444,22 @@ def test_webpush_monthly_rotation_no_channels(self): start = time.time() while time.time()-start < 2: c = yield deferToThread( - self.conn.db.router.get_uaid, client.uaid) + self.conn.db.router.get_uaid, + client.uaid, + boto_session=session, + ) if c["current_month"] == self.conn.db.current_msg_month: break else: yield deferToThread(time.sleep, 0.2) # Verify the month update in the router table - c = yield deferToThread(self.conn.db.router.get_uaid, client.uaid) + c = yield deferToThread(self.conn.db.router.get_uaid, + client.uaid, + boto_session=session) assert c["current_month"] == self.conn.db.current_msg_month assert server_client.ps.rotate_message_table is False - + self.conn.db.sessions.release(session) yield self.shut_down(client) @inlineCallbacks diff --git a/autopush/tests/test_main.py b/autopush/tests/test_main.py index 4b95aa07..18b45ed9 100644 --- a/autopush/tests/test_main.py +++ b/autopush/tests/test_main.py @@ -11,7 +11,7 @@ import autopush.db from autopush.config import AutopushConfig -from autopush.db import DatabaseManager, get_rotating_message_table +from autopush.db import DatabaseManager, get_rotating_message_tablename from autopush.exceptions import InvalidConfig from autopush.http import skip_request_logging from autopush.main import ( @@ -20,6 +20,7 @@ ) from autopush.tests.support import test_db from autopush.utils import resolve_ip +import autopush.tests connection_main = ConnectionApplication.main endpoint_main = EndpointApplication.main @@ -51,7 +52,9 @@ def test_new_month(self): db = test_db() db._tomorrow = Mock() db._tomorrow.return_value = tomorrow - db.create_initial_message_tables() + session = db.sessions.fetch() + db.create_initial_message_tables(boto_session=session) + db.sessions.release(session) assert len(db.message_tables) == 3 @@ -60,19 +63,23 @@ def test_update_rotating_tables(self): from autopush.db import get_month conf = AutopushConfig( hostname="example.com", resolve_hostname=True) - db = DatabaseManager.from_config(conf) - db.create_initial_message_tables() + db = DatabaseManager.from_config( + conf, + boto_sessions=autopush.tests.boto_sessions) + session = db.sessions.fetch() + db.create_initial_message_tables(boto_session=session) + db.sessions.release(session) # Erase the tables it has on init, and move current month back one last_month = get_month(-1) db.current_month = last_month.month - db.message_tables = {} + db.message_tables = [] # Create the next month's table, just in case today is the day before # a new month, in which case the lack of keys will cause an error in # update_rotating_tables next_month = get_month(1) - db.message_tables[next_month.month] = None + assert next_month.month not in db.message_tables # Get the deferred back e = Deferred() @@ -115,45 +122,47 @@ def test_update_rotating_tables_month_end(self): conf = AutopushConfig( hostname="example.com", resolve_hostname=True) - db = DatabaseManager.from_config(conf) + db = DatabaseManager.from_config( + conf, + boto_sessions=autopush.tests.boto_sessions) db._tomorrow = Mock(return_value=tomorrow) - db.create_initial_message_tables() + session = db.sessions.fetch() + db.create_initial_message_tables(boto_session=session) # We should have 3 tables, one for next/this/last month assert len(db.message_tables) == 3 # Grab next month's table name and remove it - next_month = get_rotating_message_table( + next_month = get_rotating_message_tablename( conf.message_table.tablename, - delta=1 + delta=1, + boto_session=session ) - db.message_tables.pop(next_month.table_name) + db.sessions.release(session) + db.message_tables.pop(db.message_tables.index(next_month)) # Get the deferred back d = db.update_rotating_tables() def check_tables(result): assert len(db.message_tables) == 3 - assert next_month.table_name in db.message_tables + assert next_month in db.message_tables d.addCallback(check_tables) return d def test_update_not_needed(self): - from autopush.db import get_month conf = AutopushConfig( hostname="google.com", resolve_hostname=True) - db = DatabaseManager.from_config(conf) - db.create_initial_message_tables() + db = DatabaseManager.from_config( + conf, + boto_sessions=autopush.tests.boto_sessions) + session = db.sessions.fetch() + db.create_initial_message_tables(boto_session=session) + db.sessions.release(session) # Erase the tables it has on init, and move current month back one - db.message_tables = {} - - # Create the next month's table, just in case today is the day before - # a new month, in which case the lack of keys will cause an error in - # update_rotating_tables - next_month = get_month(1) - db.message_tables[next_month.month] = None + db.message_tables = [] # Get the deferred back e = Deferred() @@ -184,7 +193,7 @@ def tearDown(self): mock.stop() def test_basic(self): - connection_main([], False) + connection_main([], False, sessions=autopush.tests.boto_sessions) def test_ssl(self): connection_main([ @@ -193,12 +202,12 @@ def test_ssl(self): "--ssl_key=keys/server.key", "--router_ssl_cert=keys/server.crt", "--router_ssl_key=keys/server.key", - ], False) + ], False, sessions=autopush.tests.boto_sessions) def test_memusage(self): connection_main([ "--memusage_port=8083", - ], False) + ], False, sessions=autopush.tests.boto_sessions) def test_skip_logging(self): # Should skip setting up logging on the handler @@ -277,15 +286,18 @@ def tearDown(self): autopush.db.key_hash = "" def test_basic(self): - endpoint_main([ - ], False) + endpoint_main( + [], + False, + sessions=autopush.tests.boto_sessions + ) def test_ssl(self): endpoint_main([ "--ssl_dh_param=keys/dhparam.pem", "--ssl_cert=keys/server.crt", "--ssl_key=keys/server.key", - ], False) + ], False, sessions=autopush.tests.boto_sessions) def test_bad_senderidlist(self): returncode = endpoint_main([ @@ -306,18 +318,18 @@ def test_client_certs(self): "--ssl_cert=keys/server.crt", "--ssl_key=keys/server.key", '--client_certs={"foo": ["%s"]}' % cert - ], False) + ], False, sessions=autopush.tests.boto_sessions) assert not returncode def test_proxy_protocol_port(self): endpoint_main([ "--proxy_protocol_port=8081", - ], False) + ], False, sessions=autopush.tests.boto_sessions) def test_memusage(self): endpoint_main([ "--memusage_port=8083", - ], False) + ], False, sessions=autopush.tests.boto_sessions) def test_client_certs_parse(self): conf = AutopushConfig.from_argparse(self.TestArg) @@ -344,7 +356,8 @@ def test_bad_client_certs(self): @patch('hyper.tls', spec=hyper.tls) def test_conf(self, *args): conf = AutopushConfig.from_argparse(self.TestArg) - app = EndpointApplication(conf) + app = EndpointApplication(conf, + sessions=autopush.tests.boto_sessions) # verify that the hostname is what we said. assert conf.hostname == self.TestArg.hostname assert app.routers["gcm"].router_conf['collapsekey'] == "collapse" @@ -375,7 +388,7 @@ def test_gcm_start(self): endpoint_main([ "--gcm_enabled", """--senderid_list={"123":{"auth":"abcd"}}""", - ], False) + ], False, sessions=autopush.tests.boto_sessions) @patch("requests.get") def test_aws_ami_id(self, request_mock): diff --git a/autopush/tests/test_router.py b/autopush/tests/test_router.py index 9e721585..cf25aa23 100644 --- a/autopush/tests/test_router.py +++ b/autopush/tests/test_router.py @@ -7,9 +7,8 @@ import requests import ssl -from autopush.utils import WebPushNotification -from mock import Mock, PropertyMock, patch import pytest +from mock import Mock, PropertyMock, patch from twisted.trial import unittest from twisted.internet.error import ConnectionRefusedError from twisted.internet.defer import inlineCallbacks @@ -35,6 +34,7 @@ from autopush.router.interface import RouterResponse, IRouter from autopush.tests import MockAssist from autopush.tests.support import test_db +from autopush.utils import WebPushNotification class RouterInterfaceTestCase(TestCase): @@ -999,7 +999,7 @@ def setUp(self): mock_result.not_registered = dict() mock_result.retry_after = 1000 self.router_mock = db.router - self.message_mock = db.message = Mock(spec=Message) + self.message_mock = db._message = Mock(spec=Message) self.conf = conf def test_route_to_busy_node_saves_looks_up_and_sends_check_201(self): @@ -1009,6 +1009,7 @@ def test_route_to_busy_node_saves_looks_up_and_sends_check_201(self): side_effect=MockAssist([202, 200])) self.message_mock.store_message.return_value = True self.message_mock.all_channels.return_value = (True, [dummy_chid]) + self.db.message_table = Mock(return_value=self.message_mock) router_data = dict(node_id="http://somewhere", uaid=dummy_uaid, current_month=self.db.current_msg_month) self.router_mock.get_uaid.return_value = router_data @@ -1033,6 +1034,7 @@ def test_route_failure(self): self.agent_mock.request = Mock(side_effect=ConnectionRefusedError) self.message_mock.store_message.return_value = True self.message_mock.all_channels.return_value = (True, [dummy_chid]) + self.db.message_table = Mock(return_value=self.message_mock) router_data = dict(node_id="http://somewhere", uaid=dummy_uaid, current_month=self.db.current_msg_month) self.router_mock.get_uaid.return_value = router_data @@ -1070,6 +1072,7 @@ def test_route_to_busy_node_with_ttl_zero(self): side_effect=MockAssist([202, 200])) self.message_mock.store_message.return_value = True self.message_mock.all_channels.return_value = (True, [dummy_chid]) + self.db.message_table = Mock(return_value=self.message_mock) router_data = dict(node_id="http://somewhere", uaid=dummy_uaid, current_month=self.db.current_msg_month) self.router_mock.get_uaid.return_value = router_data @@ -1100,9 +1103,8 @@ def throw(): self.agent_mock.request.return_value = response_mock = Mock() response_mock.code = 202 - self.message_mock.store_message.side_effect = MockAssist( - [throw] - ) + self.message_mock.store_message.side_effect = MockAssist([throw]) + self.db.message_table = Mock(return_value=self.message_mock) router_data = dict(node_id="http://somewhere", uaid=dummy_uaid, current_month=self.db.current_msg_month) @@ -1123,6 +1125,7 @@ def throw(): raise JSONResponseError(500, "Whoops") self.message_mock.store_message.return_value = True + self.db.message_table = Mock(return_value=self.message_mock) self.router_mock.get_uaid.side_effect = MockAssist( [throw] ) @@ -1144,6 +1147,7 @@ def throw(): raise ItemNotFound() self.message_mock.store_message.return_value = True + self.db.message_table = Mock(return_value=self.message_mock) self.router_mock.get_uaid.side_effect = MockAssist( [throw] ) @@ -1160,9 +1164,8 @@ def verify_deliver(status): def test_route_lookup_uaid_no_nodeid(self): self.message_mock.store_message.return_value = True - self.router_mock.get_uaid.return_value = dict( - - ) + self.db.message_table = Mock(return_value=self.message_mock) + self.router_mock.get_uaid.return_value = dict() router_data = dict(node_id="http://somewhere", uaid=dummy_uaid, current_month=self.db.current_msg_month) @@ -1179,6 +1182,7 @@ def test_route_and_clear_failure(self): self.agent_mock.request = Mock(side_effect=ConnectionRefusedError) self.message_mock.store_message.return_value = True self.message_mock.all_channels.return_value = (True, [dummy_chid]) + self.db.message_table = Mock(return_value=self.message_mock) router_data = dict(node_id="http://somewhere", uaid=dummy_uaid, current_month=self.db.current_msg_month) self.router_mock.get_uaid.return_value = router_data diff --git a/autopush/tests/test_web_base.py b/autopush/tests/test_web_base.py index dc56255d..fd1d2f71 100644 --- a/autopush/tests/test_web_base.py +++ b/autopush/tests/test_web_base.py @@ -195,8 +195,7 @@ def test_response_err(self): def test_overload_err(self): try: - import autopush.db - raise autopush.db.g_client.exceptions.ClientError( + raise ClientError( {'Error': { 'Code': 'ProvisionedThroughputExceededException'}}, 'mock_update_item' @@ -208,8 +207,7 @@ def test_overload_err(self): def test_client_err(self): try: - import autopush.db - raise autopush.db.g_client.exceptions.ClientError( + raise ClientError( {'Error': { 'Code': 'Flibbertygidgit'}}, 'mock_update_item' diff --git a/autopush/tests/test_web_webpush.py b/autopush/tests/test_web_webpush.py index ca6327f5..aade1b7a 100644 --- a/autopush/tests/test_web_webpush.py +++ b/autopush/tests/test_web_webpush.py @@ -20,6 +20,7 @@ class TestWebpushHandler(unittest.TestCase): def setUp(self): + import autopush from autopush.web.webpush import WebPushHandler self.conf = conf = AutopushConfig( @@ -30,11 +31,13 @@ def setUp(self): self.fernet_mock = conf.fernet = Mock(spec=Fernet) self.db = db = test_db() - self.message_mock = db.message = Mock(spec=Message) + self.message_mock = db._message = Mock(spec=Message) + self.db.message_table = Mock(return_value=self.message_mock) self.message_mock.all_channels.return_value = (True, [dummy_chid]) app = EndpointHTTPFactory.for_handler(WebPushHandler, conf, db=db) self.wp_router_mock = app.routers["webpush"] = Mock(spec=IRouter) + self.db.router = Mock(spec=autopush.db.Router) self.client = Client(app) def url(self, **kwargs): @@ -48,21 +51,24 @@ def test_router_needs_update(self): public_key="asdfasdf", )) self.fernet_mock.decrypt.return_value = dummy_token - self.db.router.get_uaid.return_value = dict( + self.db.router.get_uaid = Mock(return_value=dict( router_type="webpush", router_data=dict(), uaid=dummy_uaid, current_month=self.db.current_msg_month, - ) + )) + self.db.router.register_user = Mock(return_value=False) self.wp_router_mock.route_notification.return_value = RouterResponse( status_code=503, router_data=dict(token="new_connect"), ) + self.db.message_tables.append(self.db.current_msg_month) resp = yield self.client.post( self.url(api_ver="v1", token=dummy_token), ) - assert resp.get_status() == 503 + rstatus = resp.get_status() + assert rstatus == 503 ru = self.db.router.register_user assert ru.called assert 'webpush' == ru.call_args[0][0].get('router_type') @@ -85,6 +91,7 @@ def test_router_returns_data_without_detail(self): status_code=503, router_data=dict(), ) + self.db.message_tables.append(self.db.current_msg_month) resp = yield self.client.post( self.url(api_ver="v1", token=dummy_token), diff --git a/autopush/tests/test_webpush_server.py b/autopush/tests/test_webpush_server.py index 2805ff37..057c2ea1 100644 --- a/autopush/tests/test_webpush_server.py +++ b/autopush/tests/test_webpush_server.py @@ -15,6 +15,7 @@ DatabaseManager, generate_last_connect, make_rotating_tablename, + Message, ) from autopush.metrics import SinkMetrics from autopush.config import AutopushConfig @@ -35,6 +36,7 @@ Unregister, WebPushMessage, ) +import autopush.tests class AutopushCall(object): @@ -170,14 +172,18 @@ def setUp(self): begin_or_register(self.logs) self.addCleanup(globalLogPublisher.removeObserver, self.logs) - self.db = db = DatabaseManager.from_config(self.conf) + self.db = db = DatabaseManager.from_config( + self.conf, + boto_sessions=autopush.tests.boto_sessions) self.metrics = db.metrics = Mock(spec=SinkMetrics) db.setup_tables() def _store_messages(self, uaid, topic=False, num=5): try: item = self.db.router.get_uaid(uaid.hex) - message_table = self.db.message_tables[item["current_month"]] + message_table = Message( + item["current_month"], + boto_sessions=autopush.tests.boto_sessions) except ItemNotFound: message_table = self.db.message messages = [WebPushNotificationFactory(uaid=uaid) @@ -441,7 +447,9 @@ def test_migrate_user(self): # Check that it's there item = self.db.router.get_uaid(uaid) - _, channels = self.db.message_tables[last_month].all_channels(uaid) + _, channels = Message( + last_month, + boto_sessions=self.db.sessions).all_channels(uaid) assert item["current_month"] != self.db.current_msg_month assert item is not None assert len(channels) == 3 @@ -504,16 +512,18 @@ def test_register_bad_chid_nodash(self): self._test_invalid(uuid4().hex) def test_register_over_provisioning(self): + import autopush def raise_condition(*args, **kwargs): - import autopush.db - raise autopush.db.g_client.exceptions.ClientError( + from botocore.exceptions import ClientError + raise ClientError( {'Error': {'Code': 'ProvisionedThroughputExceededException'}}, 'mock_update_item' ) - self.db.message.table.update_item = Mock( - side_effect=raise_condition) + mock_table = Mock(spec=autopush.db.Message) + mock_table.register_channel = Mock(side_effect=raise_condition) + self.db.message_table = Mock(return_value=mock_table) self._test_invalid(str(uuid4()), "overloaded", 503) diff --git a/autopush/tests/test_websocket.py b/autopush/tests/test_websocket.py index 9f973dad..ab37350c 100644 --- a/autopush/tests/test_websocket.py +++ b/autopush/tests/test_websocket.py @@ -8,6 +8,7 @@ import twisted.internet.base from autobahn.twisted.util import sleep from autobahn.websocket.protocol import ConnectionRequest +from botocore.exceptions import ClientError from mock import Mock, patch import pytest from twisted.internet import reactor @@ -25,7 +26,6 @@ from autopush.db import DatabaseManager from autopush.http import InternalRouterHTTPFactory from autopush.metrics import SinkMetrics -from autopush.tests import MockAssist from autopush.utils import WebPushNotification from autopush.tests.client import Client from autopush.tests.test_db import make_webpush_notification @@ -37,6 +37,7 @@ WebSocketServerProtocol, ) from autopush.utils import base64url_encode, ms_time +import autopush.tests dummy_version = (u'gAAAAABX_pXhN22H-hvscOHsMulKvtC0hKJimrZivbgQPFB3sQAtOPmb' @@ -101,7 +102,10 @@ def setUp(self): statsd_host=None, env="test", ) - db = DatabaseManager.from_config(conf) + db = DatabaseManager.from_config( + conf, + boto_sessions=autopush.tests.boto_sessions + ) self.metrics = db.metrics = Mock(spec=SinkMetrics) db.setup_tables() @@ -381,7 +385,9 @@ def test_close_with_delivery_cleanup(self): self.proto.ps.direct_updates[chid] = [notif] # Apply some mocks - self.proto.db.message.store_message = Mock() + msg_mock = Mock(spec=db.Message) + msg_mock.store_message = Mock() + self.proto.db.message_table = Mock(return_value=msg_mock) self.proto.db.router.get_uaid = mock_get = Mock() mock_get.return_value = dict(node_id="localhost:2000") @@ -400,7 +406,9 @@ def test_close_with_delivery_cleanup_using_webpush(self): self.proto.ps.direct_updates[dummy_chid_str] = [dummy_notif()] # Apply some mocks - self.proto.db.message.store_message = Mock() + msg_mock = Mock(spec=db.Message) + msg_mock.store_message = Mock() + self.proto.db.message_table = Mock(return_value=msg_mock) self.proto.db.router.get_uaid = mock_get = Mock() mock_get.return_value = dict(node_id="localhost:2000") @@ -421,7 +429,9 @@ def test_close_with_delivery_cleanup_and_get_no_result(self): self.proto.ps.direct_updates[chid] = [notif] # Apply some mocks - self.proto.db.message.store_message = Mock() + msg_mock = Mock(spec=db.Message) + msg_mock.store_message = Mock() + self.proto.db.message_table = Mock(return_value=msg_mock) self.proto.db.router.get_uaid = mock_get = Mock() mock_get.return_value = False self.metrics.reset_mock() @@ -451,7 +461,8 @@ def test_hello_old(self): "current_month": msg_date, } router = self.proto.db.router - router.table.put_item( + session = self.proto.db.sessions.fetch() + router.table(session).put_item( Item=dict( uaid=orig_uaid, connected_at=ms_time(), @@ -459,8 +470,9 @@ def test_hello_old(self): router_type="webpush" ) ) + self.proto.db.sessions.release(session) - def fake_msg(data): + def fake_msg(data, **kwargs): return (True, msg_data) mock_msg = Mock(wraps=db.Message) @@ -473,12 +485,10 @@ def fake_msg(data): # notifications are irrelevant for this test. self.proto.process_notifications = Mock() # massage message_tables to include our fake range - mt = self.proto.ps.db.message_tables - for k in mt.keys(): - del(mt[k]) - mt['message_2016_1'] = mock_msg - mt['message_2016_2'] = mock_msg - mt['message_2016_3'] = mock_msg + self.proto.ps.db.message_tables = [ + 'message_2016_1', 'message_2016_2', 'message_2016_3' + ] + self.proto.ps.db.message_table = Mock(return_value=mock_msg) with patch.object(datetime, 'date', Mock(wraps=datetime.date)) as patched: patched.today.return_value = target_day @@ -500,12 +510,14 @@ def fake_msg(data): def test_hello_tomorrow(self): orig_uaid = "deadbeef00000000abad1dea00000000" router = self.proto.db.router + session = self.proto.db.sessions.fetch() router.register_user(dict( uaid=orig_uaid, connected_at=ms_time(), current_month="message_2016_3", router_type="webpush", )) + self.proto.db.sessions.release(session) # router.register_user returns (registered, previous target_day = datetime.date(2016, 2, 29) @@ -521,7 +533,7 @@ def test_hello_tomorrow(self): "current_month": msg_date, } - def fake_msg(data): + def fake_msg(data, **kwargs): return (True, msg_data) mock_msg = Mock(wraps=db.Message) @@ -530,12 +542,10 @@ def fake_msg(data): mock_msg.all_channels.return_value = (None, []) self.proto.db.router.register_user = fake_msg # massage message_tables to include our fake range - mt = self.proto.ps.db.message_tables - for k in mt.keys(): - del(mt[k]) - mt['message_2016_1'] = mock_msg - mt['message_2016_2'] = mock_msg - mt['message_2016_3'] = mock_msg + self.proto.db.message_table = Mock(return_value=mock_msg) + self.proto.ps.db.message_tables = [ + 'message_2016_1', 'message_2016_2', 'message_2016_3' + ] with patch.object(datetime, 'date', Mock(wraps=datetime.date)) as patched: patched.today.return_value = target_day @@ -559,10 +569,11 @@ def fake_msg(data): def test_hello_tomorrow_provision_error(self): orig_uaid = "deadbeef00000000abad1dea00000000" router = self.proto.db.router + current_month = "message_2016_3" router.register_user(dict( uaid=orig_uaid, connected_at=ms_time(), - current_month="message_2016_3", + current_month=current_month, router_type="webpush", )) @@ -580,36 +591,42 @@ def test_hello_tomorrow_provision_error(self): "current_month": msg_date, } - def fake_msg(data): - return (True, msg_data) - mock_msg = Mock(wraps=db.Message) mock_msg.fetch_messages.return_value = "01;", [] mock_msg.fetch_timestamp_messages.return_value = None, [] mock_msg.all_channels.return_value = (None, []) - self.proto.db.router.register_user = fake_msg + self.proto.db.router.register_user = Mock( + return_value=(True, msg_data) + ) # massage message_tables to include our fake range - mt = self.proto.ps.db.message_tables - mt.clear() - mt['message_2016_1'] = mock_msg - mt['message_2016_2'] = mock_msg - mt['message_2016_3'] = mock_msg - + self.proto.ps.db.message_tables = [ + 'message_2016_1', 'message_2016_2', current_month + ] + self.proto.db.message_table = Mock(return_value=mock_msg) patch_range = patch("autopush.websocket.randrange") mock_patch = patch_range.start() mock_patch.return_value = 1 def raise_condition(*args, **kwargs): - import autopush.db - raise autopush.db.g_client.exceptions.ClientError( + raise ClientError( {'Error': {'Code': 'ProvisionedThroughputExceededException'}}, 'mock_update_item' ) - self.proto.db.router.update_message_month = MockAssist([ - raise_condition, - Mock(), - ]) + self.proto.db.register_user = Mock(return_value=(False, {})) + mock_router = Mock(spec=db.Router) + mock_router.register_user = Mock(return_value=(True, msg_data)) + mock_router.update_message_month = Mock(side_effect=raise_condition) + self.proto.db.router = mock_router + self.proto.db.router.get_uaid = Mock(return_value={ + "router_type": "webpush", + "connected_at": int(msg_day.strftime("%s")), + "current_month": current_month, + "last_connect": int(msg_day.strftime("%s")), + "record_version": 1, + }) + self.proto.db.current_msg_month = current_month + self.proto.ps.message_month = current_month with patch.object(datetime, 'date', Mock(wraps=datetime.date)) as patched: @@ -711,7 +728,9 @@ def test_hello_failure(self): self._connect() # Fail out the register_user call router = self.proto.db.router - router.table.update_item = Mock(side_effect=KeyError) + session = self.proto.db.sessions.fetch() + router.table(session).update_item = Mock(side_effect=KeyError) + self.proto.db.sessions.release(session) self._send_message(dict(messageType="hello", channelIDs=[], use_webpush=True, stop=1)) @@ -727,15 +746,15 @@ def test_hello_provisioned_during_check(self): # Fail out the register_user call def raise_condition(*args, **kwargs): - import autopush.db - raise autopush.db.g_client.exceptions.ClientError( + raise ClientError( {'Error': {'Code': 'ProvisionedThroughputExceededException'}}, 'mock_update_item' ) router = self.proto.db.router - router.table.update_item = Mock(side_effect=raise_condition) - + mock_table = Mock() + mock_table.update_item = Mock(side_effect=raise_condition) + router.table = Mock(return_value=mock_table) self._send_message(dict(messageType="hello", use_webpush=True, channelIDs=[])) msg = yield self.get_response() @@ -911,10 +930,12 @@ def test_register_webpush(self): self._connect() chid = str(uuid.uuid4()) self.proto.ps.uaid = uuid.uuid4().hex - self.proto.db.message.register_channel = Mock() + msg_mock = Mock(spec=db.Message) + msg_mock.register_channel = Mock() + self.proto.db.message_table = Mock(return_value=msg_mock) yield self.proto.process_register(dict(channelID=chid)) - assert self.proto.db.message.register_channel.called + assert msg_mock.register_channel.called assert_called_included(self.proto.log.info, format="Register") @inlineCallbacks @@ -922,7 +943,8 @@ def test_register_webpush_with_key(self): self._connect() chid = str(uuid.uuid4()) self.proto.ps.uaid = uuid.uuid4().hex - self.proto.db.message.register_channel = Mock() + msg_mock = Mock(spec=db.Message) + self.proto.db.message_table = Mock(return_value=msg_mock) test_key = "SomeRandomCryptoKeyString" test_sha = sha256(test_key).hexdigest() test_endpoint = ('http://localhost/wpush/v2/' + @@ -942,7 +964,7 @@ def echo(string): ) assert test_endpoint == self.proto.sendJSON.call_args[0][0][ 'pushEndpoint'] - assert self.proto.db.message.register_channel.called + assert msg_mock.register_channel.called assert_called_included(self.proto.log.info, format="Register") @inlineCallbacks @@ -1046,11 +1068,12 @@ def test_register_over_provisioning(self): self._connect() chid = str(uuid.uuid4()) self.proto.ps.uaid = uuid.uuid4().hex - self.proto.db.message.register_channel = register = Mock() + msg_mock = Mock(spec=db.Message) + msg_mock.register_channel = register = Mock() + self.proto.ps.db.message_table = Mock(return_value=msg_mock) def raise_condition(*args, **kwargs): - import autopush.db - raise autopush.db.g_client.exceptions.ClientError( + raise ClientError( {'Error': {'Code': 'ProvisionedThroughputExceededException'}}, 'mock_update_item' ) @@ -1058,7 +1081,7 @@ def raise_condition(*args, **kwargs): register.side_effect = raise_condition yield self.proto.process_register(dict(channelID=chid)) - assert self.proto.db.message.register_channel.called + assert msg_mock.register_channel.called assert self.send_mock.called args, _ = self.send_mock.call_args msg = json.loads(args[0]) @@ -1305,9 +1328,11 @@ def test_process_notifications(self): self.proto.ps.uaid = uuid.uuid4().hex # Swap out fetch_notifications - self.proto.db.message.fetch_messages = Mock( + msg_mock = Mock(spec=db.Message) + msg_mock.fetch_messages = Mock( return_value=(None, []) ) + self.proto.ps.db.message_table = Mock(return_value=msg_mock) self.proto.process_notifications() @@ -1337,13 +1362,12 @@ def throw(*args, **kwargs): twisted.internet.base.DelayedCall.debug = True self._connect() - self.proto.db.message.fetch_messages = Mock( - return_value=(None, []) - ) - self.proto.db.message.fetch_messages = Mock( + msg_mock = Mock(spec=db.Message) + msg_mock.fetch_messages = Mock( side_effect=throw) - self.proto.db.message.fetch_timestamp_messages = Mock( + msg_mock.fetch_timestamp_messages = Mock( side_effect=throw) + self.proto.db.message_table = Mock(return_value=msg_mock) self.proto.ps.uaid = uuid.uuid4().hex self.proto.process_notifications() @@ -1367,21 +1391,19 @@ def wait(result): def test_process_notifications_provision_err(self): def raise_condition(*args, **kwargs): - import autopush.db - raise autopush.db.g_client.exceptions.ClientError( + raise ClientError( {'Error': {'Code': 'ProvisionedThroughputExceededException'}}, 'mock_update_item' ) twisted.internet.base.DelayedCall.debug = True self._connect() - self.proto.db.message.fetch_messages = Mock( - return_value=(None, []) - ) - self.proto.db.message.fetch_messages = Mock( + msg_mock = Mock(spec=db.Message) + msg_mock.fetch_messages = Mock( side_effect=raise_condition) - self.proto.db.message.fetch_timestamp_messages = Mock( + msg_mock.fetch_timestamp_messages = Mock( side_effect=raise_condition) + self.proto.db.message_table = Mock(return_value=msg_mock) self.proto.deferToLater = Mock() self.proto.ps.uaid = uuid.uuid4().hex @@ -1488,7 +1510,9 @@ def test_notif_finished_with_too_many_messages(self): self.proto.ps.uaid = uuid.uuid4().hex self.proto.ps._check_notifications = True self.proto.db.router.drop_user = Mock() - self.proto.ps.message.fetch_messages = Mock() + msg_mock = Mock() + msg_mock.fetch_messages = Mock() + self.proto.ps.db.message_table = Mock(return_value=msg_mock) notif = make_webpush_notification( self.proto.ps.uaid, @@ -1496,7 +1520,7 @@ def test_notif_finished_with_too_many_messages(self): ttl=500 ) self.proto.ps.updates_sent = defaultdict(lambda: []) - self.proto.ps.message.fetch_messages.return_value = ( + msg_mock.fetch_messages.return_value = ( None, [notif, notif, notif] ) diff --git a/autopush/web/health.py b/autopush/web/health.py index d28265b2..5a437297 100644 --- a/autopush/web/health.py +++ b/autopush/web/health.py @@ -32,15 +32,18 @@ def get(self): clients=len(getattr(self.application, 'clients', ())) ) + session = self.db.sessions.fetch() + dl = DeferredList([ - self._check_table(self.db.router.table), - self._check_table(self.db.message.table, "storage") + self._check_table(self.db.router.table(session), session), + self._check_table(self.db.message.table(session), session, + "storage") ]) - dl.addBoth(self._finish_response) + dl.addBoth(self._finish_response, session) - def _check_table(self, table, name_over=None): + def _check_table(self, table, session, name_over=None): """Checks the tables known about in DynamoDB""" - d = deferToThread(table_exists, table.table_name, self.db.client) + d = deferToThread(table_exists, table.table_name, boto_session=session) d.addCallback(self._check_success, name_over or table.table_name) d.addErrback(self._check_error, name_over or table.table_name) return d @@ -65,7 +68,7 @@ def _check_error(self, failure, name): else: cause["error"] = "Internal error" - def _finish_response(self, results): + def _finish_response(self, results, session): """Returns whether the check succeeded or not""" if self._healthy: self._health_checks["status"] = "OK" @@ -74,6 +77,7 @@ def _finish_response(self, results): self._health_checks["status"] = "NOT OK" self.write(self._health_checks) + self.db.sessions.release(session) self.finish() diff --git a/autopush/web/webpush.py b/autopush/web/webpush.py index 2d6002fa..954198b8 100644 --- a/autopush/web/webpush.py +++ b/autopush/web/webpush.py @@ -149,7 +149,8 @@ def _validate_webpush(self, d, result): db.router.drop_user(uaid) raise InvalidRequest("No such subscription", status_code=410, errno=106) - exists, chans = db.message_tables[month_table].all_channels(uaid=uaid) + msg = db.message_table(month_table) + exists, chans = msg.all_channels(uaid=uaid) if (not exists or channel_id.lower() not in map(lambda x: normalize_id(x), chans)): diff --git a/autopush/webpush_server.py b/autopush/webpush_server.py index 1c8e6d04..130b7781 100644 --- a/autopush/webpush_server.py +++ b/autopush/webpush_server.py @@ -25,6 +25,8 @@ has_connected_this_month, hasher, generate_last_connect, + Message, + BotoSessions, ) from autopush.config import AutopushConfig # noqa @@ -356,7 +358,12 @@ def process(self, hello): # Save the UAID as register_user removes it uaid = user_item["uaid"] # type: str - success, _ = self.db.router.register_user(user_item) + session = self.db.sessions.fetch() + try: + success, _ = self.db.router.register_user(user_item, + boto_session=session) + finally: + self.db.sessions.release(session) flags["connected_at"] = hello.connected_at if not success: # User has already connected more recently elsewhere @@ -374,10 +381,13 @@ def lookup_user(self, hello): rotate_message_table=False, ) uaid = hello.uaid.hex + session = self.db.sessions.fetch() try: - record = self.db.router.get_uaid(uaid) + record = self.db.router.get_uaid(uaid, session) except ItemNotFound: return None, flags + finally: + self.db.sessions.release(session) # All records must have a router_type and connected_at, in some odd # cases a record exists for some users without it @@ -437,7 +447,11 @@ def drop_user(self, uaid, uaid_record, code): uaid_record=repr(uaid_record) ) self.metrics.increment('ua.expiration', tags=['code:{}'.format(code)]) - self.db.router.drop_user(uaid) + session = self.db.sessions.fetch() + try: + self.db.router.drop_user(uaid, boto_session=session) + finally: + self.db.sessions.release(session) class CheckStorageCommand(ProcessorCommand): @@ -445,20 +459,27 @@ def process(self, command): # type: (CheckStorage) -> CheckStorageResponse # First, determine if there's any messages to retrieve - timestamp, messages, include_topic = self._check_storage(command) + session = self.db.sessions.fetch() + try: + timestamp, messages, include_topic = self._check_storage(command, + session) + finally: + self.db.sessions.release(session) return CheckStorageResponse( timestamp=timestamp, messages=messages, include_topic=include_topic, ) - def _check_storage(self, command): + def _check_storage(self, command, session): timestamp = None messages = [] - message = self.db.message_tables[command.message_month] + message = Message(command.message_month, + boto_sessions=self.db.sessions) if command.include_topic: timestamp, messages = message.fetch_messages( uaid=command.uaid, limit=11, + boto_session=session ) # If we have topic messages, return them immediately @@ -475,6 +496,7 @@ def _check_storage(self, command): timestamp, messages = message.fetch_timestamp_messages( uaid=command.uaid, timestamp=command.timestamp, + boto_session=session ) messages = [WebPushMessage.from_WebPushNotification(m) for m in messages] @@ -484,7 +506,8 @@ def _check_storage(self, command): class IncrementStorageCommand(ProcessorCommand): def process(self, command): # type: (IncStoragePosition) -> IncStoragePositionResponse - message = self.db.message_tables[command.message_month] + message = Message(command.message_month, + boto_sessions=self.db.sessions) message.update_last_message_read(command.uaid, command.timestamp) return IncStoragePositionResponse() @@ -493,7 +516,8 @@ class DeleteMessageCommand(ProcessorCommand): def process(self, command): # type: (DeleteMessage) -> DeleteMessageResponse notif = command.message.to_WebPushNotification() - message = self.db.message_tables[command.message_month] + message = Message(command.message_month, + boto_sessions=self.db.sessions) message.delete_message(notif) return DeleteMessageResponse() @@ -501,7 +525,11 @@ def process(self, command): class DropUserCommand(ProcessorCommand): def process(self, command): # type: (DropUser) -> DropUserResponse - self.db.router.drop_user(command.uaid.hex) + session = self.db.sessions.fetch() + try: + self.db.router.drop_user(command.uaid.hex) + finally: + self.db.sessions.release(session) return DropUserResponse() @@ -509,30 +537,45 @@ class MigrateUserCommand(ProcessorCommand): def process(self, command): # type: (MigrateUser) -> MigrateUserResponse # Get the current channels for this month - message = self.db.message_tables[command.message_month] - _, channels = message.all_channels(command.uaid.hex) - - # Get the current message month - cur_month = self.db.current_msg_month - if channels: - # Save the current channels into this months message table - msg_table = self.db.message_tables[cur_month] - msg_table.save_channels(command.uaid.hex, channels) - - # Finally, update the route message month - self.db.router.update_message_month(command.uaid.hex, cur_month) + message = Message(command.message_month, + boto_sessions=self.db.sessions) + session = self.db.sessions.fetch() + try: + _, channels = message.all_channels(command.uaid.hex, + boto_session=session) + + # Get the current message month + cur_month = self.db.current_msg_month + if channels: + # Save the current channels into this months message table + msg_table = Message(cur_month) + msg_table.save_channels(command.uaid.hex, + channels, + boto_session=session) + + # Finally, update the route message month + self.db.router.update_message_month(command.uaid.hex, + cur_month, + boto_session=session) + finally: + self.db.sessions.release(session) return MigrateUserResponse(message_month=cur_month) class StoreMessagesUserCommand(ProcessorCommand): def process(self, command): # type: (StoreMessages) -> StoreMessagesResponse - message = self.db.message_tables[command.message_month] - for m in command.messages: - if "topic" not in m: - m["topic"] = None - notif = WebPushMessage(**m).to_WebPushNotification() - message.store_message(notif) + session = self.db.sessions.fetch() + try: + message = Message(command.message_month, + boto_sessions=self.db.sessions) + for m in command.messages: + if "topic" not in m: + m["topic"] = None + notif = WebPushMessage(**m).to_WebPushNotification() + message.store_message(notif, boto_session=session) + finally: + self.db.sessions.release(session) return StoreMessagesResponse() @@ -581,15 +624,19 @@ def process(self, command): command.channel_id, command.key ) - message = self.db.message_tables[command.message_month] + session = self.db.sessions.fetch() + message = self.db.message_table(command.message_month) try: - message.register_channel(command.uaid.hex, command.channel_id) + message.register_channel(command.uaid.hex, + command.channel_id, + boto_session=session) except ClientError as ex: if (ex.response['Error']['Code'] == "ProvisionedThroughputExceededException"): return RegisterErrorResponse(error_msg="overloaded", status=503) - + finally: + self.db.sessions.release(session) self.metrics.increment('ua.command.register') log.info( "Register", @@ -630,9 +677,14 @@ def process(self, if not valid: return UnregisterErrorResponse(error_msg=msg) - message = self.db.message_tables[command.message_month] + session = self.db.sessions.fetch() + message = Message(command.message_month, + boto_sessions=self.db.sessions) # TODO: JSONResponseError not handled (no force_retry) - message.unregister_channel(command.uaid.hex, command.channel_id) + try: + message.unregister_channel(command.uaid.hex, command.channel_id) + finally: + self.db.sessions.release(session) # TODO: Clear out any existing tracked messages for this # channel diff --git a/autopush/websocket.py b/autopush/websocket.py index a3de965a..15dbb0b2 100644 --- a/autopush/websocket.py +++ b/autopush/websocket.py @@ -249,7 +249,7 @@ def __attrs_post_init__(self): def message(self): # type: () -> Message """Property to access the currently used message table""" - return self.db.message_tables[self.message_month] + return Message(self.message_month) @property def user_agent(self): @@ -596,7 +596,8 @@ def cleanUp(self, wasClean, code, reason): def _save_webpush_notif(self, notif): """Save a direct_update webpush style notification""" - return deferToThread(self.ps.message.store_message, + message = self.db.message_table(self.ps.message_month) + return deferToThread(message.store_message, notif).addErrback(self.log_failure) def _lookup_node(self, results): @@ -648,7 +649,7 @@ def returnError(self, messageType, reason, statusCode, close=True, if close: self.sendClose() - def err_overload(self, failure, message_type, disconnect=True): + def error_overload(self, failure, message_type, disconnect=True): """Handle database overloads and errors If ``disconnect`` is False, the an overload error is returned and the @@ -668,7 +669,7 @@ def err_overload(self, failure, message_type, disconnect=True): if disconnect: self.transport.pauseProducing() d = self.deferToLater(self.randrange(4, 9), - self.err_finish_overload, message_type) + self.error_finish_overload, message_type) d.addErrback(self.trap_cancel) else: if (failure.value.response["Error"]["Code"] != @@ -678,7 +679,7 @@ def err_overload(self, failure, message_type, disconnect=True): "status": 503} self.sendJSON(send) - def err_finish_overload(self, message_type): + def error_finish_overload(self, message_type): """Close the connection down and resume consuming input after the random interval from a db overload""" # Resume producing so we can finish the shutdown @@ -716,8 +717,8 @@ def process_hello(self, data): d = self.deferToThread(self._register_user, existing_user) d.addCallback(self._check_other_nodes) d.addErrback(self.trap_cancel) - d.addErrback(self.err_overload, "hello") - d.addErrback(self.err_hello) + d.addErrback(self.error_overload, "hello") + d.addErrback(self.error_hello) self.ps._register = d return d @@ -745,7 +746,12 @@ def _register_user(self, existing_user=True): ) user_item["current_month"] = self.ps.message_month - return self.db.router.register_user(user_item) + session = self.db.sessions.fetch() + try: + return self.db.router.register_user(user_item, + boto_session=session) + finally: + self.db.sessions.release(session) def _verify_user_record(self): """Verify a user record is valid @@ -756,57 +762,64 @@ def _verify_user_record(self): :rtype: :class:`~boto.dynamodb2.items.Item` or None """ + session = self.db.sessions.fetch() try: - record = self.db.router.get_uaid(self.ps.uaid) - except ItemNotFound: - return None - - # All records must have a router_type and connected_at, in some odd - # cases a record exists for some users that doesn't - if "router_type" not in record or "connected_at" not in record: - self.log.debug(format="Dropping User", code=104, - uaid_hash=self.ps.uaid_hash, - uaid_record=repr(record)) - tags = ['code:104'] - self.metrics.increment("ua.expiration", tags=tags) - self.force_retry(self.db.router.drop_user, self.ps.uaid) - return None - - # Validate webpush records - # Current month must exist and be a valid prior month - if ("current_month" not in record) or record["current_month"] \ - not in self.db.message_tables: - self.log.debug(format="Dropping User", code=105, - uaid_hash=self.ps.uaid_hash, - uaid_record=repr(record)) - self.force_retry(self.db.router.drop_user, - self.ps.uaid) - tags = ['code:105'] - self.metrics.increment("ua.expiration", tags=tags) - return None - - # Determine if message table rotation is needed - if record["current_month"] != self.ps.message_month: - self.ps.message_month = record["current_month"] - self.ps.rotate_message_table = True - - # Include and update last_connect if needed, otherwise exclude - if has_connected_this_month(record): - del record["last_connect"] - else: - record["last_connect"] = generate_last_connect() + try: + record = self.db.router.get_uaid(self.ps.uaid, + boto_session=session) + except ItemNotFound: + return None + + # All records must have a router_type and connected_at, in some odd + # cases a record exists for some users that doesn't + if "router_type" not in record or "connected_at" not in record: + self.log.debug(format="Dropping User", code=104, + uaid_hash=self.ps.uaid_hash, + uaid_record=repr(record)) + tags = ['code:104'] + self.metrics.increment("ua.expiration", tags=tags) + self.force_retry(self.db.router.drop_user, self.ps.uaid, + boto_session=session) + return None + + # Validate webpush records + # Current month must exist and be a valid prior month + if ("current_month" not in record) or record["current_month"] \ + not in self.db.message_tables: + self.log.debug(format="Dropping User", code=105, + uaid_hash=self.ps.uaid_hash, + uaid_record=repr(record)) + self.force_retry(self.db.router.drop_user, + self.ps.uaid, + boto_session=session) + tags = ['code:105'] + self.metrics.increment("ua.expiration", tags=tags) + return None + + # Determine if message table rotation is needed + if record["current_month"] != self.ps.message_month: + self.ps.message_month = record["current_month"] + self.ps.rotate_message_table = True + + # Include and update last_connect if needed, otherwise exclude + if has_connected_this_month(record): + del record["last_connect"] + else: + record["last_connect"] = generate_last_connect() - # Determine if this is missing a record version - if ("record_version" not in record or - int(record["record_version"]) < USER_RECORD_VERSION): - self.ps.reset_uaid = True + # Determine if this is missing a record version + if ("record_version" not in record or + int(record["record_version"]) < USER_RECORD_VERSION): + self.ps.reset_uaid = True - # Update the node_id, connected_at for this node/connected_at - record["node_id"] = self.conf.router_url - record["connected_at"] = self.ps.connected_at - return record + # Update the node_id, connected_at for this node/connected_at + record["node_id"] = self.conf.router_url + record["connected_at"] = self.ps.connected_at + return record + finally: + self.db.sessions.release(session) - def err_hello(self, failure): + def error_hello(self, failure): """errBack for hello failures""" self.transport.resumeProducing() self.log_failure(failure) @@ -903,11 +916,14 @@ def process_notifications(self): def webpush_fetch(self): """Helper to return an appropriate function to fetch messages""" + message = self.db.message_table(self.ps.message_month) if self.ps.scan_timestamps: - return partial(self.ps.message.fetch_timestamp_messages, - self.ps.uaid_obj, self.ps.current_timestamp) + return partial(message.fetch_timestamp_messages, + self.ps.uaid_obj, + self.ps.current_timestamp) else: - return partial(self.ps.message.fetch_messages, self.ps.uaid_obj) + return partial(message.fetch_messages, + self.ps.uaid_obj) def error_notifications(self, fail): """errBack for notification check failing""" @@ -931,7 +947,13 @@ def error_notification_overload(self, fail): def error_message_overload(self, fail): """errBack for handling excessive messages per UAID""" fail.trap(MessageOverloadException) - self.force_retry(self.db.router.drop_user, self.ps.uaid) + session = self.db.sessions.fetch() + try: + self.force_retry(self.db.router.drop_user, + self.ps.uaid, + boto_session=session) + finally: + self.db.sessions.release(session) self.sendClose() def finish_notifications(self, notifs): @@ -978,7 +1000,13 @@ def finish_webpush_notifications(self, result): # Told to reset the user? if self.ps.reset_uaid: - self.force_retry(self.db.router.drop_user, self.ps.uaid) + session = self.db.sessions.fetch() + try: + self.force_retry(self.db.router.drop_user, + self.ps.uaid, + boto_session=session) + finally: + self.db.sessions.release(session) self.sendClose() # Not told to check for notifications, do we need to now rotate @@ -990,39 +1018,50 @@ def finish_webpush_notifications(self, result): # Send out all the notifications now = int(time.time()) messages_sent = False - for notif in notifs: - self.ps.stats.stored_retrieved += 1 - # If the TTL is too old, don't deliver and fire a delete off - if notif.expired(at_time=now): - if not notif.sortkey_timestamp: - # Delete non-timestamped messages - self.force_retry(self.ps.message.delete_message, notif) - - # nocover here as coverage gets confused on the line below - # for unknown reasons - continue # pragma: nocover - - self.ps.updates_sent[str(notif.channel_id)].append(notif) - msg = notif.websocket_format() - messages_sent = True - self.sent_notification_count += 1 - if self.sent_notification_count > self.conf.msg_limit: - raise MessageOverloadException() - if notif.topic: - self.metrics.increment("ua.notification.topic") - self.metrics.increment('ua.message_data', len(msg.get('data', '')), - tags=make_tags(source=notif.source)) - self.sendJSON(msg) - - # Did we send any messages? - if messages_sent: - return + message = self.db.message_table(self.ps.message_month) + session = self.db.sessions.fetch() + try: + for notif in notifs: + self.ps.stats.stored_retrieved += 1 + # If the TTL is too old, don't deliver and fire a delete off + if notif.expired(at_time=now): + if not notif.sortkey_timestamp: + # Delete non-timestamped messages + self.force_retry(message.delete_message, + notif, + boto_session=session) + + # nocover here as coverage gets confused on the line below + # for unknown reasons + continue # pragma: nocover + + self.ps.updates_sent[str(notif.channel_id)].append(notif) + msg = notif.websocket_format() + messages_sent = True + self.sent_notification_count += 1 + if self.sent_notification_count > self.conf.msg_limit: + raise MessageOverloadException() + if notif.topic: + self.metrics.increment("ua.notification.topic") + self.metrics.increment('ua.message_data', + len(msg.get('data', '')), + tags=make_tags(source=notif.source)) + self.sendJSON(msg) + + # Did we send any messages? + if messages_sent: + return - # No messages sent, update the record if needed - if self.ps.current_timestamp: - self.force_retry( - self.ps.message.update_last_message_read, - self.ps.uaid_obj, self.ps.current_timestamp) + # No messages sent, update the record if needed + if self.ps.current_timestamp: + self.force_retry( + message.update_last_message_read, + self.ps.uaid_obj, + self.ps.current_timestamp, + boto_session=session + ) + finally: + self.db.sessions.release(session) # Schedule a new process check self.check_missed_notifications(None) @@ -1047,17 +1086,23 @@ def _monthly_transition(self): """ # Get the current channels for this month - _, channels = self.ps.message.all_channels(self.ps.uaid) + message = self.db.message_table(self.ps.message_month) + _, channels = message.all_channels(self.ps.uaid) # Get the current message month cur_month = self.db.current_msg_month if channels: # Save the current channels into this months message table - msg_table = self.db.message_tables[cur_month] + msg_table = self.db.message_table(cur_month) msg_table.save_channels(self.ps.uaid, channels) # Finally, update the route message month - self.db.router.update_message_month(self.ps.uaid, cur_month) + session = self.db.sessions.fetch() + try: + self.db.router.update_message_month(self.ps.uaid, cur_month, + boto_session=session) + finally: + self.db.sessions.release(session) def _finish_monthly_transition(self, result): """Mark the client as successfully transitioned and resume""" @@ -1137,12 +1182,13 @@ def error_register(self, fail): def finish_register(self, endpoint, chid): """callback for successful endpoint creation, sends register reply""" - d = self.deferToThread(self.ps.message.register_channel, self.ps.uaid, + message = self.db.message_table(self.ps.message_month) + d = self.deferToThread(message.register_channel, self.ps.uaid, chid) d.addCallback(self.send_register_finish, endpoint, chid) # Note: No trap_cancel needed here since the deferred here is # returned to process_register which will trap it - d.addErrback(self.err_overload, "register", disconnect=False) + d.addErrback(self.error_overload, "register", disconnect=False) return d def send_register_finish(self, result, endpoint, chid): @@ -1180,7 +1226,8 @@ def process_unregister(self, data): self.ps.updates_sent[chid] = [] # Unregister the channel - self.force_retry(self.ps.message.unregister_channel, self.ps.uaid, + message = self.db.message_table(self.ps.message_month) + self.force_retry(message.unregister_channel, self.ps.uaid, chid) data["status"] = 200 @@ -1238,6 +1285,7 @@ def ver_filter(notif): **self.ps.raw_agent) self.ps.stats.stored_acked += 1 + message = self.db.message_table(self.ps.message_month) if msg.sortkey_timestamp: # Is this the last un-acked message we're waiting for? last_unacked = sum( @@ -1249,7 +1297,7 @@ def ver_filter(notif): # If it's the last message in the batch, or last un-acked # message d = self.force_retry( - self.ps.message.update_last_message_read, + message.update_last_message_read, self.ps.uaid_obj, self.ps.current_timestamp, ) d.addBoth(self._handle_webpush_update_remove, chid, msg) @@ -1260,7 +1308,7 @@ def ver_filter(notif): d = None else: # No sortkey_timestamp, so legacy/topic message, delete - d = self.force_retry(self.ps.message.delete_message, msg) + d = self.force_retry(message.delete_message, msg) # We don't remove the update until we know the delete ran # This is because we don't use range queries on dynamodb and # we need to make sure this notification is deleted from the diff --git a/autopush_rs/Cargo.lock b/autopush_rs/Cargo.lock index 7edf6c92..8f6db862 100644 --- a/autopush_rs/Cargo.lock +++ b/autopush_rs/Cargo.lock @@ -1,4 +1,4 @@ -[root] +[[package]] name = "autopush" version = "0.1.0" dependencies = [ diff --git a/tox.ini b/tox.ini index 5ab422e5..3163daab 100644 --- a/tox.ini +++ b/tox.ini @@ -1,10 +1,11 @@ [tox] -envlist = py27,pypy,flake8,py36-mypy +#envlist = py27,pypy,flake8,py36-mypy +envlist = pypy,flake8 [testenv] deps = -rtest-requirements.txt usedevelop = True -passenv = SKIP_INTEGRATION AWS_SHARED_CREDENTIALS_FILE BOTO_CONFIG +passenv = SKIP_INTEGRATION AWS_SHARED_CREDENTIALS_FILE BOTO_CONFIG AWS_LOCAL_DYNAMODB commands = nosetests {posargs} autopush install_command = pip install --pre {opts} {packages} @@ -12,7 +13,7 @@ install_command = pip install --pre {opts} {packages} [testenv:pypy] basepython = pypy # avoids pycrypto build issues w/ pypy + libgmp-dev or libmpir-dev -setenv = with_gmp=no +setenv = with_gmp=no AWS_LOCAL_DYNAMODB=http://localhost:8000 [testenv:flake8] commands = flake8 autopush