diff --git a/autopush/db.py b/autopush/db.py index bae7eaea..9557036d 100644 --- a/autopush/db.py +++ b/autopush/db.py @@ -50,6 +50,7 @@ ) import boto3 import botocore +from boto3 import Session from boto3.dynamodb.conditions import Key from boto3.exceptions import Boto3Error from botocore.exceptions import ClientError @@ -94,15 +95,15 @@ TRACK_DB_CALLS = False DB_CALLS = [] -# See https://botocore.readthedocs.io/en/stable/reference/config.html for -# additional config options -g_dynamodb = boto3.resource( - 'dynamodb', - config=botocore.config.Config( - region_name=os.getenv("AWS_REGION_NAME", "us-east-1") + +def get_session(): + session = boto3.session.Session() + return session.resource( + 'dynamodb', + config=botocore.config.Config( + region_name=os.getenv("AWS_REGION_NAME", "us-east-1") + ) ) -) -g_client = g_dynamodb.meta.client def get_month(delta=0): @@ -144,20 +145,21 @@ def make_rotating_tablename(prefix, delta=0, date=None): def create_rotating_message_table(prefix="message", delta=0, date=None, read_throughput=5, - write_throughput=5): - # type: (str, int, Optional[datetime.date], int, int) -> Table + write_throughput=5, + session=get_session()): + # type: (str, int, Optional[datetime.date], int, int, Session) -> Table """Create a new message table for webpush style message storage""" tablename = make_rotating_tablename(prefix, delta, date) try: - table = g_dynamodb.Table(tablename) + table = session.Table(tablename) if table.table_status == 'ACTIVE': # pragma nocover return table except ClientError as ex: if ex.response['Error']['Code'] != 'ResourceNotFoundException': # If we hit this, our boto3 is misconfigured and we need to bail. raise ex # pragma nocover - table = g_dynamodb.create_table( + table = session.create_table( TableName=tablename, KeySchema=[ { @@ -198,24 +200,28 @@ def create_rotating_message_table(prefix="message", delta=0, date=None, return table -def get_rotating_message_table(prefix="message", delta=0, date=None, - message_read_throughput=5, - message_write_throughput=5): - # type: (str, int, Optional[datetime.date], int, int) -> Table +def get_rotating_message_tablename(prefix="message", delta=0, date=None, + message_read_throughput=5, + message_write_throughput=5, + session=get_session()): + # type: (str, int, Optional[datetime.date], int, int, Session) -> str """Gets the message table for the current month.""" tablename = make_rotating_tablename(prefix, delta, date) - if not table_exists(tablename): - return create_rotating_message_table( + if not table_exists(tablename, session=session): + create_rotating_message_table( prefix=prefix, delta=delta, date=date, read_throughput=message_read_throughput, write_throughput=message_write_throughput, + session=session ) + return tablename else: - return g_dynamodb.Table(tablename) + return tablename def create_router_table(tablename="router", read_throughput=5, - write_throughput=5): + write_throughput=5, + session=get_session()): # type: (str, int, int) -> Table """Create a new router table @@ -231,7 +237,7 @@ def create_router_table(tablename="router", read_throughput=5, """ - table = g_dynamodb.create_table( + table = session.create_table( TableName=tablename, KeySchema=[ { @@ -290,20 +296,22 @@ def create_router_table(tablename="router", read_throughput=5, return table -def _drop_table(tablename): +def _drop_table(tablename, session=get_session()): try: - g_client.delete_table(TableName=tablename) + session.meta.client.delete_table(TableName=tablename) except ClientError: # pragma nocover pass -def _make_table(table_func, tablename, read_throughput, write_throughput): +def _make_table(table_func, tablename, read_throughput, write_throughput, + session=get_session()): # type: (Callable[[str, int, int], Table], str, int, int) -> Table """Private common function to make a table with a table func""" - if not table_exists(tablename): - return table_func(tablename, read_throughput, write_throughput) + if not table_exists(tablename, session): + return table_func(tablename, read_throughput, write_throughput, + session) else: - return g_dynamodb.Table(tablename) + return session.Table(tablename) def _expiry(ttl): @@ -311,7 +319,7 @@ def _expiry(ttl): def get_router_table(tablename="router", read_throughput=5, - write_throughput=5): + write_throughput=5, session=get_session()): # type: (str, int, int) -> Table """Get the main router table object @@ -320,7 +328,7 @@ def get_router_table(tablename="router", read_throughput=5, """ return _make_table(create_router_table, tablename, read_throughput, - write_throughput) + write_throughput, session=session) def preflight_check(message, router, uaid="deadbeef00000000deadbeef00000000"): @@ -331,10 +339,12 @@ def preflight_check(message, router, uaid="deadbeef00000000deadbeef00000000"): Failure to run correctly will raise an exception. """ + session = get_session() # Verify tables are ready for use if they just got created ready = False while not ready: - tbl_status = [x.table_status() for x in [message, router]] + tbl_status = [x.table_status(session=session) for x in [message, + router]] ready = all([status == "ACTIVE" for status in tbl_status]) if not ready: time.sleep(1) @@ -352,20 +362,20 @@ def preflight_check(message, router, uaid="deadbeef00000000deadbeef00000000"): message_id=message_id, ttl=60, ) - # Store a notification, fetch it, delete it - message.store_message(notif) - assert message.delete_message(notif) + message.store_message(notif, session=session) + assert message.delete_message(notif, session=session) # Store a router entry, fetch it, delete it router.register_user(dict(uaid=uaid.hex, node_id=node_id, connected_at=connected_at, - router_type="webpush")) - item = router.get_uaid(uaid.hex) + router_type="webpush"), + session=session) + item = router.get_uaid(uaid.hex, session=session) assert item.get("node_id") == node_id # Clean up the preflight data. - router.clear_node(item) - router.drop_user(uaid.hex) + router.clear_node(item, session=session) + router.drop_user(uaid.hex, session=session) def track_provisioned(func): @@ -428,8 +438,9 @@ def generate_last_connect_values(date): yield int(val) -def list_tables(client=g_client): +def list_tables(client=None, session=get_session()): """Return a list of the names of all DynamoDB tables.""" + client = session.meta.client start_table = None while True: if start_table: # pragma nocover @@ -443,33 +454,36 @@ def list_tables(client=g_client): break -def table_exists(tablename, client=None): +def table_exists(tablename, session=get_session()): """Determine if the specified Table exists""" - if not client: - client = g_client - return tablename in list_tables(client) + return tablename in list_tables(session) class Message(object): """Create a Message table abstraction on top of a DynamoDB Table object""" - def __init__(self, table, metrics, max_ttl=MAX_EXPIRY): - # type: (Table, IMetrics) -> None + def __init__(self, tablename, metrics, max_ttl=MAX_EXPIRY): + # type: (str, IMetrics, int) -> None """Create a new Message object - :param table: :class:`Table` object. + :param tablename: name of the table. :param metrics: Metrics object that implements the :class:`autopush.metrics.IMetrics` interface. + :param session: Session for thread """ - self.table = table + self.tablename = tablename self.metrics = metrics self._max_ttl = max_ttl - def table_status(self): - return self.table.table_status + def table(self, session=get_session()): + return session.Table(self.tablename) + + def table_status(self, session=get_session()): + return self.table(session).table_status @track_provisioned - def register_channel(self, uaid, channel_id, ttl=None): + def register_channel(self, uaid, channel_id, ttl=None, + session=get_session()): # type: (str, str, int) -> bool """Register a channel for a given uaid""" # Generate our update expression @@ -479,7 +493,7 @@ def register_channel(self, uaid, channel_id, ttl=None): ":channel_id": set([normalize_id(channel_id)]), ":expiry": _expiry(ttl) } - self.table.update_item( + self.table(session).update_item( Key={ 'uaid': hasher(uaid), 'chidmessageid': ' ', @@ -490,14 +504,15 @@ def register_channel(self, uaid, channel_id, ttl=None): return True @track_provisioned - def unregister_channel(self, uaid, channel_id, **kwargs): + def unregister_channel(self, uaid, channel_id, session=get_session(), + **kwargs): # type: (str, str, **str) -> bool """Remove a channel registration for a given uaid""" expr = "DELETE chids :channel_id" chid = normalize_id(channel_id) expr_values = {":channel_id": set([chid])} - response = self.table.update_item( + response = self.table(session).update_item( Key={ 'uaid': hasher(uaid), 'chidmessageid': ' ', @@ -516,14 +531,14 @@ def unregister_channel(self, uaid, channel_id, **kwargs): return False @track_provisioned - def all_channels(self, uaid): + def all_channels(self, uaid, session=get_session()): # type: (str) -> Tuple[bool, Set[str]] """Retrieve a list of all channels for a given uaid""" # Note: This only returns the chids associated with the UAID. # Functions that call store_message() would be required to # update that list as well using register_channel() - result = self.table.get_item( + result = self.table(session).get_item( Key={ 'uaid': hasher(uaid), 'chidmessageid': ' ', @@ -537,10 +552,10 @@ def all_channels(self, uaid): return True, result['Item'].get("chids", set([])) @track_provisioned - def save_channels(self, uaid, channels): + def save_channels(self, uaid, channels, session=get_session()): # type: (str, Set[str]) -> None """Save out a set of channels""" - self.table.put_item( + self.table(session).put_item( Item={ 'uaid': hasher(uaid), 'chidmessageid': ' ', @@ -550,7 +565,7 @@ def save_channels(self, uaid, channels): ) @track_provisioned - def store_message(self, notification): + def store_message(self, notification, session=get_session()): # type: (WebPushNotification) -> None """Stores a WebPushNotification in the message table""" item = dict( @@ -566,15 +581,15 @@ def store_message(self, notification): ) if notification.data: item['data'] = notification.data - self.table.put_item(Item=item) + self.table(session).put_item(Item=item) @track_provisioned - def delete_message(self, notification): + def delete_message(self, notification, session=get_session()): # type: (WebPushNotification) -> bool """Deletes a specific message""" if notification.update_id: try: - self.table.delete_item( + self.table(session).delete_item( Key={ 'uaid': hasher(notification.uaid.hex), 'chidmessageid': notification.sort_key @@ -588,7 +603,7 @@ def delete_message(self, notification): except ClientError: return False else: - self.table.delete_item( + self.table(session).delete_item( Key={ 'uaid': hasher(notification.uaid.hex), 'chidmessageid': notification.sort_key, @@ -600,6 +615,7 @@ def fetch_messages( self, uaid, # type: uuid.UUID limit=10, # type: int + session=get_session() ): # type: (...) -> Tuple[Optional[int], List[WebPushNotification]] """Fetches messages for a uaid @@ -609,7 +625,7 @@ def fetch_messages( """ # Eagerly fetches all results in the result set. - response = self.table.query( + response = self.table(session).query( KeyConditionExpression=(Key("uaid").eq(hasher(uaid.hex)) & Key('chidmessageid').lt('02')), ConsistentRead=True, @@ -636,6 +652,7 @@ def fetch_timestamp_messages( uaid, # type: uuid.UUID timestamp=None, # type: Optional[Union[int, str]] limit=10, # type: int + session=get_session(), # type: Session ): # type: (...) -> Tuple[Optional[int], List[WebPushNotification]] """Fetches timestamped messages for a uaid @@ -654,7 +671,7 @@ def fetch_timestamp_messages( else: sortkey = "01;" - response = self.table.query( + response = self.table(session).query( KeyConditionExpression=(Key('uaid').eq(hasher(uaid.hex)) & Key('chidmessageid').gt(sortkey)), ConsistentRead=True, @@ -671,13 +688,13 @@ def fetch_timestamp_messages( return last_position, notifs @track_provisioned - def update_last_message_read(self, uaid, timestamp): - # type: (uuid.UUID, int) -> bool + def update_last_message_read(self, uaid, timestamp, session=get_session()): + # type: (uuid.UUID, int, Session) -> bool """Update the last read timestamp for a user""" expr = "SET current_timestamp=:timestamp, expiry=:expiry" expr_values = {":timestamp": timestamp, ":expiry": _expiry(self._max_ttl)} - self.table.update_item( + self.table(session).update_item( Key={ "uaid": hasher(uaid.hex), "chidmessageid": " " @@ -690,8 +707,8 @@ def update_last_message_read(self, uaid, timestamp): class Router(object): """Create a Router table abstraction on top of a DynamoDB Table object""" - def __init__(self, table, metrics, max_ttl=MAX_EXPIRY): - # type: (Table, IMetrics) -> None + def __init__(self, conf, metrics, max_ttl=MAX_EXPIRY): + # type: (dict, IMetrics) -> None """Create a new Router object :param table: :class:`Table` object. @@ -699,14 +716,23 @@ def __init__(self, table, metrics, max_ttl=MAX_EXPIRY): :class:`autopush.metrics.IMetrics` interface. """ - self.table = table + self.conf = conf self.metrics = metrics self._max_ttl = max_ttl + self._session = None + self._cached_table = None + + def table(self, session=get_session()): + if self._cached_table and session == self._session: + return self._cached_table + self._session = session + self._cached_table = get_router_table(session=session, **self.conf) + return self._cached_table - def table_status(self): - return self.table.table_status + def table_status(self, session=get_session()): + return self.table(session).table_status - def get_uaid(self, uaid): + def get_uaid(self, uaid, session=get_session()): # type: (str) -> Item """Get the database record for the UAID @@ -717,7 +743,7 @@ def get_uaid(self, uaid): """ try: - item = self.table.get_item( + item = self.table(session).get_item( Key={ 'uaid': hasher(uaid) }, @@ -741,7 +767,7 @@ def get_uaid(self, uaid): raise ItemNotFound("uaid not found") @track_provisioned - def register_user(self, data): + def register_user(self, data, session=get_session()): # type: (ItemLike) -> Tuple[bool, Dict[str, Any]] """Register this user @@ -772,7 +798,7 @@ def register_user(self, data): attribute_not_exists(node_id) or (connected_at < :connected_at) )""" - result = self.table.update_item( + result = self.table(session).update_item( Key=db_key, UpdateExpression=expr, ConditionExpression=cond, @@ -783,7 +809,7 @@ def register_user(self, data): r = {} for key, value in result["Attributes"].items(): try: - r[key] = self.table._dynamizer.decode(value) + r[key] = self.table(session)._dynamizer.decode(value) except (TypeError, AttributeError): # pragma: nocover # Included for safety as moto has occasionally made # this not work @@ -799,13 +825,13 @@ def register_user(self, data): raise @track_provisioned - def drop_user(self, uaid): + def drop_user(self, uaid, session=get_session()): # type: (str) -> bool """Drops a user record""" # The following hack ensures that only uaids that exist and are # deleted return true. try: - item = self.table.get_item( + item = self.table(session).get_item( Key={ 'uaid': hasher(uaid) }, @@ -815,18 +841,18 @@ def drop_user(self, uaid): return False except ClientError: pass - result = self.table.delete_item(Key={'uaid': hasher(uaid)}) + result = self.table(session).delete_item(Key={'uaid': hasher(uaid)}) return result['ResponseMetadata']['HTTPStatusCode'] == 200 - def delete_uaids(self, uaids): + def delete_uaids(self, uaids, session=get_session()): # type: (List[str]) -> None """Issue a batch delete call for the given uaids""" - with self.table.batch_writer() as batch: + with self.table(session).batch_writer() as batch: for uaid in uaids: batch.delete_item(Key={'uaid': uaid}) - def drop_old_users(self, months_ago=2): - # type: (int) -> Iterable[int] + def drop_old_users(self, months_ago=2, session=get_session()): + # type: (int, Session) -> Iterable[int] """Drops user records that have no recent connection Utilizes the last_connect index to locate users that haven't @@ -847,6 +873,7 @@ def drop_old_users(self, months_ago=2): quickly as possible. :param months_ago: how many months ago since the last connect + :param session: Session for thread :returns: Iterable of how many deletes were run @@ -855,7 +882,7 @@ def drop_old_users(self, months_ago=2): batched = [] for hash_key in generate_last_connect_values(prior_date): - response = self.table.query( + response = self.table(session).query( KeyConditionExpression=Key("last_connect").eq(hash_key), IndexName="AccessIndex", ) @@ -874,16 +901,16 @@ def drop_old_users(self, months_ago=2): yield len(batched) @track_provisioned - def _update_last_connect(self, uaid, last_connect): - self.table.update_item( + def _update_last_connect(self, uaid, last_connect, session=get_session()): + self.table(session).update_item( Key={"uaid": hasher(uaid)}, UpdateExpression="SET last_connect=:last_connect", ExpressionAttributeValues={":last_connect": last_connect} ) @track_provisioned - def update_message_month(self, uaid, month): - # type: (str, str) -> bool + def update_message_month(self, uaid, month, session=get_session()): + # type: (str, str, Session) -> bool """Update the route tables current_message_month Note that we also update the last_connect at this point since webpush @@ -899,7 +926,7 @@ def update_message_month(self, uaid, month): ":last_connect": generate_last_connect(), ":expiry": _expiry(self._max_ttl), } - self.table.update_item( + self.table(session).update_item( Key=db_key, UpdateExpression=expr, ExpressionAttributeValues=expr_values, @@ -907,8 +934,8 @@ def update_message_month(self, uaid, month): return True @track_provisioned - def clear_node(self, item): - # type: (dict) -> bool + def clear_node(self, item, session=get_session()): + # type: (dict, Session) -> bool """Given a router item and remove the node_id The node_id will only be cleared if the ``connected_at`` matches up @@ -926,7 +953,7 @@ def clear_node(self, item): try: cond = "(node_id = :node) and (connected_at = :conn)" - self.table.put_item( + self.table(session).put_item( Item=item, ConditionExpression=cond, ExpressionAttributeValues={ @@ -953,11 +980,10 @@ class DatabaseManager(object): metrics = attrib() # type: IMetrics router = attrib(default=None) # type: Optional[Router] - message_tables = attrib(default=Factory(dict)) # type: Dict[str, Message] + message_tables = attrib(default=Factory(list)) # type: List[str] current_msg_month = attrib(init=False) # type: Optional[str] current_month = attrib(init=False) # type: Optional[int] # for testing: - client = attrib(default=g_client) # type: Optional[Any] def __attrs_post_init__(self): """Initialize sane defaults""" @@ -990,8 +1016,8 @@ def setup(self, preflight_uaid): def setup_tables(self): """Lookup or create the database tables""" self.router = Router( - get_router_table(**asdict(self._router_conf)), - self.metrics + conf=self._router_conf, + metrics=self.metrics ) # Used to determine whether a connection is out of date with current # db objects. There are three noteworty cases: @@ -1007,13 +1033,7 @@ def setup_tables(self): def message(self): # type: () -> Message """Property that access the current message table""" - return self.message_tables[self.current_msg_month] - - @message.setter - def message(self, value): - # type: (Message) -> None - """Setter to set the current message table""" - self.message_tables[self.current_msg_month] = value + return Message(self.current_msg_month, self.metrics) def _tomorrow(self): # type: () -> datetime.date @@ -1026,34 +1046,34 @@ def create_initial_message_tables(self): an entry for tomorrow, if tomorrow is a new month. """ + session = get_session() mconf = self._message_conf today = datetime.date.today() - last_month = get_rotating_message_table( + last_month = get_rotating_message_tablename( prefix=mconf.tablename, delta=-1, message_read_throughput=mconf.read_throughput, - message_write_throughput=mconf.write_throughput + message_write_throughput=mconf.write_throughput, + session=session, ) - this_month = get_rotating_message_table( + this_month = get_rotating_message_tablename( prefix=mconf.tablename, message_read_throughput=mconf.read_throughput, - message_write_throughput=mconf.write_throughput + message_write_throughput=mconf.write_throughput, + session=session, ) self.current_month = today.month - self.current_msg_month = this_month.table_name - self.message_tables = { - last_month.table_name: Message(last_month, self.metrics), - this_month.table_name: Message(this_month, self.metrics) - } + self.current_msg_month = this_month + self.message_tables = [last_month, this_month] if self._tomorrow().month != today.month: - next_month = get_rotating_message_table( + next_month = get_rotating_message_tablename( prefix=mconf.tablename, delta=1, message_read_throughput=mconf.read_throughput, - message_write_throughput=mconf.write_throughput + message_write_throughput=mconf.write_throughput, + session=session, ) - self.message_tables[next_month.table_name] = Message( - next_month, self.metrics) + self.message_tables.append(next_month) @inlineCallbacks def update_rotating_tables(self): @@ -1071,16 +1091,14 @@ def update_rotating_tables(self): if ((tomorrow.month != today.month) and sorted(self.message_tables.keys())[-1] != tomorrow.month): next_month = yield deferToThread( - get_rotating_message_table, + get_rotating_message_tablename, prefix=mconf.tablename, delta=0, date=tomorrow, message_read_throughput=mconf.read_throughput, message_write_throughput=mconf.write_throughput ) - self.message_tables[next_month.table_name] = Message( - next_month, self.metrics) - + self.message_tables.append(next_month) if today.month == self.current_month: # No change in month, we're fine. returnValue(False) @@ -1088,7 +1106,7 @@ def update_rotating_tables(self): # Get tables for the new month, and verify they exist before we try to # switch over message_table = yield deferToThread( - get_rotating_message_table, + get_rotating_message_tablename, prefix=mconf.tablename, message_read_throughput=mconf.read_throughput, message_write_throughput=mconf.write_throughput @@ -1096,7 +1114,5 @@ def update_rotating_tables(self): # Both tables found, safe to switch-over self.current_month = today.month - self.current_msg_month = message_table.table_name - self.message_tables[self.current_msg_month] = Message( - message_table, self.metrics) + self.current_msg_month = message_table returnValue(True) diff --git a/autopush/tests/test_db.py b/autopush/tests/test_db.py index 920565dc..866dbfe8 100644 --- a/autopush/tests/test_db.py +++ b/autopush/tests/test_db.py @@ -11,7 +11,7 @@ import pytest from autopush.db import ( - get_rotating_message_table, + get_rotating_message_tablename, get_router_table, create_router_table, preflight_check, @@ -21,7 +21,8 @@ generate_last_connect, make_rotating_tablename, _drop_table, - _make_table) + _make_table, + get_session) from autopush.exceptions import AutopushException from autopush.metrics import SinkMetrics from autopush.utils import WebPushNotification @@ -44,17 +45,18 @@ def make_webpush_notification(uaid, chid, ttl=100): class DbUtilsTest(unittest.TestCase): def test_make_table(self): + fake_session = Mock() fake_func = Mock() fake_table = "DoesNotExist_{}".format(uuid.uuid4()) - _make_table(fake_func, fake_table, 5, 10) - assert fake_func.call_args[0] == (fake_table, 5, 10) + _make_table(fake_func, fake_table, 5, 10, session=fake_session) + assert fake_func.call_args[0] == (fake_table, 5, 10, fake_session) class DbCheckTestCase(unittest.TestCase): def test_preflight_check_fail(self): router = Router(get_router_table(), SinkMetrics()) - message = Message(get_rotating_message_table(), SinkMetrics()) + message = Message(get_rotating_message_tablename(), SinkMetrics()) def raise_exc(*args, **kwargs): # pragma: no cover raise Exception("Oops") @@ -66,8 +68,8 @@ def raise_exc(*args, **kwargs): # pragma: no cover preflight_check(message, router) def test_preflight_check(self): - router = Router(get_router_table(), SinkMetrics()) - message = Message(get_rotating_message_table(), SinkMetrics()) + router = Router({}, SinkMetrics()) + message = Message(get_rotating_message_tablename(), SinkMetrics()) pf_uaid = "deadbeef00000000deadbeef01010101" preflight_check(message, router, pf_uaid) @@ -78,8 +80,9 @@ def test_preflight_check(self): router.get_uaid(pf_uaid) def test_preflight_check_wait(self): - router = Router(get_router_table(), SinkMetrics()) - message = Message(get_rotating_message_table(), SinkMetrics()) + router = Router({}, SinkMetrics()) + message = Message(get_rotating_message_tablename(), + SinkMetrics()) values = ["PENDING", "ACTIVE"] message.table_status = Mock(side_effect=values) @@ -126,7 +129,8 @@ def test_normalize_id(self): class MessageTestCase(unittest.TestCase): def setUp(self): - table = get_rotating_message_table() + table = get_rotating_message_tablename() + self._session = get_session() self.real_table = table self.uaid = str(uuid.uuid4()) @@ -135,12 +139,13 @@ def tearDown(self): def test_register(self): chid = str(uuid.uuid4()) - m = get_rotating_message_table() - message = Message(m, SinkMetrics()) - message.register_channel(self.uaid, chid) + m = get_rotating_message_tablename() + message = Message(m, SinkMetrics()) + message.register_channel(self.uaid, chid, session=self._session) + lm = self._session.Table(m) # Verify it's in the db - response = m.query( + response = lm.query( KeyConditions={ 'uaid': { 'AttributeValueList': [self.uaid], @@ -157,12 +162,14 @@ def test_register(self): def test_unregister(self): chid = str(uuid.uuid4()) - m = get_rotating_message_table() + m = get_rotating_message_tablename() message = Message(m, SinkMetrics()) message.register_channel(self.uaid, chid) # Verify its in the db - response = m.query( + lm = self._session.Table(m) + # Verify it's in the db + response = lm.query( KeyConditions={ 'uaid': { 'AttributeValueList': [self.uaid], @@ -179,10 +186,10 @@ def test_unregister(self): assert len(results) == 1 assert results[0]["chids"] == {chid} - message.unregister_channel(self.uaid, chid) + message.unregister_channel(self.uaid, chid, session=self._session) # Verify its not in the db - response = m.query( + response = lm.query( KeyConditions={ 'uaid': { 'AttributeValueList': [self.uaid], @@ -200,48 +207,52 @@ def test_unregister(self): assert results[0].get("chids") is None # Test for the very unlikely case that there's no 'chid' - m.update_item = Mock(return_value={ + mtable = Mock() + mtable.update_item = Mock(return_value={ 'Attributes': {'uaid': self.uaid}, 'ResponseMetaData': {} }) - r = message.unregister_channel(self.uaid, dummy_chid) + message.table = Mock(return_value=mtable) + r = message.unregister_channel(self.uaid, dummy_chid, + session=self._session) assert r is False def test_all_channels(self): chid = str(uuid.uuid4()) chid2 = str(uuid.uuid4()) - m = get_rotating_message_table() + m = get_rotating_message_tablename() message = Message(m, SinkMetrics()) message.register_channel(self.uaid, chid) message.register_channel(self.uaid, chid2) - _, chans = message.all_channels(self.uaid) + _, chans = message.all_channels(self.uaid, session=self._session) assert chid in chans assert chid2 in chans - message.unregister_channel(self.uaid, chid2) - _, chans = message.all_channels(self.uaid) + message.unregister_channel(self.uaid, chid2, session=self._session) + _, chans = message.all_channels(self.uaid, session=self._session) assert chid2 not in chans assert chid in chans def test_all_channels_fail(self): - m = get_rotating_message_table() + + m = get_rotating_message_tablename() message = Message(m, SinkMetrics()) - message.table.get_item = Mock() - message.table.get_item.return_value = { + mtable = Mock() + mtable.get_item.return_value = { "ResponseMetadata": { "HTTPStatusCode": 400 }, } - + message.table = Mock(return_value=mtable) res = message.all_channels(self.uaid) assert res == (False, set([])) def test_save_channels(self): chid = str(uuid.uuid4()) chid2 = str(uuid.uuid4()) - m = get_rotating_message_table() + m = get_rotating_message_tablename() message = Message(m, SinkMetrics()) message.register_channel(self.uaid, chid) message.register_channel(self.uaid, chid2) @@ -253,7 +264,7 @@ def test_save_channels(self): assert chans == new_chans def test_all_channels_no_uaid(self): - m = get_rotating_message_table() + m = get_rotating_message_tablename() message = Message(m, SinkMetrics()) exists, chans = message.all_channels(dummy_uaid) assert chans == set([]) @@ -261,7 +272,7 @@ def test_all_channels_no_uaid(self): def test_message_storage(self): chid = str(uuid.uuid4()) chid2 = str(uuid.uuid4()) - m = get_rotating_message_table() + m = get_rotating_message_tablename() message = Message(m, SinkMetrics()) message.register_channel(self.uaid, chid) message.register_channel(self.uaid, chid2) @@ -283,7 +294,7 @@ def test_message_storage_overwrite(self): notif2 = make_webpush_notification(self.uaid, chid) notif3 = make_webpush_notification(self.uaid, chid2) notif2.message_id = notif1.message_id - m = get_rotating_message_table() + m = get_rotating_message_tablename() message = Message(m, SinkMetrics()) message.register_channel(self.uaid, chid) message.register_channel(self.uaid, chid2) @@ -298,14 +309,15 @@ def test_message_storage_overwrite(self): def test_message_delete_fail_condition(self): notif = make_webpush_notification(dummy_uaid, dummy_chid) notif.message_id = notif.update_id = dummy_uaid - m = get_rotating_message_table() + m = get_rotating_message_tablename() message = Message(m, SinkMetrics()) def raise_condition(*args, **kwargs): raise ClientError({}, 'delete_item') - message.table = Mock() - message.table.delete_item.side_effect = raise_condition + m_de = Mock() + m_de.delete_item = Mock(side_effect=raise_condition) + message.table = Mock(return_value=m_de) result = message.delete_message(notif) assert result is False @@ -314,8 +326,8 @@ def test_message_rotate_table_with_date(self): future = (datetime.today() + timedelta(days=32)).date() tbl_name = make_rotating_tablename(prefix, date=future) - m = get_rotating_message_table(prefix=prefix, date=future) - assert m.table_name == tbl_name + m = get_rotating_message_tablename(prefix=prefix, date=future) + assert m == tbl_name # Clean up the temp table. _drop_table(tbl_name) @@ -323,13 +335,7 @@ def test_message_rotate_table_with_date(self): class RouterTestCase(unittest.TestCase): @classmethod def setup_class(self): - table = get_router_table() - self.real_table = table - self.real_connection = table.meta.client - - @classmethod - def teardown_class(self): - self.real_table.meta.client = self.real_connection + self._session = get_session() def _create_minimal_record(self): data = { @@ -342,23 +348,24 @@ def _create_minimal_record(self): def test_drop_old_users(self): # First create a bunch of users - r = get_router_table() - router = Router(r, SinkMetrics()) + router = Router({}, SinkMetrics()) # Purge any existing users from previous runs. - router.drop_old_users(0) + router.drop_old_users(0, session=self._session) for _ in range(0, 53): - router.register_user(self._create_minimal_record()) + router.register_user(self._create_minimal_record(), + session=self._session) - results = router.drop_old_users(months_ago=0) + results = router.drop_old_users(months_ago=0, + session=self._session) assert list(results) == [25, 25, 3] def test_custom_tablename(self): db_name = "router_%s" % uuid.uuid4() - assert not table_exists(db_name) - create_router_table(db_name) - assert table_exists(db_name) + assert not table_exists(db_name, session=self._session) + create_router_table(db_name, session=self._session) + assert table_exists(db_name, session=self._session) # Clean up the temp table. - _drop_table(db_name) + _drop_table(db_name, session=self._session) def test_provisioning(self): db_name = "router_%s" % uuid.uuid4() @@ -369,114 +376,114 @@ def test_provisioning(self): def test_no_uaid_found(self): uaid = str(uuid.uuid4()) - r = get_router_table() - router = Router(r, SinkMetrics()) + router = Router({}, SinkMetrics()) with pytest.raises(ItemNotFound): - router.get_uaid(uaid) + router.get_uaid(uaid, session=self._session) def test_uaid_provision_failed(self): - r = get_router_table() - router = Router(r, SinkMetrics()) + router = Router({}, SinkMetrics()) router.table = Mock() def raise_condition(*args, **kwargs): - import autopush.db - raise autopush.db.g_client.exceptions.ClientError( + raise ClientError( {'Error': {'Code': 'ProvisionedThroughputExceededException'}}, 'mock_update_item' ) - router.table.get_item.side_effect = raise_condition + mm = Mock() + mm.get_item = Mock(side_effect=raise_condition) + router.table = Mock(return_value=mm) with pytest.raises(ClientError) as ex: - router.get_uaid(uaid="asdf") + router.get_uaid(uaid="asdf", session=self._session) assert (ex.value.response['Error']['Code'] == "ProvisionedThroughputExceededException") def test_register_user_provision_failed(self): - r = get_router_table() - router = Router(r, SinkMetrics()) - router.table.meta.client = Mock() + router = Router({}, SinkMetrics()) + mm = Mock() + mm.client = Mock() + router.table = Mock(return_value=mm) def raise_condition(*args, **kwargs): - import autopush.db - raise autopush.db.g_client.exceptions.ClientError( + raise ClientError( {'Error': {'Code': 'ProvisionedThroughputExceededException'}}, 'mock_update_item' ) - router.table.update_item = Mock(side_effect=raise_condition) + mm.update_item = Mock(side_effect=raise_condition) with pytest.raises(ClientError) as ex: router.register_user(dict(uaid=dummy_uaid, node_id="me", connected_at=1234, - router_type="webpush")) + router_type="webpush"), + session=self._session) assert (ex.value.response['Error']['Code'] == "ProvisionedThroughputExceededException") def test_register_user_condition_failed(self): - r = get_router_table() - router = Router(r, SinkMetrics()) - router.table.meta.client = Mock() + router = Router({}, SinkMetrics()) + router.table(self._session).meta.client = Mock() def raise_error(*args, **kwargs): - import autopush.db - raise autopush.db.g_client.exceptions.ClientError( + raise ClientError( {'Error': {'Code': 'ConditionalCheckFailedException'}}, 'mock_update_item' ) - - router.table.update_item = Mock(side_effect=raise_error) + mm = Mock() + mm.update_item = Mock(side_effect=raise_error) + router.table = Mock(return_value=mm) res = router.register_user(dict(uaid=dummy_uaid, node_id="me", connected_at=1234, - router_type="webpush")) + router_type="webpush"), + session=self._session) assert res == (False, {}) def test_clear_node_provision_failed(self): - r = get_router_table() - router = Router(r, SinkMetrics()) + router = Router({}, SinkMetrics()) def raise_condition(*args, **kwargs): - import autopush.db - raise autopush.db.g_client.exceptions.ClientError( + raise ClientError( {'Error': {'Code': 'ProvisionedThroughputExceededException'}}, 'mock_update_item' ) - router.table.put_item = Mock(side_effect=raise_condition) + mm = Mock() + mm.put_item = Mock(side_effect=raise_condition) + router.table = Mock(return_value=mm) with pytest.raises(ClientError) as ex: router.clear_node(dict(uaid=dummy_uaid, connected_at="1234", node_id="asdf", - router_type="webpush")) + router_type="webpush"), + session=self._session) assert (ex.value.response['Error']['Code'] == "ProvisionedThroughputExceededException") def test_clear_node_condition_failed(self): - r = get_router_table() - router = Router(r, SinkMetrics()) + router = Router({}, SinkMetrics()) def raise_error(*args, **kwargs): - import autopush.db - raise autopush.db.g_client.exceptions.ClientError( + raise ClientError( {'Error': {'Code': 'ConditionalCheckFailedException'}}, 'mock_put_item' ) - router.table.put_item = Mock(side_effect=raise_error) + router.table(self._session).put_item = Mock(side_effect=raise_error) res = router.clear_node(dict(uaid=dummy_uaid, connected_at="1234", node_id="asdf", - router_type="webpush")) + router_type="webpush"), + session=self._session) + assert res is False def test_incomplete_uaid(self): # Older records may be incomplete. We can't inject them using normal # methods. uaid = str(uuid.uuid4()) - r = get_router_table() - router = Router(r, SinkMetrics()) - router.table.get_item = Mock() - router.drop_user = Mock() - router.table.get_item.return_value = { + router = Router({}, SinkMetrics()) + mm = Mock() + mm.get_item = Mock() + mm.get_item.return_value = { "ResponseMetadata": { "HTTPStatusCode": 200 }, @@ -484,83 +491,89 @@ def test_incomplete_uaid(self): "uaid": uuid.uuid4().hex } } + mm.delete_item.return_value = { + "ResponseMetadata": { + "HTTPStatusCode": 200 + }, + } + router.table = Mock(return_value=mm) + router.drop_user = Mock() try: - router.register_user(dict(uaid=uaid)) + router.register_user(dict(uaid=uaid), + session=self._session) except AutopushException: pass with pytest.raises(ItemNotFound): - router.get_uaid(uaid) + router.get_uaid(uaid, session=self._session) assert router.drop_user.called def test_failed_uaid(self): uaid = str(uuid.uuid4()) - r = get_router_table() - router = Router(r, SinkMetrics()) - router.table.get_item = Mock() - router.drop_user = Mock() - router.table.get_item.return_value = { + router = Router({}, SinkMetrics()) + mm = Mock() + mm.get_item = Mock() + mm.get_item.return_value = { "ResponseMetadata": { "HTTPStatusCode": 400 }, } + router.table = Mock(return_value=mm) + router.drop_user = Mock() with pytest.raises(ItemNotFound): - router.get_uaid(uaid) + router.get_uaid(uaid, session=self._session) def test_save_new(self): - r = get_router_table() - router = Router(r, SinkMetrics()) + router = Router({}, SinkMetrics()) # Sadly, moto currently does not return an empty value like boto # when not updating data. - router.table.update_item = Mock(return_value={}) + router.table(self._session).update_item = Mock(return_value={}) result = router.register_user(dict(uaid=dummy_uaid, node_id="me", router_type="webpush", - connected_at=1234)) + connected_at=1234), + session=self._session) assert result[0] is True def test_save_fail(self): - r = get_router_table() - router = Router(r, SinkMetrics()) + router = Router({}, SinkMetrics()) def raise_condition(*args, **kwargs): - import autopush.db - raise autopush.db.g_client.exceptions.ClientError( + raise ClientError( {'Error': {'Code': 'ConditionalCheckFailedException'}}, 'mock_update_item' ) - router.table.update_item = Mock(side_effect=raise_condition) + router.table(self._session).update_item = Mock(side_effect=raise_condition) router_data = dict(uaid=dummy_uaid, node_id="asdf", connected_at=1234, router_type="webpush") - result = router.register_user(router_data) + result = router.register_user(router_data, session=self._session) assert result == (False, {}) def test_node_clear(self): - r = get_router_table() - router = Router(r, SinkMetrics()) + router = Router({}, SinkMetrics()) # Register a node user router.register_user(dict(uaid=dummy_uaid, node_id="asdf", connected_at=1234, - router_type="webpush")) + router_type="webpush"), + session=self._session) # Verify - user = router.get_uaid(dummy_uaid) + user = router.get_uaid(dummy_uaid, session=self._session) assert user["node_id"] == "asdf" assert user["connected_at"] == 1234 assert user["router_type"] == "webpush" # Clear - router.clear_node(user) + router.clear_node(user, session=self._session) # Verify - user = router.get_uaid(dummy_uaid) + user = router.get_uaid(dummy_uaid, session=self._session) assert user.get("node_id") is None assert user["connected_at"] == 1234 assert user["router_type"] == "webpush" def test_node_clear_fail(self): - r = get_router_table() - router = Router(r, SinkMetrics()) + router = Router({}, SinkMetrics()) def raise_condition(*args, **kwargs): raise ClientError( @@ -568,21 +581,21 @@ def raise_condition(*args, **kwargs): 'mock_update_item' ) - router.table.put_item = Mock(side_effect=raise_condition) + router.table(session=self._session).put_item = Mock( + side_effect=raise_condition) data = dict(uaid=dummy_uaid, node_id="asdf", connected_at=1234) - result = router.clear_node(data) + result = router.clear_node(data, session=self._session) assert result is False def test_drop_user(self): uaid = str(uuid.uuid4()) - r = get_router_table() - router = Router(r, SinkMetrics()) + router = Router({}, SinkMetrics()) # Register a node user router.register_user(dict(uaid=uaid, node_id="asdf", router_type="webpush", connected_at=1234)) - result = router.drop_user(uaid) + result = router.drop_user(uaid, session=self._session) assert result is True # Deleting already deleted record should return false. - result = router.drop_user(uaid) + result = router.drop_user(uaid, session=self._session) assert result is False diff --git a/autopush/tests/test_endpoint.py b/autopush/tests/test_endpoint.py index af7426aa..04226d3a 100644 --- a/autopush/tests/test_endpoint.py +++ b/autopush/tests/test_endpoint.py @@ -1,6 +1,7 @@ import json import uuid +import mock import twisted.internet.base from cryptography.fernet import Fernet, InvalidToken from mock import Mock, patch @@ -663,18 +664,20 @@ def test_delete_bad_router(self): assert resp.get_status() == 400 @inlineCallbacks - def test_get(self): + @mock.patch("autopush.db.Message") + def test_get(self, mm): chids = [str(dummy_chid), str(dummy_uaid)] - self.db.message.all_channels = Mock() - self.db.message.all_channels.return_value = (True, chids) + ma = Mock(return_value=(True, chids)) + mm.all_channels = ma resp = yield self.client.get( self.url(router_type="test", router_token="test", uaid=dummy_uaid.hex), headers={"Authorization": self.auth} ) - self.db.message.all_channels.assert_called_with(str(dummy_uaid)) + ma.assert_called_with(str(dummy_uaid)) payload = json.loads(resp.content) assert chids == payload['channelIDs'] assert dummy_uaid.hex == payload['uaid'] + diff --git a/autopush/web/health.py b/autopush/web/health.py index d28265b2..7f108438 100644 --- a/autopush/web/health.py +++ b/autopush/web/health.py @@ -40,7 +40,7 @@ def get(self): def _check_table(self, table, name_over=None): """Checks the tables known about in DynamoDB""" - d = deferToThread(table_exists, table.table_name, self.db.client) + d = deferToThread(table_exists, table.table_name) d.addCallback(self._check_success, name_over or table.table_name) d.addErrback(self._check_error, name_over or table.table_name) return d