diff --git a/autopush/base.py b/autopush/base.py index 39c85b5b..21b6ce67 100644 --- a/autopush/base.py +++ b/autopush/base.py @@ -1,21 +1,41 @@ import sys import uuid +from typing import TYPE_CHECKING import cyclone.web from twisted.logger import Logger from twisted.python import failure +if TYPE_CHECKING: # pragma: nocover + from autopush.db import DatabaseManager # noqa + from autopush.metrics import IMetrics # noqa + from autopush.settings import AutopushSettings # noqa + class BaseHandler(cyclone.web.RequestHandler): """Base cyclone RequestHandler for autopush""" log = Logger() - def initialize(self, ap_settings): - """Setup basic attributes from AutopushSettings""" - self.ap_settings = ap_settings + def initialize(self): + """Initialize info from the client""" self._client_info = self._init_info() + @property + def ap_settings(self): + # type: () -> AutopushSettings + return self.application.ap_settings + + @property + def db(self): + # type: () -> DatabaseManager + return self.application.db + + @property + def metrics(self): + # type: () -> IMetrics + return self.db.metrics + def _init_info(self): return dict( ami_id=self.ap_settings.ami_id, diff --git a/autopush/db.py b/autopush/db.py index 9631ef6c..7060598e 100644 --- a/autopush/db.py +++ b/autopush/db.py @@ -37,6 +37,11 @@ import uuid from functools import wraps +from attr import ( + attrs, + attrib, + Factory +) from boto.exception import JSONResponseError, BotoServerError from boto.dynamodb2.exceptions import ( ConditionalCheckFailedException, @@ -51,6 +56,7 @@ Any, Callable, Dict, + Generator, Iterable, List, Optional, @@ -58,7 +64,11 @@ TypeVar, Tuple, ) +from twisted.internet.defer import Deferred # noqa +from twisted.internet.defer import inlineCallbacks, returnValue +from twisted.internet.threads import deferToThread +import autopush.metrics from autopush.exceptions import AutopushException from autopush.metrics import IMetrics # noqa from autopush.types import ItemLike # noqa @@ -853,3 +863,136 @@ def clear_node(self, item): return True except ConditionalCheckFailedException: return False + + +@attrs +class DatabaseManager(object): + """Provides database access""" + + storage = attrib() # type: Storage + router = attrib() # type: Router + + metrics = attrib() # type: IMetrics + + message_tables = attrib(default=Factory(dict)) # type: Dict[str, Message] + current_msg_month = attrib(default=None) # type: Optional[str] + current_month = attrib(default=None) # type: Optional[int] + + _message_prefix = attrib(default="message") # type: str + + @classmethod + def from_settings(cls, settings): + router_table = get_router_table( + settings.router_tablename, + settings.router_read_throughput, + settings.router_write_throughput + ) + storage_table = get_storage_table( + settings.storage_tablename, + settings.storage_read_throughput, + settings.storage_write_throughput + ) + get_rotating_message_table( + settings.message_tablename, + message_read_throughput=settings.message_read_throughput, + message_write_throughput=settings.message_write_throughput + ) + metrics = autopush.metrics.from_settings(settings) + return cls( + storage=Storage(storage_table, metrics), + router=Router(router_table, metrics), + message_prefix=settings.message_tablename, + metrics=metrics + ) + + def setup(self, preflight_uaid): + # type: (str) -> None + """Setup metrics, message tables and perform preflight_check""" + self.metrics.start() + + # Used to determine whether a connection is out of date with current + # db objects. There are three noteworty cases: + # 1 "Last Month" the table requires a rollover. + # 2 "This Month" the most common case. + # 3 "Next Month" where the system will soon be rolling over, but with + # timing, some nodes may roll over sooner. Ensuring the next month's + # table is present before the switchover is the main reason for this, + # just in case some nodes do switch sooner. + self.create_initial_message_tables() + + preflight_check(self.storage, self.router, preflight_uaid) + + @property + def message(self): + # type: () -> Message + """Property that access the current message table""" + return self.message_tables[self.current_msg_month] + + @message.setter + def message(self, value): + # type: (Message) -> None + """Setter to set the current message table""" + self.message_tables[self.current_msg_month] = value + + def _tomorrow(self): + # type: () -> datetime.date + return datetime.date.today() + datetime.timedelta(days=1) + + def create_initial_message_tables(self): + """Initializes a dict of the initial rotating messages tables. + + An entry for last months table, an entry for this months table, + an entry for tomorrow, if tomorrow is a new month. + + """ + today = datetime.date.today() + last_month = get_rotating_message_table(self._message_prefix, -1) + this_month = get_rotating_message_table(self._message_prefix) + self.current_month = today.month + self.current_msg_month = this_month.table_name + self.message_tables = { + last_month.table_name: Message(last_month, self.metrics), + this_month.table_name: Message(this_month, self.metrics) + } + if self._tomorrow().month != today.month: + next_month = get_rotating_message_table(self._message_prefix, + delta=1) + self.message_tables[next_month.table_name] = Message( + next_month, self.metrics) + + @inlineCallbacks + def update_rotating_tables(self): + # type: () -> Generator + """This method is intended to be tasked to run periodically off the + twisted event hub to rotate tables. + + When today is a new month from yesterday, then we swap out all the + table objects on the settings object. + + """ + today = datetime.date.today() + tomorrow = self._tomorrow() + if ((tomorrow.month != today.month) and + sorted(self.message_tables.keys())[-1] != tomorrow.month): + next_month = yield deferToThread( + get_rotating_message_table, + self._message_prefix, 0, tomorrow + ) + self.message_tables[next_month.table_name] = Message( + next_month, self.metrics) + + if today.month == self.current_month: + # No change in month, we're fine. + returnValue(False) + + # Get tables for the new month, and verify they exist before we try to + # switch over + message_table = yield deferToThread(get_rotating_message_table, + self._message_prefix) + + # Both tables found, safe to switch-over + self.current_month = today.month + self.current_msg_month = message_table.table_name + self.message_tables[self.current_msg_month] = Message( + message_table, self.metrics) + returnValue(True) diff --git a/autopush/diagnostic_cli.py b/autopush/diagnostic_cli.py index 8810f568..a13be2ec 100644 --- a/autopush/diagnostic_cli.py +++ b/autopush/diagnostic_cli.py @@ -6,6 +6,7 @@ import configargparse from twisted.logger import Logger +from autopush.db import DatabaseManager from autopush.main import AutopushMultiService from autopush.main_argparse import add_shared_args from autopush.settings import AutopushSettings @@ -19,12 +20,15 @@ class EndpointDiagnosticCLI(object): def __init__(self, sysargs, use_files=True): args = self._load_args(sysargs, use_files) - self._settings = AutopushSettings( + self._settings = settings = AutopushSettings( crypto_key=args.crypto_key, router_tablename=args.router_tablename, storage_tablename=args.storage_tablename, message_tablename=args.message_tablename, + statsd_host=None, ) + self.db = DatabaseManager.from_settings(settings) + self.db.setup(settings.preflight_uaid) self._endpoint = args.endpoint self._pp = pprint.PrettyPrinter(indent=4) @@ -56,6 +60,7 @@ def run(self): api_ver, token = md.get("api_ver", "v1"), md["token"] parsed = self._settings.parse_endpoint( + self.db.metrics, token=token, version=api_ver, ) @@ -63,13 +68,13 @@ def run(self): print("UAID: {}\nCHID: {}\n".format(uaid, chid)) - rec = self._settings.router.get_uaid(uaid) + rec = self.db.router.get_uaid(uaid) print("Router record:") self._pp.pprint(rec._data) print("\n") mess_table = rec["current_month"] - chans = self._settings.message_tables[mess_table].all_channels(uaid) + chans = self.db.message_tables[mess_table].all_channels(uaid) print("Channels in message table:") self._pp.pprint(chans) diff --git a/autopush/http.py b/autopush/http.py index dd181771..9f391fa2 100644 --- a/autopush/http.py +++ b/autopush/http.py @@ -2,6 +2,7 @@ from typing import ( # noqa Any, Callable, + Dict, Optional, Sequence, Tuple, @@ -11,6 +12,9 @@ import cyclone.web from autopush.base import BaseHandler +from autopush.db import DatabaseManager +from autopush.router import routers_from_settings +from autopush.router.interface import IRouter # noqa from autopush.settings import AutopushSettings # noqa from autopush.ssl import AutopushSSLContextFactory from autopush.web.health import ( @@ -53,45 +57,47 @@ class BaseHTTPFactory(cyclone.web.Application): ) def __init__(self, - ap_settings, - handlers=None, - log_function=skip_request_logging, + ap_settings, # type: AutopushSettings + db, # type: DatabaseManager + routers, # type: Dict[str, IRouter] + handlers=None, # type: APHandlers + log_function=skip_request_logging, # type: CycloneLogger **kwargs): - # type: (AutopushSettings, APHandlers, CycloneLogger, **Any) -> None + # type: (...) -> None self.ap_settings = ap_settings + self.db = db + self.routers = routers self.noisy = ap_settings.debug cyclone.web.Application.__init__( self, + handlers=self.ap_handlers if handlers is None else handlers, default_host=self._hostname, debug=ap_settings.debug, log_function=log_function, **kwargs ) - self.add_ap_handlers( - self.ap_handlers if handlers is None else handlers) - - def add_ap_handlers(self, handlers): - # type: (APHandlers) -> None - """Add BaseHandlers w/ their appropriate handler kwargs""" - h_kwargs = dict(ap_settings=self.ap_settings) - self.add_handlers( - ".*$", - [(pattern, handler, h_kwargs) for pattern, handler in handlers] - ) def add_health_handlers(self): """Add the health check HTTP handlers""" - self.add_ap_handlers(self.health_ap_handlers) + self.add_handlers(".*$", self.health_ap_handlers) @property def _hostname(self): return self.ap_settings.hostname @classmethod - def for_handler(cls, handler_cls, *args, **kwargs): - # type: (Type[BaseHandler], *Any, **Any) -> BaseHTTPFactory - """Create a cyclone app around a specific handler_cls. + def for_handler(cls, + handler_cls, # Type[BaseHTTPFactory] + ap_settings, # type: AutopushSettings + db=None, # type: Optional[DatabaseManager] + routers=None, # type: Optional[Dict[str, IRouter]] + **kwargs): + # type: (...) -> BaseHTTPFactory + """Create a cyclone app around a specific handler_cls for tests. + + Creates an uninitialized (no setup() called) DatabaseManager + from settings if one isn't specified. handler_cls must be included in ap_handlers or a ValueError is thrown. @@ -101,7 +107,17 @@ def for_handler(cls, handler_cls, *args, **kwargs): raise ValueError("handler_cls incompatibile with handlers kwarg") for pattern, handler in cls.ap_handlers + cls.health_ap_handlers: if handler is handler_cls: - return cls(handlers=[(pattern, handler)], *args, **kwargs) + if db is None: + db = DatabaseManager.from_settings(ap_settings) + if routers is None: + routers = routers_from_settings(ap_settings, db) + return cls( + ap_settings, + db=db, + routers=routers, + handlers=[(pattern, handler)], + **kwargs + ) raise ValueError("{!r} not in ap_handlers".format( handler_cls)) # pragma: nocover diff --git a/autopush/main.py b/autopush/main.py index 235d14df..a76cd589 100644 --- a/autopush/main.py +++ b/autopush/main.py @@ -1,5 +1,4 @@ """autopush/autoendpoint daemon scripts""" -import json import os from argparse import Namespace # noqa @@ -17,20 +16,19 @@ Any, Optional, Sequence, - Union ) -import autopush.db as db from autopush.http import ( InternalRouterHTTPFactory, EndpointHTTPFactory, MemUsageHTTPFactory ) -import autopush.utils as utils from autopush.exceptions import InvalidSettings +from autopush.db import DatabaseManager from autopush.haproxy import HAProxyServerEndpoint from autopush.logging import PushLogger from autopush.main_argparse import parse_connection, parse_endpoint +from autopush.router import routers_from_settings from autopush.settings import AutopushSettings from autopush.websocket import ( ConnectionWSSite, @@ -41,115 +39,6 @@ log = Logger() -def make_settings(args, **kwargs): - """Helper function to make a :class:`AutopushSettings` object""" - router_conf = {} - if args.key_hash: - db.key_hash = args.key_hash - # Some routers require a websocket to timeout on idle (e.g. UDP) - if args.wake_pem is not None and args.wake_timeout != 0: - router_conf["simplepush"] = {"idle": args.wake_timeout, - "server": args.wake_server, - "cert": args.wake_pem} - if args.apns_creds: - # if you have the critical elements for each external router, create it - try: - router_conf["apns"] = json.loads(args.apns_creds) - except (ValueError, TypeError): - raise InvalidSettings( - "Invalid JSON specified for APNS config options") - if args.gcm_enabled: - # Create a common gcmclient - try: - sender_ids = json.loads(args.senderid_list) - except (ValueError, TypeError): - raise InvalidSettings("Invalid JSON specified for senderid_list") - try: - # This is an init check to verify that things are configured - # correctly. Otherwise errors may creep in later that go - # unaccounted. - sender_ids[sender_ids.keys()[0]] - except (IndexError, TypeError): - raise InvalidSettings("No GCM SenderIDs specified or found.") - router_conf["gcm"] = {"ttl": args.gcm_ttl, - "dryrun": args.gcm_dryrun, - "max_data": args.max_data, - "collapsekey": args.gcm_collapsekey, - "senderIDs": sender_ids} - - client_certs = None - # endpoint only - if getattr(args, 'client_certs', None): - try: - client_certs_arg = json.loads(args.client_certs) - except (ValueError, TypeError): - raise InvalidSettings("Invalid JSON specified for client_certs") - if client_certs_arg: - if not args.ssl_key: - raise InvalidSettings("client_certs specified without SSL " - "enabled (no ssl_key specified)") - client_certs = {} - for name, sigs in client_certs_arg.iteritems(): - if not isinstance(sigs, list): - raise InvalidSettings( - "Invalid JSON specified for client_certs") - for sig in sigs: - sig = sig.upper() - if (not name or not utils.CLIENT_SHA256_RE.match(sig) or - sig in client_certs): - raise InvalidSettings("Invalid client_certs argument") - client_certs[sig] = name - - if args.fcm_enabled: - # Create a common gcmclient - if not args.fcm_auth: - raise InvalidSettings("No Authorization Key found for FCM") - if not args.fcm_senderid: - raise InvalidSettings("No SenderID found for FCM") - router_conf["fcm"] = {"ttl": args.fcm_ttl, - "dryrun": args.fcm_dryrun, - "max_data": args.max_data, - "collapsekey": args.fcm_collapsekey, - "auth": args.fcm_auth, - "senderid": args.fcm_senderid} - - ami_id = None - # Not a fan of double negatives, but this makes more understandable args - if not args.no_aws: - ami_id = utils.get_amid() - - return AutopushSettings( - crypto_key=args.crypto_key, - datadog_api_key=args.datadog_api_key, - datadog_app_key=args.datadog_app_key, - datadog_flush_interval=args.datadog_flush_interval, - hostname=args.hostname, - statsd_host=args.statsd_host, - statsd_port=args.statsd_port, - router_conf=router_conf, - router_tablename=args.router_tablename, - storage_tablename=args.storage_tablename, - storage_read_throughput=args.storage_read_throughput, - storage_write_throughput=args.storage_write_throughput, - message_tablename=args.message_tablename, - message_read_throughput=args.message_read_throughput, - message_write_throughput=args.message_write_throughput, - router_read_throughput=args.router_read_throughput, - router_write_throughput=args.router_write_throughput, - resolve_hostname=args.resolve_hostname, - wake_timeout=args.wake_timeout, - ami_id=ami_id, - client_certs=client_certs, - msg_limit=args.msg_limit, - connect_timeout=args.connection_timeout, - memusage_port=args.memusage_port, - ssl_key=args.ssl_key, - ssl_cert=args.ssl_cert, - ssl_dh_param=args.ssl_dh_param, - **kwargs - ) - - class AutopushMultiService(MultiService): shared_config_files = ( @@ -168,6 +57,7 @@ def __init__(self, settings): # type: (AutopushSettings) -> None super(AutopushMultiService, self).__init__() self.settings = settings + self.db = DatabaseManager.from_settings(settings) @staticmethod def parse_args(config_files, args): @@ -195,7 +85,7 @@ def add_timer(self, *args, **kwargs): def add_memusage(self): """Add the memusage Service""" - factory = MemUsageHTTPFactory(self.settings) + factory = MemUsageHTTPFactory(self.settings, None, None) self.addService( TCPServer(self.settings.memusage_port, factory, reactor=reactor)) @@ -211,7 +101,7 @@ def _from_argparse(cls, ns, **kwargs): """Create an instance from argparse/additional kwargs""" # Add some entropy to prevent potential conflicts. postfix = os.urandom(4).encode('hex').ljust(8, '0') - settings = make_settings( + settings = AutopushSettings.from_argparse( ns, debug=ns.debug, preflight_uaid="deadbeef000000000deadbeef" + postfix, @@ -261,8 +151,12 @@ class EndpointApplication(AutopushMultiService): endpoint_factory = EndpointHTTPFactory + def __init__(self, *args, **kwargs): + super(EndpointApplication, self).__init__(*args, **kwargs) + self.routers = routers_from_settings(self.settings, self.db) + def setup(self, rotate_tables=True): - self.settings.metrics.start() + self.db.setup(self.settings.preflight_uaid) self.add_endpoint() if self.settings.memusage_port: @@ -270,13 +164,13 @@ def setup(self, rotate_tables=True): # Start the table rotation checker/updater if rotate_tables: - self.add_timer(60, self.settings.update_rotating_tables) + self.add_timer(60, self.db.update_rotating_tables) def add_endpoint(self): """Start the Endpoint HTTP router""" settings = self.settings - factory = self.endpoint_factory(settings) + factory = self.endpoint_factory(settings, self.db, self.routers) factory.protocol.maxData = settings.max_data factory.add_health_handlers() ssl_cf = factory.ssl_cf() @@ -323,7 +217,7 @@ class ConnectionApplication(AutopushMultiService): websocket_site_factory = ConnectionWSSite def setup(self, rotate_tables=True): - self.settings.metrics.start() + self.db.setup(self.settings.preflight_uaid) self.add_internal_router() if self.settings.memusage_port: @@ -333,11 +227,11 @@ def setup(self, rotate_tables=True): # Start the table rotation checker/updater if rotate_tables: - self.add_timer(60, self.settings.update_rotating_tables) + self.add_timer(60, self.db.update_rotating_tables) def add_internal_router(self): """Start the internal HTTP notification router""" - factory = self.internal_router_factory(self.settings) + factory = self.internal_router_factory(self.settings, self.db, None) factory.add_health_handlers() self.add_maybe_ssl(self.settings.router_port, factory, factory.ssl_cf()) @@ -345,10 +239,11 @@ def add_internal_router(self): def add_websocket(self): """Start the public WebSocket server""" settings = self.settings - ws_factory = self.websocket_factory(settings) + ws_factory = self.websocket_factory(settings, self.db) site_factory = self.websocket_site_factory(settings, ws_factory) self.add_maybe_ssl(settings.port, site_factory, site_factory.ssl_cf()) - self.add_timer(1.0, periodic_reporter, settings, ws_factory) + self.add_timer(1.0, periodic_reporter, settings, self.db.metrics, + ws_factory) @classmethod def from_argparse(cls, ns): diff --git a/autopush/metrics.py b/autopush/metrics.py index b4375cbe..d05ee307 100644 --- a/autopush/metrics.py +++ b/autopush/metrics.py @@ -1,4 +1,6 @@ """Metrics interface and implementations""" +from typing import TYPE_CHECKING + from twisted.internet import reactor from txstatsd.client import StatsDClientProtocol, TwistedStatsDClient from txstatsd.metrics.metrics import Metrics @@ -6,6 +8,9 @@ import datadog from datadog import ThreadStats +if TYPE_CHECKING: # pragma: nocover + from autopush.settings import AutopushSettings # noqa + class IMetrics(object): """Metrics interface @@ -97,3 +102,19 @@ def gauge(self, name, count, **kwargs): def timing(self, name, duration, **kwargs): self._client.timing(self._prefix_name(name), value=duration, host=self._host, **kwargs) + + +def from_settings(settings): + # type: (AutopushSettings) -> IMetrics + """Create an IMetrics from the given settings""" + if settings.datadog_api_key: + return DatadogMetrics( + hostname=settings.hostname, + api_key=settings.datadog_api_key, + app_key=settings.datadog_app_key, + flush_interval=settings.datadog_flush_interval, + ) + elif settings.statsd_host: + return TwistedMetrics(settings.statsd_host, settings.statsd_port) + else: + return SinkMetrics() diff --git a/autopush/router/__init__.py b/autopush/router/__init__.py index e83db84c..a7387c8f 100644 --- a/autopush/router/__init__.py +++ b/autopush/router/__init__.py @@ -4,11 +4,32 @@ through the appropriate system for a given client. """ +from typing import Dict # noqa + +from autopush.db import DatabaseManager # noqa from autopush.router.apnsrouter import APNSRouter from autopush.router.gcm import GCMRouter +from autopush.router.interface import IRouter # noqa from autopush.router.simple import SimpleRouter from autopush.router.webpush import WebPushRouter from autopush.router.fcm import FCMRouter +from autopush.settings import AutopushSettings # noqa __all__ = ["APNSRouter", "FCMRouter", "GCMRouter", "SimpleRouter", "WebPushRouter"] + + +def routers_from_settings(settings, db): + # type: (AutopushSettings, DatabaseManager) -> Dict[str, IRouter] + """Create a dict of IRouters for the given settings""" + router_conf = settings.router_conf + routers = dict( + simplepush=SimpleRouter( + settings, router_conf.get("simplepush"), db), + webpush=WebPushRouter(settings, None, db) + ) + if 'apns' in router_conf: + routers["apns"] = APNSRouter(settings, router_conf["apns"], db.metrics) + if 'gcm' in router_conf: + routers["gcm"] = GCMRouter(settings, router_conf["gcm"], db.metrics) + return routers diff --git a/autopush/router/apnsrouter.py b/autopush/router/apnsrouter.py index ba1507e4..ee188cd1 100644 --- a/autopush/router/apnsrouter.py +++ b/autopush/router/apnsrouter.py @@ -44,10 +44,11 @@ def _connect(self, rel_channel, load_connections=True): APNS_MAX_CONNECTIONS), topic=cert_info.get("topic", default_topic), logger=self.log, - metrics=self.ap_settings.metrics, + metrics=self.metrics, load_connections=load_connections) - def __init__(self, ap_settings, router_conf, load_connections=True): + def __init__(self, ap_settings, router_conf, metrics, + load_connections=True): """Create a new APNS router and connect to APNS :param ap_settings: Configuration settings @@ -59,9 +60,10 @@ def __init__(self, ap_settings, router_conf, load_connections=True): """ self.ap_settings = ap_settings + self._config = router_conf + self.metrics = metrics self._base_tags = [] self.apns = dict() - self._config = router_conf for rel_channel in self._config: self.apns[rel_channel] = self._connect(rel_channel, load_connections) @@ -142,10 +144,8 @@ def _route(self, notification, router_data): apns_client.send(router_token=router_token, payload=payload, apns_id=apns_id) except (ConnectionError, AttributeError) as ex: - self.ap_settings.metrics.increment( - "updates.client.bridge.apns.connection_err", - self._base_tags - ) + self.metrics.increment("updates.client.bridge.apns.connection_err", + self._base_tags) self.log.error("Connection Error sending to APNS", log_failure=Failure(ex)) raise RouterException( @@ -165,10 +165,11 @@ def _route(self, notification, router_data): location = "%s/m/%s" % (self.ap_settings.endpoint_url, notification.version) - self.ap_settings.metrics.increment( + self.metrics.increment( "updates.client.bridge.apns.%s.sent" % router_data["rel_channel"], - self._base_tags) + self._base_tags + ) return RouterResponse(status_code=201, response_body="", headers={"TTL": notification.ttl, "Location": location}, diff --git a/autopush/router/fcm.py b/autopush/router/fcm.py index 6385eb42..8a53a89e 100644 --- a/autopush/router/fcm.py +++ b/autopush/router/fcm.py @@ -100,15 +100,16 @@ class FCMRouter(object): } } - def __init__(self, ap_settings, router_conf): + def __init__(self, ap_settings, router_conf, metrics): """Create a new FCM router and connect to FCM""" + self.ap_settings = ap_settings self.config = router_conf + self.metrics = metrics self.min_ttl = router_conf.get("ttl", 60) self.dryRun = router_conf.get("dryrun", False) self.collapseKey = router_conf.get("collapseKey", "webpush") self.senderID = router_conf.get("senderID") self.auth = router_conf.get("auth") - self.metrics = ap_settings.metrics self._base_tags = [] try: self.fcm = pyfcm.FCMNotification(api_key=self.auth) @@ -117,7 +118,6 @@ def __init__(self, ap_settings, router_conf): ex=e) raise IOError("FCM Bridge not initiated in main") self.log.debug("Starting FCM router...") - self.ap_settings = ap_settings def amend_endpoint_response(self, response, router_data): # type: (JSONDict, JSONDict) -> None diff --git a/autopush/router/gcm.py b/autopush/router/gcm.py index 1fe6f8be..cc692b0a 100644 --- a/autopush/router/gcm.py +++ b/autopush/router/gcm.py @@ -18,9 +18,11 @@ class GCMRouter(object): collapseKey = "simplepush" MAX_TTL = 2419200 - def __init__(self, ap_settings, router_conf): + def __init__(self, ap_settings, router_conf, metrics): """Create a new GCM router and connect to GCM""" + self.ap_settings = ap_settings self.config = router_conf + self.metrics = metrics self.min_ttl = router_conf.get("ttl", 60) self.dryRun = router_conf.get("dryrun", False) self.collapseKey = router_conf.get("collapseKey", "simplepush") @@ -36,11 +38,8 @@ def __init__(self, ap_settings, router_conf): self.gcm[sid] = gcmclient.GCM(auth) except: raise IOError("GCM Bridge not initiated in main") - self.metrics = ap_settings.metrics self._base_tags = [] - self.router_table = ap_settings.router self.log.debug("Starting GCM router...") - self.ap_settings = ap_settings def amend_endpoint_response(self, response, router_data): # type: (JSONDict, JSONDict) -> None diff --git a/autopush/router/interface.py b/autopush/router/interface.py index 3ab09c7f..5b39c75e 100644 --- a/autopush/router/interface.py +++ b/autopush/router/interface.py @@ -24,7 +24,7 @@ def __init__(self, status_code=200, response_body="", router_data=None, class IRouter(object): - def __init__(self, settings, router_conf): + def __init__(self, settings, router_conf, **kwargs): """Initialize the Router to handle notifications and registrations with the given settings and router conf.""" raise NotImplementedError("__init__ must be implemented") diff --git a/autopush/router/simple.py b/autopush/router/simple.py index 96f34313..b1fe981f 100644 --- a/autopush/router/simple.py +++ b/autopush/router/simple.py @@ -40,13 +40,17 @@ class SimpleRouter(object): """ log = Logger() - def __init__(self, ap_settings, router_conf): + def __init__(self, ap_settings, router_conf, db): """Create a new SimpleRouter""" self.ap_settings = ap_settings - self.metrics = ap_settings.metrics self.conf = router_conf + self.db = db self.waker = None + @property + def metrics(self): + return self.db.metrics + def register(self, uaid, router_data, app_id, *args, **kwargs): # type: (str, JSONDict, str, *Any, **Any) -> None """No additional routing data""" @@ -70,7 +74,7 @@ def route_notification(self, notification, uaid_data): node_id = uaid_data.get("node_id") uaid = uaid_data["uaid"] self.udp = uaid_data.get("udp") - router = self.ap_settings.router + router = self.db.router # Node_id is present, attempt delivery. # - Send Notification to node @@ -181,7 +185,7 @@ def _save_notification(self, uaid_data, notification): """ uaid = uaid_data["uaid"] - return deferToThread(self.ap_settings.storage.save_notification, + return deferToThread(self.db.storage.save_notification, uaid=uaid, chid=notification.channel_id, version=notification.version) diff --git a/autopush/router/webpush.py b/autopush/router/webpush.py index 045d25de..20580dc3 100644 --- a/autopush/router/webpush.py +++ b/autopush/router/webpush.py @@ -89,6 +89,6 @@ def _save_notification(self, uaid_data, notification): "Location": location}, logged_status=204) return deferToThread( - self.ap_settings.message_tables[month_table].store_message, + self.db.message_tables[month_table].store_message, notification=notification, ) diff --git a/autopush/settings.py b/autopush/settings.py index 293b6927..f700cc47 100644 --- a/autopush/settings.py +++ b/autopush/settings.py @@ -1,48 +1,32 @@ """Autopush Settings Object and Setup""" -import datetime +import json import socket - +from argparse import Namespace # noqa from hashlib import sha256 +from typing import Any # noqa from cryptography.fernet import Fernet, MultiFernet from cryptography.hazmat.primitives import constant_time from twisted.internet import reactor -from twisted.internet.defer import ( - inlineCallbacks, - returnValue, -) -from twisted.internet.threads import deferToThread from twisted.web.client import Agent, HTTPConnectionPool, _HTTP11ClientFactory -from autopush.db import ( - get_router_table, - get_storage_table, - get_rotating_message_table, - preflight_check, - Storage, - Router, - Message, -) -from autopush.exceptions import InvalidTokenException, VapidAuthException -from autopush.metrics import ( - DatadogMetrics, - TwistedMetrics, - SinkMetrics, -) -from autopush.router import ( - APNSRouter, - GCMRouter, - SimpleRouter, - WebPushRouter, + +import autopush.db as db +from autopush.exceptions import ( + InvalidSettings, + InvalidTokenException, + VapidAuthException ) from autopush.utils import ( + CLIENT_SHA256_RE, canonical_url, + get_amid, resolve_ip, repad, base64url_decode, parse_auth_header, ) -from autopush.crypto_key import (CryptoKey, CryptoKeyException) +from autopush.crypto_key import CryptoKey, CryptoKeyException class QuietClientFactory(_HTTP11ClientFactory): @@ -144,18 +128,11 @@ def __init__(self, if resolve_hostname: self.hostname = resolve_ip(self.hostname) - # Metrics setup - if datadog_api_key: - self.metrics = DatadogMetrics( - hostname=self.hostname, - api_key=datadog_api_key, - app_key=datadog_app_key, - flush_interval=datadog_flush_interval, - ) - elif statsd_host: - self.metrics = TwistedMetrics(statsd_host, statsd_port) - else: - self.metrics = SinkMetrics() + self.datadog_api_key = datadog_api_key + self.datadog_app_key = datadog_app_key + self.datadog_flush_interval = datadog_flush_interval + self.statsd_host = statsd_host + self.statsd_port = statsd_port self.port = port self.router_port = router_port @@ -200,35 +177,17 @@ def __init__(self, self.max_connections = max_connections self.close_handshake_timeout = close_handshake_timeout - # Database objects - self.router_table = get_router_table(router_tablename, - router_read_throughput, - router_write_throughput) - self.storage_table = get_storage_table( - storage_tablename, - storage_read_throughput, - storage_write_throughput) - self.message_table = get_rotating_message_table( - message_tablename, - message_read_throughput=message_read_throughput, - message_write_throughput=message_write_throughput) - self._message_prefix = message_tablename - self.message_limit = msg_limit - self.storage = Storage(self.storage_table, self.metrics) - self.router = Router(self.router_table, self.metrics) - - # Used to determine whether a connection is out of date with current - # db objects. There are three noteworty cases: - # 1 "Last Month" the table requires a rollover. - # 2 "This Month" the most common case. - # 3 "Next Month" where the system will soon be rolling over, but with - # timing, some nodes may roll over sooner. Ensuring the next month's - # table is present before the switchover is the main reason for this, - # just in case some nodes do switch sooner. - self.create_initial_message_tables() - - # Run preflight check - preflight_check(self.storage, self.router, preflight_uaid) + self.router_tablename = router_tablename + self.router_read_throughput = router_read_throughput + self.router_write_throughput = router_write_throughput + self.storage_tablename = storage_tablename + self.storage_read_throughput = storage_read_throughput + self.storage_write_throughput = storage_write_throughput + self.message_tablename = message_tablename + self.message_read_throughput = message_read_throughput + self.message_write_throughput = message_write_throughput + + self.msg_limit = msg_limit # CORS self.cors = enable_cors @@ -236,18 +195,6 @@ def __init__(self, # Force timeout in idle seconds self.wake_timeout = wake_timeout - # Setup the routers - self.routers = dict() - self.routers["simplepush"] = SimpleRouter( - self, - router_conf.get("simplepush") - ) - self.routers["webpush"] = WebPushRouter(self, None) - if 'apns' in router_conf: - self.routers["apns"] = APNSRouter(self, router_conf["apns"]) - if 'gcm' in router_conf: - self.routers["gcm"] = GCMRouter(self, router_conf["gcm"]) - # Env self.env = env @@ -258,78 +205,123 @@ def __init__(self, # Generate messages per legacy rules, only used for testing to # generate legacy data. self._notification_legacy = False - - @property - def message(self): - """Property that access the current message table""" - return self.message_tables[self.current_msg_month] - - @message.setter - def message(self, value): - """Setter to set the current message table""" - self.message_tables[self.current_msg_month] = value - - def _tomorrow(self): - return datetime.date.today() + datetime.timedelta(days=1) - - def create_initial_message_tables(self): - """Initializes a dict of the initial rotating messages tables. - - An entry for last months table, an entry for this months table, - an entry for tomorrow, if tomorrow is a new month. - - """ - today = datetime.date.today() - last_month = get_rotating_message_table(self._message_prefix, -1) - this_month = get_rotating_message_table(self._message_prefix) - self.current_month = today.month - self.current_msg_month = this_month.table_name - self.message_tables = { - last_month.table_name: Message(last_month, self.metrics), - this_month.table_name: Message(this_month, self.metrics) - } - if self._tomorrow().month != today.month: - next_month = get_rotating_message_table(self._message_prefix, - delta=1) - self.message_tables[next_month.table_name] = Message( - next_month, self.metrics) - - @inlineCallbacks - def update_rotating_tables(self): - """This method is intended to be tasked to run periodically off the - twisted event hub to rotate tables. - - When today is a new month from yesterday, then we swap out all the - table objects on the settings object. - - """ - today = datetime.date.today() - tomorrow = self._tomorrow() - if ((tomorrow.month != today.month) and - sorted(self.message_tables.keys())[-1] != - tomorrow.month): - next_month = yield deferToThread( - get_rotating_message_table, - self._message_prefix, 0, tomorrow - ) - self.message_tables[next_month.table_name] = Message( - next_month, self.metrics) - - if today.month == self.current_month: - # No change in month, we're fine. - returnValue(False) - - # Get tables for the new month, and verify they exist before we try to - # switch over - message_table = yield deferToThread(get_rotating_message_table, - self._message_prefix) - - # Both tables found, safe to switch-over - self.current_month = today.month - self.current_msg_month = message_table.table_name - self.message_tables[self.current_msg_month] = \ - Message(message_table, self.metrics) - returnValue(True) + self.preflight_uaid = preflight_uaid + + @classmethod + def from_argparse(cls, ns, **kwargs): + # type: (Namespace, **Any) -> AutopushSettings + """Create an instance from argparse/additional kwargs""" + router_conf = {} + if ns.key_hash: + db.key_hash = ns.key_hash + # Some routers require a websocket to timeout on idle + # (e.g. UDP) + if ns.wake_pem is not None and ns.wake_timeout != 0: + router_conf["simplepush"] = {"idle": ns.wake_timeout, + "server": ns.wake_server, + "cert": ns.wake_pem} + if ns.apns_creds: + # if you have the critical elements for each external + # router, create it + try: + router_conf["apns"] = json.loads(ns.apns_creds) + except (ValueError, TypeError): + raise InvalidSettings( + "Invalid JSON specified for APNS config options") + if ns.gcm_enabled: + # Create a common gcmclient + try: + sender_ids = json.loads(ns.senderid_list) + except (ValueError, TypeError): + raise InvalidSettings( + "Invalid JSON specified for senderid_list") + try: + # This is an init check to verify that things are + # configured correctly. Otherwise errors may creep in + # later that go unaccounted. + sender_ids[sender_ids.keys()[0]] + except (IndexError, TypeError): + raise InvalidSettings("No GCM SenderIDs specified or found.") + router_conf["gcm"] = {"ttl": ns.gcm_ttl, + "dryrun": ns.gcm_dryrun, + "max_data": ns.max_data, + "collapsekey": ns.gcm_collapsekey, + "senderIDs": sender_ids} + + client_certs = None + # endpoint only + if getattr(ns, 'client_certs', None): + try: + client_certs_arg = json.loads(ns.client_certs) + except (ValueError, TypeError): + raise InvalidSettings( + "Invalid JSON specified for client_certs") + if client_certs_arg: + if not ns.ssl_key: + raise InvalidSettings("client_certs specified without SSL " + "enabled (no ssl_key specified)") + client_certs = {} + for name, sigs in client_certs_arg.iteritems(): + if not isinstance(sigs, list): + raise InvalidSettings( + "Invalid JSON specified for client_certs") + for sig in sigs: + sig = sig.upper() + if (not name or not CLIENT_SHA256_RE.match(sig) or + sig in client_certs): + raise InvalidSettings( + "Invalid client_certs argument") + client_certs[sig] = name + + if ns.fcm_enabled: + # Create a common gcmclient + if not ns.fcm_auth: + raise InvalidSettings("No Authorization Key found for FCM") + if not ns.fcm_senderid: + raise InvalidSettings("No SenderID found for FCM") + router_conf["fcm"] = {"ttl": ns.fcm_ttl, + "dryrun": ns.fcm_dryrun, + "max_data": ns.max_data, + "collapsekey": ns.fcm_collapsekey, + "auth": ns.fcm_auth, + "senderid": ns.fcm_senderid} + + ami_id = None + # Not a fan of double negatives, but this makes more + # understandable args + if not ns.no_aws: + ami_id = get_amid() + + return cls( + crypto_key=ns.crypto_key, + datadog_api_key=ns.datadog_api_key, + datadog_app_key=ns.datadog_app_key, + datadog_flush_interval=ns.datadog_flush_interval, + hostname=ns.hostname, + statsd_host=ns.statsd_host, + statsd_port=ns.statsd_port, + router_conf=router_conf, + router_tablename=ns.router_tablename, + storage_tablename=ns.storage_tablename, + storage_read_throughput=ns.storage_read_throughput, + storage_write_throughput=ns.storage_write_throughput, + message_tablename=ns.message_tablename, + message_read_throughput=ns.message_read_throughput, + message_write_throughput=ns.message_write_throughput, + router_read_throughput=ns.router_read_throughput, + router_write_throughput=ns.router_write_throughput, + resolve_hostname=ns.resolve_hostname, + wake_timeout=ns.wake_timeout, + ami_id=ami_id, + client_certs=client_certs, + msg_limit=ns.msg_limit, + connect_timeout=ns.connection_timeout, + memusage_port=ns.memusage_port, + ssl_key=ns.ssl_key, + ssl_cert=ns.ssl_cert, + ssl_dh_param=ns.ssl_dh_param, + **kwargs + ) def update(self, **kwargs): """Update the arguments, if a ``crypto_key`` is in kwargs then the @@ -376,7 +368,7 @@ def make_endpoint(self, uaid, chid, key=None): ep = self.fernet.encrypt(base + sha256(raw_key).digest()).strip('=') return root + 'v2/' + ep - def parse_endpoint(self, token, version="v1", ckey_header=None, + def parse_endpoint(self, metrics, token, version="v1", ckey_header=None, auth_header=None): """Parse an endpoint into component elements of UAID, CHID and optional key hash if v2 @@ -404,7 +396,7 @@ def parse_endpoint(self, token, version="v1", ckey_header=None, vapid_auth = parse_auth_header(auth_header) if not vapid_auth: raise VapidAuthException("Invalid Auth token") - self.metrics.increment("updates.notification.auth.{}".format( + metrics.increment("updates.notification.auth.{}".format( vapid_auth['scheme'] )) # pull the public key from the VAPID auth header if needed diff --git a/autopush/tests/support.py b/autopush/tests/support.py index 4336155a..474ef9f9 100644 --- a/autopush/tests/support.py +++ b/autopush/tests/support.py @@ -1,6 +1,14 @@ +from mock import Mock from twisted.logger import ILogObserver from zope.interface import implementer +from autopush.db import ( + DatabaseManager, + Router, + Storage +) +from autopush.metrics import SinkMetrics + @implementer(ILogObserver) class TestingLogObserver(object): @@ -28,3 +36,12 @@ def logged_session(self): """Extract the last logged session""" return filter(lambda e: e["log_format"] == "Session", self._events)[-1] + + +def test_db(metrics=None): + """Return a test DatabaseManager: its Storage/Router are mocked""" + return DatabaseManager( + storage=Mock(spec=Storage), + router=Mock(spec=Router), + metrics=SinkMetrics() if metrics is None else metrics + ) diff --git a/autopush/tests/test_diagnostic_cli.py b/autopush/tests/test_diagnostic_cli.py index 91bad1da..9e8bfa23 100644 --- a/autopush/tests/test_diagnostic_cli.py +++ b/autopush/tests/test_diagnostic_cli.py @@ -30,7 +30,7 @@ def test_basic_load(self): "--router_tablename=fred", "http://someendpoint", ]) - eq_(cli._settings.router_table.table_name, "fred") + eq_(cli.db.router.table.table_name, "fred") def test_bad_endpoint(self): cli = self._makeFUT([ @@ -41,16 +41,19 @@ def test_bad_endpoint(self): ok_(returncode not in (None, 0)) @patch("autopush.diagnostic_cli.AutopushSettings") - def test_successfull_lookup(self, mock_settings_class): + @patch("autopush.diagnostic_cli.DatabaseManager.from_settings") + def test_successfull_lookup(self, mock_db_cstr, mock_settings_class): from autopush.diagnostic_cli import run_endpoint_diagnostic_cli mock_settings_class.return_value = mock_settings = Mock() mock_settings.parse_endpoint.return_value = dict( uaid="asdf", chid="asdf") - mock_settings.router.get_uaid.return_value = mock_item = FakeDict() + + mock_db_cstr.return_value = mock_db = Mock() + mock_db.router.get_uaid.return_value = mock_item = FakeDict() mock_item._data = {} mock_item["current_month"] = "201608120002" mock_message_table = Mock() - mock_settings.message_tables = {"201608120002": mock_message_table} + mock_db.message_tables = {"201608120002": mock_message_table} run_endpoint_diagnostic_cli([ "--router_tablename=fred", diff --git a/autopush/tests/test_endpoint.py b/autopush/tests/test_endpoint.py index 5d8fced0..3d48bf32 100644 --- a/autopush/tests/test_endpoint.py +++ b/autopush/tests/test_endpoint.py @@ -3,7 +3,6 @@ import twisted.internet.base from cryptography.fernet import Fernet, InvalidToken -from cyclone.web import Application from mock import Mock, patch from nose.tools import eq_, ok_ from twisted.internet.defer import inlineCallbacks @@ -13,8 +12,6 @@ import autopush.utils as utils from autopush.db import ( ProvisionedThroughputExceededException, - Router, - Storage, Message, ItemNotFound, create_rotating_message_table, @@ -22,10 +19,13 @@ ) from autopush.exceptions import RouterException from autopush.http import EndpointHTTPFactory -from autopush.settings import AutopushSettings +from autopush.metrics import SinkMetrics +from autopush.router import routers_from_settings from autopush.router.interface import IRouter +from autopush.settings import AutopushSettings from autopush.tests.client import Client from autopush.tests.test_db import make_webpush_notification +from autopush.tests.support import test_db from autopush.utils import ( generate_hash, ) @@ -64,12 +64,11 @@ def setUp(self): statsd_host=None, crypto_key='AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=', ) + db = test_db() + self.message_mock = db.message = Mock(spec=Message) self.fernet_mock = settings.fernet = Mock(spec=Fernet) - self.router_mock = settings.router = Mock(spec=Router) - self.storage_mock = settings.storage = Mock(spec=Storage) - self.message_mock = settings.message = Mock(spec=Message) - app = EndpointHTTPFactory.for_handler(MessageHandler, settings) + app = EndpointHTTPFactory.for_handler(MessageHandler, settings, db=db) self.client = Client(app) def url(self, **kwargs): @@ -153,22 +152,22 @@ def setUp(self): bear_hash_key='AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAB=', ) self.fernet_mock = settings.fernet = Mock(spec=Fernet) - self.router_mock = settings.router = Mock(spec=Router) - self.storage_mock = settings.storage = Mock(spec=Storage) - self.router_mock.register_user = Mock() - self.router_mock.register_user.return_value = (True, {}, {}) - settings.routers["test"] = Mock(spec=IRouter) - settings.router.get_uaid.return_value = { + + self.db = db = test_db() + db.router.register_user.return_value = (True, {}, {}) + db.router.get_uaid.return_value = { "router_type": "test", "router_data": dict() } - app = EndpointHTTPFactory(settings) + db.create_initial_message_tables() + + self.routers = routers = routers_from_settings(settings, db) + routers["test"] = Mock(spec=IRouter) + app = EndpointHTTPFactory(settings, db=db, routers=routers) self.client = Client(app) self.request_mock = Mock(body=b'', arguments={}, headers={}) - self.reg = NewRegistrationHandler(Application(), - self.request_mock, - ap_settings=settings) + self.reg = NewRegistrationHandler(app, self.request_mock) self.auth = ("WebPush %s" % generate_hash(settings.bear_hash_key[0], dummy_uaid.hex)) @@ -293,9 +292,12 @@ def test_post_gcm(self): from autopush.router.gcm import GCMRouter sids = {"182931248179192": {"auth": "aailsjfilajdflijdsilfjsliaj"}} - gcm = GCMRouter(self.settings, - {"dryrun": True, "senderIDs": sids}) - self.settings.routers["gcm"] = gcm + gcm = GCMRouter( + self.settings, + {"dryrun": True, "senderIDs": sids}, + SinkMetrics() + ) + self.routers["gcm"] = gcm self.fernet_mock.configure_mock(**{ 'encrypt.return_value': 'abcd123', }) @@ -314,7 +316,7 @@ def test_post_gcm(self): eq_(payload["uaid"], dummy_uaid.hex) eq_(payload["channelID"], dummy_chid.hex) eq_(payload["endpoint"], "http://localhost/wpush/v1/abcd123") - calls = self.settings.router.register_user.call_args + calls = self.db.router.register_user.call_args call_args = calls[0][0] eq_(True, has_connected_this_month(call_args)) ok_("secret" in payload) @@ -348,10 +350,9 @@ def test_post_bad_router_type(self): @inlineCallbacks def test_post_bad_router_register(self, *args): - frouter = Mock(spec=IRouter) - self.settings.routers["simplepush"] = frouter + router = self.routers["simplepush"] rexc = RouterException("invalid", status_code=402, errno=107) - frouter.register = Mock(side_effect=rexc) + router.register = Mock(side_effect=rexc) resp = yield self.client.post( self.url(router_type="simplepush"), @@ -386,8 +387,7 @@ def test_post_existing_uaid(self): @inlineCallbacks def test_no_uaid(self): - self.settings.router.get_uaid = Mock() - self.settings.router.get_uaid.side_effect = ItemNotFound + self.db.router.get_uaid.side_effect = ItemNotFound resp = yield self.client.delete( self.url(router_type="webpush", uaid=dummy_uaid.hex, @@ -487,7 +487,7 @@ def test_put(self, *args): self.patch('uuid.uuid4', return_value=dummy_uaid) data = dict(token="some_token") - frouter = self.settings.routers["test"] + frouter = self.routers["test"] frouter.register = Mock() frouter.register.return_value = data @@ -504,7 +504,7 @@ def test_put(self, *args): router_data=data, app_id='test', ) - user_data = self.router_mock.register_user.call_args[0][0] + user_data = self.db.router.register_user.call_args[0][0] eq_(user_data['uaid'], dummy_uaid.hex) eq_(user_data['router_type'], 'test') eq_(user_data['router_data']['token'], 'some_token') @@ -547,7 +547,7 @@ def test_put_bad_arguments(self, *args): @inlineCallbacks def test_put_bad_router_register(self): - frouter = self.settings.routers["test"] + frouter = self.routers["test"] rexc = RouterException("invalid", status_code=402, errno=107) frouter.register = Mock(side_effect=rexc) @@ -561,7 +561,7 @@ def test_put_bad_router_register(self): @inlineCallbacks def test_delete_bad_chid_value(self): notif = make_webpush_notification(dummy_uaid.hex, str(dummy_chid)) - messages = self.settings.message + messages = self.db.message messages.register_channel(dummy_uaid.hex, str(dummy_chid)) messages.store_message(notif) @@ -577,7 +577,7 @@ def test_delete_bad_chid_value(self): @inlineCallbacks def test_delete_no_such_chid(self): notif = make_webpush_notification(dummy_uaid.hex, str(dummy_chid)) - messages = self.settings.message + messages = self.db.message messages.register_channel(dummy_uaid.hex, str(dummy_chid)) messages.store_message(notif) @@ -599,11 +599,10 @@ def test_delete_no_such_chid(self): def test_delete_uaid(self): notif = make_webpush_notification(dummy_uaid.hex, str(dummy_chid)) notif2 = make_webpush_notification(dummy_uaid.hex, str(dummy_chid)) - messages = self.settings.message + messages = self.db.message messages.store_message(notif) messages.store_message(notif2) - self.settings.router.drop_user = Mock() - self.settings.router.drop_user.return_value = True + self.db.router.drop_user.return_value = True yield self.client.delete( self.url(router_type="simplepush", @@ -613,9 +612,8 @@ def test_delete_uaid(self): ) # Note: Router is mocked, so the UAID is never actually # dropped. - ok_(self.settings.router.drop_user.called) - eq_(self.settings.router.drop_user.call_args_list[0][0], - (dummy_uaid.hex,)) + ok_(self.db.router.drop_user.called) + eq_(self.db.router.drop_user.call_args_list[0][0], (dummy_uaid.hex,)) @inlineCallbacks def test_delete_bad_uaid(self): @@ -629,8 +627,7 @@ def test_delete_bad_uaid(self): @inlineCallbacks def test_delete_orphans(self): - self.router_mock.drop_user = Mock() - self.router_mock.drop_user.return_value = False + self.db.router.drop_user.return_value = False resp = yield self.client.delete( self.url(router_type="test", router_token="test", @@ -663,15 +660,15 @@ def test_delete_bad_router(self): def test_get(self): chids = [str(dummy_chid), str(dummy_uaid)] - self.settings.message.all_channels = Mock() - self.settings.message.all_channels.return_value = (True, chids) + self.db.message.all_channels = Mock() + self.db.message.all_channels.return_value = (True, chids) resp = yield self.client.get( self.url(router_type="test", router_token="test", uaid=dummy_uaid.hex), headers={"Authorization": self.auth} ) - self.settings.message.all_channels.assert_called_with(str(dummy_uaid)) + self.db.message.all_channels.assert_called_with(str(dummy_uaid)) payload = json.loads(resp.content) eq_(chids, payload['channelIDs']) eq_(dummy_uaid.hex, payload['uaid']) diff --git a/autopush/tests/test_health.py b/autopush/tests/test_health.py index 3cb6c3d8..f777e3e6 100644 --- a/autopush/tests/test_health.py +++ b/autopush/tests/test_health.py @@ -2,7 +2,6 @@ import twisted.internet.base from boto.dynamodb2.exceptions import InternalServerError -from cyclone.web import Application from mock import Mock from moto import mock_dynamodb2 from nose.tools import eq_ @@ -33,8 +32,6 @@ def setUp(self): hostname="localhost", statsd_host=None, ) - self.router_table = settings.router.table - self.storage_table = settings.storage.table # ignore logging logs = TestingLogObserver() @@ -42,6 +39,8 @@ def setUp(self): self.addCleanup(globalLogPublisher.removeObserver, logs) app = EndpointHTTPFactory.for_handler(HealthHandler, settings) + self.router_table = app.db.router.table + self.storage_table = app.db.storage.table self.client = Client(app) @inlineCallbacks @@ -138,8 +137,10 @@ def setUp(self): statsd_host=None, ) self.request_mock = Mock() - self.status = StatusHandler(Application(), self.request_mock, - ap_settings=settings) + self.status = StatusHandler( + EndpointHTTPFactory(settings, db=None, routers=None), + self.request_mock + ) self.write_mock = self.status.write = Mock() def test_status(self): diff --git a/autopush/tests/test_integration.py b/autopush/tests/test_integration.py index f9c3c5a6..267cd270 100644 --- a/autopush/tests/test_integration.py +++ b/autopush/tests/test_integration.py @@ -13,6 +13,7 @@ from distutils.spawn import find_executable from StringIO import StringIO from httplib import HTTPResponse # noqa +from mock import Mock, call from unittest.case import SkipTest from zope.interface import implementer @@ -612,8 +613,7 @@ def test_basic_last_connect(self): yield client.disconnect() # Verify the last_connect is there and the current month - c = yield deferToThread( - self.conn.settings.router.get_uaid, client.uaid) + c = yield deferToThread(self.conn.db.router.get_uaid, client.uaid) eq_(True, has_connected_this_month(c)) # Move it back @@ -629,8 +629,7 @@ def test_basic_last_connect(self): yield client.disconnect() times = 0 while times < 10: - c = yield deferToThread( - self.conn.settings.router.get_uaid, client.uaid) + c = yield deferToThread(self.conn.db.router.get_uaid, client.uaid) if has_connected_this_month(c): break else: # pragma: nocover @@ -1257,12 +1256,11 @@ def test_message_without_crypto_headers(self): @inlineCallbacks def test_message_with_topic(self): - from mock import Mock, call data = str(uuid.uuid4()) - self.conn.settings.metrics = Mock(spec=SinkMetrics) + self.conn.db.metrics = Mock(spec=SinkMetrics) client = yield self.quick_register(use_webpush=True) yield client.send_notification(data=data, topic="topicname") - self.conn.settings.metrics.increment.assert_has_calls([ + self.conn.db.metrics.increment.assert_has_calls([ call('updates.notification.topic', tags=['host:localhost', 'use_webpush:True']) ]) @@ -1343,17 +1341,16 @@ def test_webpush_monthly_rotation(self): # Move the client back one month to the past last_month = make_rotating_tablename( - prefix=self.conn.settings._message_prefix, delta=-1) - lm_message = self.conn.settings.message_tables[last_month] + prefix=self.conn.db._message_prefix, delta=-1) + lm_message = self.conn.db.message_tables[last_month] yield deferToThread( - self.conn.settings.router.update_message_month, + self.conn.db.router.update_message_month, client.uaid, last_month ) # Verify the move - c = yield deferToThread( - self.conn.settings.router.get_uaid, client.uaid) + c = yield deferToThread(self.conn.db.router.get_uaid, client.uaid) eq_(c["current_month"], last_month) # Verify last_connect is current, then move that back @@ -1366,7 +1363,7 @@ def test_webpush_monthly_rotation(self): # Move the clients channels back one month exists, chans = yield deferToThread( - self.conn.settings.message.all_channels, client.uaid + self.conn.db.message.all_channels, client.uaid ) eq_(exists, True) eq_(len(chans), 1) @@ -1377,14 +1374,15 @@ def test_webpush_monthly_rotation(self): ) # Remove the channels entry entirely from this month - yield deferToThread(self.conn.settings.message.table.delete_item, - uaid=client.uaid, - chidmessageid=" " - ) + yield deferToThread( + self.conn.db.message.table.delete_item, + uaid=client.uaid, + chidmessageid=" " + ) # Verify the channel is gone exists, chans = yield deferToThread( - self.conn.settings.message.all_channels, + self.conn.db.message.all_channels, client.uaid ) eq_(exists, False) @@ -1421,16 +1419,16 @@ def test_webpush_monthly_rotation(self): start = time.time() while time.time()-start < 2: c = yield deferToThread( - self.conn.settings.router.get_uaid, client.uaid) - if c["current_month"] == self.conn.settings.current_msg_month: + self.conn.db.router.get_uaid, client.uaid) + if c["current_month"] == self.conn.db.current_msg_month: break else: yield deferToThread(time.sleep, 0.2) # Verify the month update in the router table c = yield deferToThread( - self.conn.settings.router.get_uaid, client.uaid) - eq_(c["current_month"], self.conn.settings.current_msg_month) + self.conn.db.router.get_uaid, client.uaid) + eq_(c["current_month"], self.conn.db.current_msg_month) eq_(server_client.ps.rotate_message_table, False) # Verify the client moved last_connect @@ -1438,7 +1436,7 @@ def test_webpush_monthly_rotation(self): # Verify the channels were moved exists, chans = yield deferToThread( - self.conn.settings.message.all_channels, + self.conn.db.message.all_channels, client.uaid ) eq_(exists, True) @@ -1454,17 +1452,16 @@ def test_webpush_monthly_rotation_prior_record_exists(self): # Move the client back one month to the past last_month = make_rotating_tablename( - prefix=self.conn.settings._message_prefix, delta=-1) - lm_message = self.conn.settings.message_tables[last_month] + prefix=self.conn.db._message_prefix, delta=-1) + lm_message = self.conn.db.message_tables[last_month] yield deferToThread( - self.conn.settings.router.update_message_month, + self.conn.db.router.update_message_month, client.uaid, last_month ) # Verify the move - c = yield deferToThread( - self.conn.settings.router.get_uaid, client.uaid) + c = yield deferToThread(self.conn.db.router.get_uaid, client.uaid) eq_(c["current_month"], last_month) # Verify last_connect is current, then move that back @@ -1477,7 +1474,7 @@ def test_webpush_monthly_rotation_prior_record_exists(self): # Move the clients channels back one month exists, chans = yield deferToThread( - self.conn.settings.message.all_channels, client.uaid + self.conn.db.message.all_channels, client.uaid ) eq_(exists, True) eq_(len(chans), 1) @@ -1518,16 +1515,15 @@ def test_webpush_monthly_rotation_prior_record_exists(self): start = time.time() while time.time()-start < 2: c = yield deferToThread( - self.conn.settings.router.get_uaid, client.uaid) - if c["current_month"] == self.conn.settings.current_msg_month: + self.conn.db.router.get_uaid, client.uaid) + if c["current_month"] == self.conn.db.current_msg_month: break else: yield deferToThread(time.sleep, 0.2) # Verify the month update in the router table - c = yield deferToThread( - self.conn.settings.router.get_uaid, client.uaid) - eq_(c["current_month"], self.conn.settings.current_msg_month) + c = yield deferToThread(self.conn.db.router.get_uaid, client.uaid) + eq_(c["current_month"], self.conn.db.current_msg_month) eq_(server_client.ps.rotate_message_table, False) # Verify the client moved last_connect @@ -1535,7 +1531,7 @@ def test_webpush_monthly_rotation_prior_record_exists(self): # Verify the channels were moved exists, chans = yield deferToThread( - self.conn.settings.message.all_channels, + self.conn.db.message.all_channels, client.uaid ) eq_(exists, True) @@ -1553,21 +1549,20 @@ def test_webpush_monthly_rotation_no_channels(self): # Move the client back one month to the past last_month = make_rotating_tablename( - prefix=self.conn.settings._message_prefix, delta=-1) + prefix=self.conn.db._message_prefix, delta=-1) yield deferToThread( - self.conn.settings.router.update_message_month, + self.conn.db.router.update_message_month, client.uaid, last_month ) # Verify the move - c = yield deferToThread( - self.conn.settings.router.get_uaid, client.uaid) + c = yield deferToThread(self.conn.db.router.get_uaid, client.uaid) eq_(c["current_month"], last_month) # Verify there's no channels exists, chans = yield deferToThread( - self.conn.settings.message.all_channels, + self.conn.db.message.all_channels, client.uaid ) eq_(exists, False) @@ -1585,16 +1580,15 @@ def test_webpush_monthly_rotation_no_channels(self): start = time.time() while time.time()-start < 2: c = yield deferToThread( - self.conn.settings.router.get_uaid, client.uaid) - if c["current_month"] == self.conn.settings.current_msg_month: + self.conn.db.router.get_uaid, client.uaid) + if c["current_month"] == self.conn.db.current_msg_month: break else: yield deferToThread(time.sleep, 0.2) # Verify the month update in the router table - c = yield deferToThread( - self.conn.settings.router.get_uaid, client.uaid) - eq_(c["current_month"], self.conn.settings.current_msg_month) + c = yield deferToThread(self.conn.db.router.get_uaid, client.uaid) + eq_(c["current_month"], self.conn.db.current_msg_month) eq_(server_client.ps.rotate_message_table, False) yield self.shut_down(client) @@ -1794,7 +1788,6 @@ def needs_retry(cls=None): def _add_router(self): from autopush.router.gcm import GCMRouter - from mock import Mock gcm = GCMRouter( self.ep.settings, { @@ -1805,9 +1798,10 @@ def _add_router(self): "senderIDs": {self.senderID: {"auth": "AIzaSyCx9PRtH8ByaJR3Cf" "Jamz0D2N0uaCgRGiI"}} - } + }, + self.ep.db.metrics ) - self.ep.settings.routers["gcm"] = gcm + self.ep.routers["gcm"] = gcm # Set up the mock call to avoid calling the live system. # The problem with calling the live system (even sandboxed) is that # you need a valid credential set from a mobile device, which can be @@ -1983,7 +1977,6 @@ class TestFCMBridgeIntegration(IntegrationBase): def _add_router(self): from autopush.router.fcm import FCMRouter - from mock import Mock fcm = FCMRouter( self.ep.settings, { @@ -1993,9 +1986,10 @@ def _add_router(self): "collapsekey": "test", "senderID": self.senderID, "auth": "AIzaSyCx9PRtH8ByaJR3CfJamz0D2N0uaCgRGiI", - } + }, + self.ep.db.metrics ) - self.ep.settings.routers["fcm"] = fcm + self.ep.routers["fcm"] = fcm # Set up the mock call to avoid calling the live system. # The problem with calling the live system (even sandboxed) is that # you need a valid credential set from a mobile device, which can be @@ -2066,17 +2060,19 @@ class m_response: def _add_router(self): from autopush.router.apnsrouter import APNSRouter - from mock import Mock apns = APNSRouter( - self.ep.settings, { + self.ep.settings, + { "firefox": { "cert": "/home/user/certs/SimplePushDemo.p12_cert.pem", "key": "/home/user/certs/SimplePushDemo.p12_key.pem", "sandbox": True, } }, - load_connections=False,) - self.ep.settings.routers["apns"] = apns + self.ep.db.metrics, + load_connections=False + ) + self.ep.routers["apns"] = apns # Set up the mock call to avoid calling the live system. # The problem with calling the live system (even sandboxed) is that # you need a valid credential set from a mobile device, which can be diff --git a/autopush/tests/test_main.py b/autopush/tests/test_main.py index 5788c2a0..bc7d91c4 100644 --- a/autopush/tests/test_main.py +++ b/autopush/tests/test_main.py @@ -14,18 +14,17 @@ import hyper import hyper.tls -from autopush.db import get_rotating_message_table +from autopush.db import DatabaseManager, get_rotating_message_table from autopush.exceptions import InvalidSettings from autopush.http import skip_request_logging from autopush.main import ( ConnectionApplication, EndpointApplication, - make_settings, ) from autopush.settings import AutopushSettings +from autopush.tests.support import test_db from autopush.utils import resolve_ip - connection_main = ConnectionApplication.main endpoint_main = EndpointApplication.main mock_dynamodb2 = mock_dynamodb2() @@ -62,10 +61,11 @@ def test_new_month(self): tomorrow = datetime.datetime(year=next_year, month=next_month, day=1) - AutopushSettings._tomorrow = Mock() - AutopushSettings._tomorrow.return_value = tomorrow - settings = AutopushSettings() - eq_(len(settings.message_tables), 3) + db = test_db() + db._tomorrow = Mock() + db._tomorrow.return_value = tomorrow + db.create_initial_message_tables() + eq_(len(db.message_tables), 3) class SettingsAsyncTestCase(trialtest.TestCase): @@ -73,25 +73,27 @@ def test_update_rotating_tables(self): from autopush.db import get_month settings = AutopushSettings( hostname="example.com", resolve_hostname=True) + db = DatabaseManager.from_settings(settings) + db.create_initial_message_tables() # Erase the tables it has on init, and move current month back one last_month = get_month(-1) - settings.current_month = last_month.month - settings.message_tables = {} + db.current_month = last_month.month + db.message_tables = {} # Create the next month's table, just in case today is the day before # a new month, in which case the lack of keys will cause an error in # update_rotating_tables next_month = get_month(1) - settings.message_tables[next_month.month] = None + db.message_tables[next_month.month] = None # Get the deferred back e = Deferred() - d = settings.update_rotating_tables() + d = db.update_rotating_tables() def check_tables(result): - eq_(len(settings.message_tables), 2) - eq_(settings.current_month, get_month().month) + eq_(len(db.message_tables), 2) + eq_(db.current_month, get_month().month) d.addCallback(check_tables) d.addBoth(lambda x: e.callback(True)) @@ -123,26 +125,26 @@ def test_update_rotating_tables_month_end(self): tomorrow = datetime.datetime(year=next_year, month=next_month, day=1) - AutopushSettings._tomorrow = Mock() - AutopushSettings._tomorrow.return_value = tomorrow settings = AutopushSettings( hostname="example.com", resolve_hostname=True) + db = DatabaseManager.from_settings(settings) + db._tomorrow = Mock(return_value=tomorrow) + db.create_initial_message_tables() # We should have 3 tables, one for next/this/last month - eq_(len(settings.message_tables), 3) + eq_(len(db.message_tables), 3) # Grab next month's table name and remove it - next_month = get_rotating_message_table(settings._message_prefix, - delta=1) - settings.message_tables.pop(next_month.table_name) + next_month = get_rotating_message_table(db._message_prefix, delta=1) + db.message_tables.pop(next_month.table_name) # Get the deferred back - d = settings.update_rotating_tables() + d = db.update_rotating_tables() def check_tables(result): - eq_(len(settings.message_tables), 3) - ok_(next_month.table_name in settings.message_tables) + eq_(len(db.message_tables), 3) + ok_(next_month.table_name in db.message_tables) d.addCallback(check_tables) return d @@ -151,22 +153,24 @@ def test_update_not_needed(self): from autopush.db import get_month settings = AutopushSettings( hostname="google.com", resolve_hostname=True) + db = DatabaseManager.from_settings(settings) + db.create_initial_message_tables() # Erase the tables it has on init, and move current month back one - settings.message_tables = {} + db.message_tables = {} # Create the next month's table, just in case today is the day before # a new month, in which case the lack of keys will cause an error in # update_rotating_tables next_month = get_month(1) - settings.message_tables[next_month.month] = None + db.message_tables[next_month.month] = None # Get the deferred back e = Deferred() - d = settings.update_rotating_tables() + d = db.update_rotating_tables() def check_tables(result): - eq_(len(settings.message_tables), 1) + eq_(len(db.message_tables), 1) d.addCallback(check_tables) d.addBoth(lambda x: e.callback(True)) @@ -178,7 +182,7 @@ def setUp(self): patchers = [ "autopush.main.TimerService.startService", "autopush.main.reactor", - "autopush.settings.TwistedMetrics", + "autopush.metrics.TwistedMetrics", ] self.mocks = {} for name in patchers: @@ -270,10 +274,10 @@ class TestArg: def setUp(self): patchers = [ + "autopush.db.preflight_check", "autopush.main.TimerService.startService", "autopush.main.reactor", - "autopush.settings.TwistedMetrics", - "autopush.settings.preflight_check", + "autopush.metrics.TwistedMetrics", ] self.mocks = {} for name in patchers: @@ -329,10 +333,10 @@ def test_memusage(self): @patch('hyper.tls', spec=hyper.tls) def test_client_certs_parse(self, mock): - ap = make_settings(self.TestArg) - eq_(ap.client_certs["1A:"*31 + "F9"], 'partner1') - eq_(ap.client_certs["2B:"*31 + "E8"], 'partner2') - eq_(ap.client_certs["3C:"*31 + "D7"], 'partner2') + settings = AutopushSettings.from_argparse(self.TestArg) + eq_(settings.client_certs["1A:"*31 + "F9"], 'partner1') + eq_(settings.client_certs["2B:"*31 + "E8"], 'partner2') + eq_(settings.client_certs["3C:"*31 + "D7"], 'partner2') def test_bad_client_certs(self): cert = self.TestArg._client_certs['partner1'][0] @@ -355,19 +359,20 @@ def test_bad_client_certs(self): spec=hyper.HTTP20Connection) @patch('hyper.tls', spec=hyper.tls) def test_settings(self, *args): - ap = make_settings(self.TestArg) + settings = AutopushSettings.from_argparse(self.TestArg) + app = EndpointApplication(settings) # verify that the hostname is what we said. - eq_(ap.hostname, self.TestArg.hostname) - eq_(ap.routers["gcm"].config['collapsekey'], "collapse") - eq_(ap.routers["apns"]._config['firefox']['cert'], "cert.file") - eq_(ap.routers["apns"]._config['firefox']['key'], "key.file") - eq_(ap.wake_timeout, 10) + eq_(settings.hostname, self.TestArg.hostname) + eq_(app.routers["gcm"].config['collapsekey'], "collapse") + eq_(app.routers["apns"]._config['firefox']['cert'], "cert.file") + eq_(app.routers["apns"]._config['firefox']['key'], "key.file") + eq_(settings.wake_timeout, 10) def test_bad_senders(self): old_list = self.TestArg.senderid_list self.TestArg.senderid_list = "{}" with assert_raises(InvalidSettings): - make_settings(self.TestArg) + AutopushSettings.from_argparse(self.TestArg) self.TestArg.senderid_list = old_list def test_bad_fcm_senders(self): @@ -375,11 +380,11 @@ def test_bad_fcm_senders(self): old_senderid = self.TestArg.fcm_senderid self.TestArg.fcm_auth = "" with assert_raises(InvalidSettings): - make_settings(self.TestArg) + AutopushSettings.from_argparse(self.TestArg) self.TestArg.fcm_auth = old_auth self.TestArg.fcm_senderid = "" with assert_raises(InvalidSettings): - make_settings(self.TestArg) + AutopushSettings.from_argparse(self.TestArg) self.TestArg.fcm_senderid = old_senderid def test_gcm_start(self): @@ -398,5 +403,5 @@ class MockReply: request_mock.return_value = MockReply self.TestArg.no_aws = False - ap = make_settings(self.TestArg) - eq_(ap.ami_id, "ami_123") + settings = AutopushSettings.from_argparse(self.TestArg) + eq_(settings.ami_id, "ami_123") diff --git a/autopush/tests/test_router.py b/autopush/tests/test_router.py index c2a4240e..dc32d77d 100644 --- a/autopush/tests/test_router.py +++ b/autopush/tests/test_router.py @@ -20,14 +20,13 @@ from hyper.http20.exceptions import HTTP20Error from autopush.db import ( - Router, - Storage, Message, ProvisionedThroughputExceededException, ItemNotFound, create_rotating_message_table, ) from autopush.exceptions import RouterException +from autopush.metrics import SinkMetrics from autopush.router import ( APNSRouter, GCMRouter, @@ -38,6 +37,7 @@ from autopush.router.interface import RouterResponse, IRouter from autopush.settings import AutopushSettings from autopush.tests import MockAssist +from autopush.tests.support import test_db from autopush.web.base import Notification @@ -101,7 +101,7 @@ def setUp(self, mt, mc): } self.mock_connection = mc mc.return_value = mc - self.router = APNSRouter(settings, apns_config) + self.router = APNSRouter(settings, apns_config, SinkMetrics()) self.mock_response = Mock() self.mock_response.status = 200 mc.get_response.return_value = self.mock_response @@ -317,7 +317,7 @@ def setUp(self, fgcm): 'senderIDs': {'test123': {"auth": "12345678abcdefg"}}} self.gcm = fgcm - self.router = GCMRouter(settings, self.gcm_config) + self.router = GCMRouter(settings, self.gcm_config, SinkMetrics()) self.headers = {"content-encoding": "aesgcm", "encryption": "test", "encryption-key": "test"} @@ -357,7 +357,7 @@ def test_init(self): statsd_host=None, ) with assert_raises(IOError): - GCMRouter(settings, {"senderIDs": {}}) + GCMRouter(settings, {"senderIDs": {}}, SinkMetrics()) def test_register(self): router_data = {"token": "test123"} @@ -386,7 +386,11 @@ def test_gcmclient_fail(self, fgcm): statsd_host=None, ) with assert_raises(IOError): - GCMRouter(settings, {"senderIDs": {"test123": {"auth": "abcd"}}}) + GCMRouter( + settings, + {"senderIDs": {"test123": {"auth": "abcd"}}}, + SinkMetrics() + ) def test_route_notification(self): self.router.gcm['test123'] = self.gcm @@ -627,7 +631,7 @@ def setUp(self, ffcm): 'senderID': 'test123', "auth": "12345678abcdefg"} self.fcm = ffcm - self.router = FCMRouter(settings, self.fcm_config) + self.router = FCMRouter(settings, self.fcm_config, SinkMetrics()) self.headers = {"content-encoding": "aesgcm", "encryption": "test", "encryption-key": "test"} @@ -672,7 +676,7 @@ def throw_auth(*args, **kwargs): ffcm.side_effect = throw_auth with assert_raises(IOError): - FCMRouter(settings, {}) + FCMRouter(settings, {}, SinkMetrics()) def test_register(self): router_data = {"token": "test123"} @@ -906,8 +910,10 @@ def setUp(self): hostname="localhost", statsd_host=None, ) + self.metrics = metrics = Mock(spec=SinkMetrics) + db = test_db(metrics=metrics) - self.router = SimpleRouter(settings, {}) + self.router = SimpleRouter(settings, {}, db) self.router.log = Mock(spec=Logger) self.notif = Notification(10, "data", dummy_chid) mock_result = Mock(spec=gcmclient.gcm.Result) @@ -915,11 +921,10 @@ def setUp(self): mock_result.failed = dict() mock_result.not_registered = dict() mock_result.needs_retry.return_value = False - self.router_mock = settings.router = Mock(spec=Router) - self.storage_mock = settings.storage = Mock(spec=Storage) + self.router_mock = db.router + self.storage_mock = db.storage self.agent_mock = Mock(spec=settings.agent) settings.agent = self.agent_mock - self.router.metrics = Mock() def _raise_connect_error(self): raise ConnectError() @@ -1108,7 +1113,7 @@ def test_route_to_busy_node_saves_looks_up_and_sends_check_200(self): def verify_deliver(result): ok_(isinstance(result, RouterResponse)) eq_(result.status_code, 200) - self.router.metrics.increment.assert_called_with( + self.metrics.increment.assert_called_with( "router.broadcast.save_hit" ) d.addBoth(verify_deliver) @@ -1147,13 +1152,15 @@ def setUp(self): hostname="localhost", statsd_host=None, ) + self.metrics = metrics = Mock(spec=SinkMetrics) + self.db = db = test_db(metrics=metrics) self.headers = headers = { "content-encoding": "aes128", "encryption": "awesomecrypto", "crypto-key": "niftykey" } - self.router = WebPushRouter(settings, {}) + self.router = WebPushRouter(settings, {}, db) self.notif = WebPushNotification( uaid=uuid.UUID(dummy_uaid), channel_id=uuid.UUID(dummy_chid), @@ -1168,11 +1175,10 @@ def setUp(self): mock_result.failed = dict() mock_result.not_registered = dict() mock_result.needs_retry.return_value = False - self.router_mock = settings.router = Mock(spec=Router) - self.message_mock = settings.message = Mock(spec=Message) + self.router_mock = db.router + self.message_mock = db.message = Mock(spec=Message) self.agent_mock = Mock(spec=settings.agent) settings.agent = self.agent_mock - self.router.metrics = Mock() self.settings = settings def test_route_to_busy_node_saves_looks_up_and_sends_check_201(self): @@ -1183,7 +1189,7 @@ def test_route_to_busy_node_saves_looks_up_and_sends_check_201(self): self.message_mock.store_message.return_value = True self.message_mock.all_channels.return_value = (True, [dummy_chid]) router_data = dict(node_id="http://somewhere", uaid=dummy_uaid, - current_month=self.settings.current_msg_month) + current_month=self.db.current_msg_month) self.router_mock.get_uaid.return_value = router_data self.router.message_id = uuid.uuid4().hex @@ -1197,7 +1203,7 @@ def verify_deliver(result): eq_(t_h.get('encryption'), self.headers.get('encryption')) eq_(t_h.get('crypto_key'), self.headers.get('crypto-key')) eq_(t_h.get('encoding'), self.headers.get('content-encoding')) - self.router.metrics.increment.assert_called_with( + self.metrics.increment.assert_called_with( "router.broadcast.save_hit" ) ok_("Location" in result.headers) @@ -1222,7 +1228,7 @@ def test_route_to_busy_node_with_ttl_zero(self): self.message_mock.store_message.return_value = True self.message_mock.all_channels.return_value = (True, [dummy_chid]) router_data = dict(node_id="http://somewhere", uaid=dummy_uaid, - current_month=self.settings.current_msg_month) + current_month=self.db.current_msg_month) self.router_mock.get_uaid.return_value = router_data self.router.message_id = uuid.uuid4().hex @@ -1232,7 +1238,7 @@ def verify_deliver(fail): exc = fail.value ok_(exc, RouterResponse) eq_(exc.status_code, 201) - eq_(len(self.router.metrics.increment.mock_calls), 0) + eq_(len(self.metrics.increment.mock_calls), 0) ok_("Location" in exc.headers) d.addBoth(verify_deliver) return d diff --git a/autopush/tests/test_web_base.py b/autopush/tests/test_web_base.py index 51f42962..f9a518fd 100644 --- a/autopush/tests/test_web_base.py +++ b/autopush/tests/test_web_base.py @@ -2,7 +2,6 @@ import uuid from boto.exception import BotoServerError -from cyclone.web import Application from mock import Mock, patch from moto import mock_dynamodb2 from nose.tools import eq_, ok_ @@ -15,6 +14,7 @@ create_rotating_message_table, ProvisionedThroughputExceededException, ) +from autopush.http import EndpointHTTPFactory from autopush.exceptions import InvalidRequest from autopush.settings import AutopushSettings @@ -57,9 +57,10 @@ def setUp(self): headers={"ttl": "0"}, host='example.com:8080') - self.base = BaseWebHandler(Application(), - self.request_mock, - ap_settings=settings) + self.base = BaseWebHandler( + EndpointHTTPFactory(settings, db=None, routers=None), + self.request_mock + ) self.status_mock = self.base.set_status = Mock() self.write_mock = self.base.write = Mock() self.base.log = Mock(spec=Logger) diff --git a/autopush/tests/test_web_validation.py b/autopush/tests/test_web_validation.py index c1fb0b1a..f899d7d3 100644 --- a/autopush/tests/test_web_validation.py +++ b/autopush/tests/test_web_validation.py @@ -18,7 +18,9 @@ from twisted.trial import unittest from autopush.db import create_rotating_message_table +from autopush.metrics import SinkMetrics from autopush.exceptions import InvalidRequest, InvalidTokenException +from autopush.tests.support import test_db import autopush.utils as utils @@ -74,6 +76,9 @@ def _write_validation_err(rh, errors): vr = ValidateRequest(app, request) vr._timings = dict() vr.ap_settings = Mock() + vr.metrics = Mock() + vr.db = Mock() + vr.routers = Mock() return vr def _make_full(self, schema=None): @@ -114,7 +119,7 @@ def test_call_func_error(self): @inlineCallbacks def test_decorator(self): - from cyclone.web import Application + from autopush.http import EndpointHTTPFactory from autopush.web.base import BaseWebHandler, threaded_validate from autopush.tests.client import Client schema = self._make_basic_schema() @@ -128,7 +133,12 @@ def get(self): self.write("done") self.finish() - app = Application([('/test', AHandler, dict(ap_settings=Mock()))]) + app = EndpointHTTPFactory( + Mock(), + db=test_db(), + routers=None, + handlers=[('/test', AHandler)] + ) client = Client(app) resp = yield client.get('/test') eq_(resp.content, "done") @@ -138,8 +148,13 @@ class TestSimplePushRequestSchema(unittest.TestCase): def _make_fut(self): from autopush.web.simplepush import SimplePushRequestSchema schema = SimplePushRequestSchema() - schema.context["settings"] = Mock() - schema.context["log"] = Mock() + schema.context.update( + settings=Mock(), + metrics=SinkMetrics(), + db=test_db(), + routers=Mock(), + log=Mock() + ) return schema def _make_test_data(self, headers=None, body="", path_args=None, @@ -159,7 +174,7 @@ def test_valid_data(self): chid=dummy_chid, public_key="", ) - schema.context["settings"].router.get_uaid.return_value = dict( + schema.context["db"].router.get_uaid.return_value = dict( router_type="simplepush", ) result, errors = schema.load(self._make_test_data()) @@ -174,7 +189,7 @@ def test_valid_data_in_body(self): chid=dummy_chid, public_key="", ) - schema.context["settings"].router.get_uaid.return_value = dict( + schema.context["db"].router.get_uaid.return_value = dict( router_type="simplepush", ) result, errors = schema.load( @@ -191,7 +206,7 @@ def test_valid_version(self): chid=dummy_chid, public_key="", ) - schema.context["settings"].router.get_uaid.return_value = dict( + schema.context["db"].router.get_uaid.return_value = dict( router_type="simplepush", ) result, errors = schema.load( @@ -209,7 +224,7 @@ def test_invalid_router_type(self): chid=dummy_chid, public_key="", ) - schema.context["settings"].router.get_uaid.return_value = dict( + schema.context["db"].router.get_uaid.return_value = dict( router_type="webpush", ) @@ -229,7 +244,7 @@ def test_invalid_uaid_not_found(self): def throw_item(*args, **kwargs): raise ItemNotFound("Not found") - schema.context["settings"].router.get_uaid.side_effect = throw_item + schema.context["db"].router.get_uaid.side_effect = throw_item with assert_raises(InvalidRequest) as cm: schema.load(self._make_test_data()) @@ -256,7 +271,7 @@ def test_invalid_data_size(self): chid=dummy_chid, public_key="", ) - schema.context["settings"].router.get_uaid.return_value = dict( + schema.context["db"].router.get_uaid.return_value = dict( router_type="simplepush", ) schema.context["settings"].max_data = 1 @@ -271,8 +286,13 @@ class TestWebPushRequestSchema(unittest.TestCase): def _make_fut(self): from autopush.web.webpush import WebPushRequestSchema schema = WebPushRequestSchema() - schema.context["settings"] = Mock() - schema.context["log"] = Mock() + schema.context.update( + settings=Mock(), + metrics=SinkMetrics(), + db=test_db(), + routers=Mock(), + log=Mock() + ) return schema def _make_test_data(self, headers=None, body="", path_args=None, @@ -292,7 +312,7 @@ def test_valid_data(self): chid=dummy_chid, public_key="", ) - schema.context["settings"].router.get_uaid.return_value = dict( + schema.context["db"].router.get_uaid.return_value = dict( router_type="gcm", router_data=dict(creds=dict(senderID="bogus")), ) @@ -308,7 +328,7 @@ def test_no_headers(self): chid=dummy_chid, public_key="", ) - schema.context["settings"].router.get_uaid.return_value = dict( + schema.context["db"].router.get_uaid.return_value = dict( router_type="gcm", router_data=dict(creds=dict(senderID="bogus")), ) @@ -328,7 +348,7 @@ def test_invalid_simplepush_user(self): chid=dummy_chid, public_key="", ) - schema.context["settings"].router.get_uaid.return_value = dict( + schema.context["db"].router.get_uaid.return_value = dict( router_type="simplepush", ) @@ -374,7 +394,7 @@ def test_invalid_uaid_not_found(self): def throw_item(*args, **kwargs): raise ItemNotFound("Not found") - schema.context["settings"].router.get_uaid.side_effect = throw_item + schema.context["db"].router.get_uaid.side_effect = throw_item with assert_raises(InvalidRequest) as cm: schema.load(self._make_test_data()) @@ -388,7 +408,7 @@ def test_critical_failure(self): chid=dummy_chid, public_key="", ) - schema.context["settings"].router.get_uaid.return_value = dict( + schema.context["db"].router.get_uaid.return_value = dict( router_type="fcm", critical_failure="Bad SenderID", ) @@ -405,7 +425,7 @@ def test_invalid_header_combo(self): chid=dummy_chid, public_key="", ) - schema.context["settings"].router.get_uaid.return_value = dict( + schema.context["db"].router.get_uaid.return_value = dict( router_type="gcm", router_data=dict(creds=dict(senderID="bogus")), ) @@ -429,7 +449,7 @@ def test_invalid_header_combo_04(self): chid=dummy_chid, public_key="", ) - schema.context["settings"].router.get_uaid.return_value = dict( + schema.context["db"].router.get_uaid.return_value = dict( router_type="gcm", router_data=dict(creds=dict(senderID="bogus")), ) @@ -456,7 +476,7 @@ def test_missing_encryption_salt(self): chid=dummy_chid, public_key="", ) - schema.context["settings"].router.get_uaid.return_value = dict( + schema.context["db"].router.get_uaid.return_value = dict( router_type="gcm", router_data=dict(creds=dict(senderID="bogus")), ) @@ -481,7 +501,7 @@ def test_missing_encryption_salt_04(self): chid=dummy_chid, public_key="", ) - schema.context["settings"].router.get_uaid.return_value = dict( + schema.context["db"].router.get_uaid.return_value = dict( router_type="gcm", router_data=dict(creds=dict(senderID="bogus")), ) @@ -506,7 +526,7 @@ def test_missing_encryption_key_dh(self): chid=dummy_chid, public_key="", ) - schema.context["settings"].router.get_uaid.return_value = dict( + schema.context["db"].router.get_uaid.return_value = dict( router_type="gcm", router_data=dict(creds=dict(senderID="bogus")), ) @@ -531,7 +551,7 @@ def test_missing_crypto_key_dh(self): chid=dummy_chid, public_key="", ) - schema.context["settings"].router.get_uaid.return_value = dict( + schema.context["db"].router.get_uaid.return_value = dict( router_type="gcm", uaid=dummy_uaid, router_data=dict(creds=dict(senderID="bogus")), @@ -557,7 +577,7 @@ def test_invalid_data_size(self): chid=dummy_chid, public_key="", ) - schema.context["settings"].router.get_uaid.return_value = dict( + schema.context["db"].router.get_uaid.return_value = dict( router_type="gcm", uaid=dummy_uaid, router_data=dict(creds=dict(senderID="bogus")), @@ -581,7 +601,7 @@ def test_invalid_data_must_have_crypto_headers(self): chid=dummy_chid, public_key="", ) - schema.context["settings"].router.get_uaid.return_value = dict( + schema.context["db"].router.get_uaid.return_value = dict( router_type="gcm", router_data=dict(creds=dict(senderID="bogus")), ) @@ -598,7 +618,7 @@ def test_valid_data_crypto_padding_stripped(self): chid=dummy_chid, public_key="", ) - schema.context["settings"].router.get_uaid.return_value = dict( + schema.context["db"].router.get_uaid.return_value = dict( router_type="gcm", router_data=dict(creds=dict(senderID="bogus")), ) @@ -625,7 +645,7 @@ def test_invalid_dh_value_for_01_crypto(self): chid=dummy_chid, public_key="", ) - schema.context["settings"].router.get_uaid.return_value = dict( + schema.context["db"].router.get_uaid.return_value = dict( router_type="gcm", router_data=dict(creds=dict(senderID="bogus")), ) @@ -656,7 +676,7 @@ def test_invalid_vapid_crypto_header(self): chid=dummy_chid, public_key="", ) - schema.context["settings"].router.get_uaid.return_value = dict( + schema.context["db"].router.get_uaid.return_value = dict( router_type="gcm", uaid=dummy_uaid, router_data=dict(creds=dict(senderID="bogus")), @@ -684,7 +704,7 @@ def test_invalid_topic(self): chid=dummy_chid, public_key="", ) - schema.context["settings"].router.get_uaid.return_value = dict( + schema.context["db"].router.get_uaid.return_value = dict( router_type="gcm", uaid=dummy_uaid, router_data=dict(creds=dict(senderID="bogus")), @@ -725,7 +745,7 @@ def test_no_current_month(self): chid=dummy_chid, public_key="", ) - schema.context["settings"].router.get_uaid.return_value = dict( + schema.context["db"].router.get_uaid.return_value = dict( router_type="webpush", uaid=dummy_uaid, ) @@ -741,13 +761,13 @@ def test_no_current_month(self): def test_old_current_month(self): schema = self._make_fut() - schema.context["settings"].message_tables = dict() + schema.context["db"].message_tables = dict() schema.context["settings"].parse_endpoint.return_value = dict( uaid=dummy_uaid, chid=dummy_chid, public_key="", ) - schema.context["settings"].router.get_uaid.return_value = dict( + schema.context["db"].router.get_uaid.return_value = dict( router_type="webpush", uaid=dummy_uaid, current_month="message_2014_01", @@ -767,14 +787,20 @@ class TestWebPushRequestSchemaUsingVapid(unittest.TestCase): def _make_fut(self): from autopush.web.webpush import WebPushRequestSchema from autopush.settings import AutopushSettings - schema = WebPushRequestSchema() - schema.context["log"] = Mock() - schema.context["settings"] = settings = AutopushSettings( + settings = AutopushSettings( hostname="localhost", statsd_host=None, ) - settings.router = Mock() - settings.router.get_uaid.return_value = dict( + db = test_db() + schema = WebPushRequestSchema() + schema.context.update( + settings=settings, + metrics=SinkMetrics(), + db=db, + routers=Mock(), + log=Mock() + ) + db.router.get_uaid.return_value = dict( router_type="gcm", uaid=dummy_uaid, router_data=dict(creds=dict(senderID="bogus")), diff --git a/autopush/tests/test_web_webpush.py b/autopush/tests/test_web_webpush.py index c20aec4c..6654f562 100644 --- a/autopush/tests/test_web_webpush.py +++ b/autopush/tests/test_web_webpush.py @@ -8,11 +8,12 @@ from twisted.internet.defer import inlineCallbacks from twisted.trial import unittest -from autopush.db import Router, create_rotating_message_table +from autopush.db import Message, create_rotating_message_table from autopush.http import EndpointHTTPFactory from autopush.router.interface import IRouter, RouterResponse from autopush.settings import AutopushSettings from autopush.tests.client import Client +from autopush.tests.support import test_db dummy_uaid = str(uuid.UUID("abad1dea00000000aabbccdd00000000")) dummy_chid = str(uuid.UUID("deadbeef00000000decafbad00000000")) @@ -38,14 +39,13 @@ def setUp(self): statsd_host=None, ) self.fernet_mock = settings.fernet = Mock(spec=Fernet) - self.router_mock = settings.router = Mock(spec=Router) - settings.routers["webpush"] = Mock(spec=IRouter) - self.wp_router_mock = settings.routers["webpush"] - self.message_mock = settings.message = Mock() + self.db = db = test_db() + self.message_mock = db.message = Mock(spec=Message) self.message_mock.all_channels.return_value = (True, [dummy_chid]) - app = EndpointHTTPFactory.for_handler(WebPushHandler, settings) + app = EndpointHTTPFactory.for_handler(WebPushHandler, settings, db=db) + self.wp_router_mock = app.routers["webpush"] = Mock(spec=IRouter) self.client = Client(app) def url(self, **kwargs): @@ -59,11 +59,11 @@ def test_router_needs_update(self): public_key="asdfasdf", )) self.fernet_mock.decrypt.return_value = dummy_token - self.router_mock.get_uaid.return_value = dict( + self.db.router.get_uaid.return_value = dict( router_type="webpush", router_data=dict(), uaid=dummy_uaid, - current_month=self.ap_settings.current_msg_month, + current_month=self.db.current_msg_month, ) self.wp_router_mock.route_notification.return_value = RouterResponse( status_code=503, @@ -74,7 +74,7 @@ def test_router_needs_update(self): self.url(api_ver="v1", token=dummy_token), ) eq_(resp.get_status(), 503) - ru = self.router_mock.register_user + ru = self.db.router.register_user ok_(ru.called) eq_('webpush', ru.call_args[0][0].get('router_type')) @@ -86,11 +86,11 @@ def test_router_returns_data_without_detail(self): public_key="asdfasdf", )) self.fernet_mock.decrypt.return_value = dummy_token - self.router_mock.get_uaid.return_value = dict( + self.db.router.get_uaid.return_value = dict( uaid=dummy_uaid, router_type="webpush", router_data=dict(uaid="uaid"), - current_month=self.ap_settings.current_msg_month, + current_month=self.db.current_msg_month, ) self.wp_router_mock.route_notification.return_value = RouterResponse( status_code=503, @@ -101,7 +101,7 @@ def test_router_returns_data_without_detail(self): self.url(api_ver="v1", token=dummy_token), ) eq_(resp.get_status(), 503) - ok_(self.router_mock.drop_user.called) + ok_(self.db.router.drop_user.called) @inlineCallbacks def test_request_bad_ckey(self): @@ -169,8 +169,7 @@ def test_request_bad_v2_id_missing_pubkey(self): def test_request_v2_id_variant_pubkey(self): self.fernet_mock.decrypt.return_value = 'a' * 32 variant_key = base64.urlsafe_b64encode("0V0" + ('a' * 85)) - self.ap_settings.router.get_uaid = Mock() - self.ap_settings.router.get_uaid.return_value = dict( + self.db.router.get_uaid.return_value = dict( uaid=dummy_uaid, chid=dummy_chid, router_type="gcm", @@ -186,8 +185,7 @@ def test_request_v2_id_variant_pubkey(self): @inlineCallbacks def test_request_v2_id_no_crypt_auth(self): self.fernet_mock.decrypt.return_value = 'a' * 32 - self.ap_settings.router.get_uaid = Mock() - self.ap_settings.router.get_uaid.return_value = dict( + self.db.router.get_uaid.return_value = dict( uaid=dummy_uaid, chid=dummy_chid, router_type="gcm", diff --git a/autopush/tests/test_websocket.py b/autopush/tests/test_websocket.py index d7d0cc81..c47f3ae4 100644 --- a/autopush/tests/test_websocket.py +++ b/autopush/tests/test_websocket.py @@ -14,7 +14,6 @@ from boto.exception import JSONResponseError from mock import Mock, patch from nose.tools import assert_raises, eq_, ok_ -from txstatsd.metrics.metrics import Metrics from twisted.internet import reactor from twisted.internet.defer import ( inlineCallbacks, @@ -25,8 +24,9 @@ from twisted.trial import unittest import autopush.db as db -from autopush.db import create_rotating_message_table +from autopush.db import DatabaseManager, create_rotating_message_table from autopush.http import InternalRouterHTTPFactory +from autopush.metrics import SinkMetrics from autopush.settings import AutopushSettings from autopush.tests import MockAssist from autopush.utils import WebPushNotification @@ -38,6 +38,7 @@ RouterHandler, NotificationHandler, WebSocketServerProtocol, + periodic_reporter ) from autopush.utils import base64url_encode, ms_time @@ -114,7 +115,11 @@ def setUp(self): statsd_host=None, env="test", ) - self.factory = PushServerFactory(settings) + db = DatabaseManager.from_settings(settings) + self.metrics = db.metrics = Mock(spec=SinkMetrics) + db.create_initial_message_tables() + + self.factory = PushServerFactory(settings, db) self.proto = self.factory.buildProtocol(('localhost', 8080)) self.proto._log_exc = False self.proto.log = Mock(spec=Logger) @@ -123,13 +128,12 @@ def setUp(self): self.orig_close = self.proto.sendClose request_mock = Mock() request_mock.headers = {} - self.proto.ps = PushState(settings=settings, request=request_mock) + self.proto.ps = PushState(db=db, request=request_mock) self.proto.sendClose = self.close_mock = Mock() self.proto.transport = self.transport_mock = Mock() self.proto.closeHandshakeTimeout = 0 self.proto.autoPingInterval = 300 self.proto._force_retry = self.proto.force_retry - settings.metrics = Mock(spec=Metrics) def tearDown(self): self.proto.force_retry = self.proto._force_retry @@ -215,7 +219,7 @@ def test_nuke_connection(self, mock_reactor): self.proto.state = "" self.proto.ps.uaid = uuid.uuid4().hex self.proto.nukeConnection() - ok_(self.proto.ap_settings.metrics.increment.called) + ok_(self.proto.metrics.increment.called) @patch("autopush.websocket.reactor") def test_nuke_connection_shutdown_ran(self, mock_reactor): @@ -252,18 +256,18 @@ def test_base_tags(self): "rv:1.9.2.3) Gecko/20100401 Firefox/3.6.3 (.NET " "CLR 3.5.30729)"} req.host = "example.com:8080" - ps = PushState(settings=self.proto.ap_settings, request=req) + ps = PushState(db=self.proto.db, request=req) eq_(sorted(ps._base_tags), sorted(['ua_os_family:Windows', 'ua_browser_family:Firefox', 'host:example.com:8080'])) def test_reporter(self): - from autopush.websocket import periodic_reporter - periodic_reporter(self.ap_settings, self.factory) + self.metrics.reset_mock() + periodic_reporter(self.ap_settings, self.metrics, self.factory) # Verify metric increase of nothing - calls = self.ap_settings.metrics.method_calls + calls = self.metrics.method_calls eq_(len(calls), 4) name, args, _ = calls[0] eq_(name, "gauge") @@ -382,8 +386,8 @@ def test_close_with_delivery_cleanup(self): self.proto.ps.direct_updates[chid] = 12 # Apply some mocks - self.proto.ap_settings.storage.save_notification = Mock() - self.proto.ap_settings.router.get_uaid = mock_get = Mock() + self.proto.db.storage.save_notification = Mock() + self.proto.db.router.get_uaid = mock_get = Mock() self.proto.ap_settings.agent = mock_agent = Mock() mock_get.return_value = dict(node_id="localhost:2000") @@ -403,8 +407,8 @@ def test_close_with_delivery_cleanup_using_webpush(self): self.proto.ps.direct_updates[dummy_chid_str] = [dummy_notif()] # Apply some mocks - self.proto.ap_settings.message.store_message = Mock() - self.proto.ap_settings.router.get_uaid = mock_get = Mock() + self.proto.db.message.store_message = Mock() + self.proto.db.router.get_uaid = mock_get = Mock() self.proto.ap_settings.agent = mock_agent = Mock() mock_get.return_value = dict(node_id="localhost:2000") @@ -424,16 +428,16 @@ def test_close_with_delivery_cleanup_and_get_no_result(self): self.proto.ps.direct_updates[chid] = 12 # Apply some mocks - self.proto.ap_settings.storage.save_notification = Mock() - self.proto.ap_settings.router.get_uaid = mock_get = Mock() - self.proto.ps.metrics = mock_metrics = Mock() + self.proto.db.storage.save_notification = Mock() + self.proto.db.router.get_uaid = mock_get = Mock() mock_get.return_value = False + self.metrics.reset_mock() # Close the connection self.proto.onClose(True, None, None) - yield self._wait_for(lambda: len(mock_metrics.mock_calls) > 2) - eq_(len(mock_metrics.mock_calls), 3) - mock_metrics.increment.assert_called_with( + yield self._wait_for(lambda: len(self.metrics.mock_calls) > 2) + eq_(len(self.metrics.mock_calls), 3) + self.metrics.increment.assert_called_with( "client.notify_uaid_failure", tags=None) @inlineCallbacks @@ -447,9 +451,9 @@ def test_close_with_delivery_cleanup_and_get_uaid_error(self): self.proto.ps.direct_updates[chid] = 12 # Apply some mocks - self.proto.ap_settings.storage.save_notification = Mock() - self.proto.ap_settings.router.get_uaid = mock_get = Mock() - self.proto.ps.metrics = mock_metrics = Mock() + self.proto.db.storage.save_notification = Mock() + self.proto.db.router.get_uaid = mock_get = Mock() + self.metrics.reset_mock() def raise_item(*args, **kwargs): raise ItemNotFound() @@ -458,9 +462,9 @@ def raise_item(*args, **kwargs): # Close the connection self.proto.onClose(True, None, None) - yield self._wait_for(lambda: len(mock_metrics.mock_calls) > 2) - eq_(len(mock_metrics.mock_calls), 3) - mock_metrics.increment.assert_called_with( + yield self._wait_for(lambda: len(self.metrics.mock_calls) > 2) + eq_(len(self.metrics.mock_calls), 3) + self.metrics.increment.assert_called_with( "client.lookup_uaid_failure", tags=None) @inlineCallbacks @@ -474,8 +478,8 @@ def test_close_with_delivery_cleanup_and_no_node_id(self): self.proto.ps.direct_updates[chid] = 12 # Apply some mocks - self.proto.ap_settings.storage.save_notification = Mock() - self.proto.ap_settings.router.get_uaid = mock_get = Mock() + self.proto.db.storage.save_notification = Mock() + self.proto.db.router.get_uaid = mock_get = Mock() mock_get.return_value = mock_node_get = Mock() mock_node_get.get.return_value = None @@ -491,7 +495,7 @@ def test_hello_old(self): target_day = datetime.date(2016, 2, 29) msg_day = datetime.date(2015, 12, 15) msg_date = "{}_{}_{}".format( - self.proto.ap_settings._message_prefix, + self.proto.db._message_prefix, msg_day.year, msg_day.month) msg_data = { @@ -500,7 +504,7 @@ def test_hello_old(self): "last_connect": int(msg_day.strftime("%s")), "current_month": msg_date, } - router = self.proto.ap_settings.router + router = self.proto.db.router router.table.put_item(data=dict( uaid=orig_uaid, connected_at=ms_time(), @@ -513,7 +517,7 @@ def fake_msg(data): mock_msg = Mock(wraps=db.Message) mock_msg.fetch_messages.return_value = [] - self.proto.ap_settings.router.register_user = fake_msg + self.proto.db.router.register_user = fake_msg # because we're faking the dates, process_notifications will key # error and fail to return. This will cause the expected path for # this test to fail. Since we're requesting the client to change @@ -521,7 +525,7 @@ def fake_msg(data): # notifications are irrelevant for this test. self.proto.process_notifications = Mock() # massage message_tables to include our fake range - mt = self.proto.ps.settings.message_tables + mt = self.proto.ps.db.message_tables for k in mt.keys(): del(mt[k]) mt['message_2016_1'] = mock_msg @@ -547,7 +551,7 @@ def fake_msg(data): @inlineCallbacks def test_hello_tomorrow(self): orig_uaid = "deadbeef00000000abad1dea00000000" - router = self.proto.ap_settings.router + router = self.proto.db.router router.register_user(dict( uaid=orig_uaid, connected_at=ms_time(), @@ -559,7 +563,7 @@ def test_hello_tomorrow(self): target_day = datetime.date(2016, 2, 29) msg_day = datetime.date(2016, 3, 1) msg_date = "{}_{}_{}".format( - self.proto.ap_settings._message_prefix, + self.proto.db._message_prefix, msg_day.year, msg_day.month) msg_data = { @@ -576,9 +580,9 @@ def fake_msg(data): mock_msg.fetch_messages.return_value = "01;", [] mock_msg.fetch_timestamp_messages.return_value = None, [] mock_msg.all_channels.return_value = (None, []) - self.proto.ap_settings.router.register_user = fake_msg + self.proto.db.router.register_user = fake_msg # massage message_tables to include our fake range - mt = self.proto.ps.settings.message_tables + mt = self.proto.ps.db.message_tables for k in mt.keys(): del(mt[k]) mt['message_2016_1'] = mock_msg @@ -606,7 +610,7 @@ def fake_msg(data): @inlineCallbacks def test_hello_tomorrow_provision_error(self): orig_uaid = "deadbeef00000000abad1dea00000000" - router = self.proto.ap_settings.router + router = self.proto.db.router router.register_user(dict( uaid=orig_uaid, connected_at=ms_time(), @@ -618,7 +622,7 @@ def test_hello_tomorrow_provision_error(self): target_day = datetime.date(2016, 2, 29) msg_day = datetime.date(2016, 3, 1) msg_date = "{}_{}_{}".format( - self.proto.ap_settings._message_prefix, + self.proto.db._message_prefix, msg_day.year, msg_day.month) msg_data = { @@ -635,9 +639,9 @@ def fake_msg(data): mock_msg.fetch_messages.return_value = "01;", [] mock_msg.fetch_timestamp_messages.return_value = None, [] mock_msg.all_channels.return_value = (None, []) - self.proto.ap_settings.router.register_user = fake_msg + self.proto.db.router.register_user = fake_msg # massage message_tables to include our fake range - mt = self.proto.ps.settings.message_tables + mt = self.proto.ps.db.message_tables mt.clear() mt['message_2016_1'] = mock_msg mt['message_2016_2'] = mock_msg @@ -650,7 +654,7 @@ def fake_msg(data): def raise_error(*args): raise ProvisionedThroughputExceededException(None, None) - self.proto.ap_settings.router.update_message_month = MockAssist([ + self.proto.db.router.update_message_month = MockAssist([ raise_error, Mock(), ]) @@ -716,7 +720,7 @@ def test_hello_with_webpush(self): def test_hello_with_missing_router_type(self): self._connect() uaid = uuid.uuid4().hex - router = self.proto.ap_settings.router + router = self.proto.db.router router.table.put_item(data=dict( uaid=uaid, connected_at=ms_time()-1000, @@ -732,7 +736,7 @@ def test_hello_with_missing_router_type(self): def test_hello_with_missing_current_month(self): self._connect() uaid = uuid.uuid4().hex - router = self.proto.ap_settings.router + router = self.proto.db.router router.register_user(dict( uaid=uaid, connected_at=ms_time(), @@ -748,7 +752,7 @@ def test_hello_with_missing_current_month(self): def test_hello_with_uaid(self): self._connect() uaid = uuid.uuid4().hex - router = self.proto.ap_settings.router + router = self.proto.db.router router.register_user(dict( uaid=uaid, connected_at=ms_time(), @@ -764,7 +768,7 @@ def test_hello_with_uaid(self): def test_hello_resets_record(self): self._connect() uaid = uuid.uuid4().hex - router = self.proto.ap_settings.router + router = self.proto.db.router router.register_user(dict( uaid=uaid, connected_at=ms_time(), @@ -811,7 +815,7 @@ def test_hello_with_bad_uaid_case(self): def test_hello_failure(self): self._connect() # Fail out the register_user call - router = self.proto.ap_settings.router + router = self.proto.db.router router.table.connection.update_item = Mock(side_effect=KeyError) self._send_message(dict(messageType="hello", channelIDs=[], stop=1)) @@ -829,7 +833,7 @@ def test_hello_provisioned_during_check(self): def throw_error(*args, **kwargs): raise ProvisionedThroughputExceededException(None, None) - router = self.proto.ap_settings.router + router = self.proto.db.router router.table.connection.update_item = Mock(side_effect=throw_error) self._send_message(dict(messageType="hello", channelIDs=[])) @@ -848,7 +852,7 @@ def test_hello_jsonresponseerror(self): def throw_error(*args, **kwargs): raise JSONResponseError(None, None) - router = self.proto.ap_settings.router + router = self.proto.db.router router.table.connection.update_item = Mock(side_effect=throw_error) self._send_message(dict(messageType="hello", channelIDs=[])) @@ -862,12 +866,12 @@ def test_hello_check_fail(self): self._connect() # Fail out the register_user call - self.proto.ap_settings.router.register_user = \ + self.proto.db.router.register_user = \ Mock(return_value=(False, {})) self._send_message(dict(messageType="hello", channelIDs=[])) msg = yield self.get_response() - calls = self.proto.ap_settings.router.register_user.mock_calls + calls = self.proto.db.router.register_user.mock_calls eq_(len(calls), 1) eq_(msg["status"], 500) eq_(msg["reason"], "already_connected") @@ -924,7 +928,7 @@ def test_hello_udp(self): "ignored": "ok"})) msg = yield self.get_response() eq_(msg["status"], 200) - route_data = self.proto.ap_settings.router.get_uaid( + route_data = self.proto.db.router.get_uaid( msg["uaid"]).get('wake_data') eq_(route_data, {'data': {"ip": "127.0.0.1", "port": 9999, "mcc": "hammer", @@ -942,7 +946,7 @@ def test_bad_hello_udp(self): msg = yield self.get_response() eq_(msg["status"], 200) ok_("wake_data" not in - self.proto.ap_settings.router.get_uaid(msg["uaid"]).keys()) + self.proto.db.router.get_uaid(msg["uaid"]).keys()) @inlineCallbacks def test_not_hello(self): @@ -1069,10 +1073,10 @@ def test_register_webpush(self): self.proto.ps.use_webpush = True chid = str(uuid.uuid4()) self.proto.ps.uaid = uuid.uuid4().hex - self.proto.ap_settings.message.register_channel = Mock() + self.proto.db.message.register_channel = Mock() yield self.proto.process_register(dict(channelID=chid)) - ok_(self.proto.ap_settings.message.register_channel.called) + ok_(self.proto.db.message.register_channel.called) assert_called_included(self.proto.log.info, format="Register") @inlineCallbacks @@ -1081,7 +1085,7 @@ def test_register_webpush_with_key(self): self.proto.ps.use_webpush = True chid = str(uuid.uuid4()) self.proto.ps.uaid = uuid.uuid4().hex - self.proto.ap_settings.message.register_channel = Mock() + self.proto.db.message.register_channel = Mock() test_key = "SomeRandomCryptoKeyString" test_sha = sha256(test_key).hexdigest() test_endpoint = ('http://localhost/wpush/v2/' + @@ -1101,7 +1105,7 @@ def echo(string): ) eq_(test_endpoint, self.proto.sendJSON.call_args[0][0]['pushEndpoint']) - ok_(self.proto.ap_settings.message.register_channel.called) + ok_(self.proto.db.message.register_channel.called) assert_called_included(self.proto.log.info, format="Register") @inlineCallbacks @@ -1205,7 +1209,7 @@ def test_register_over_provisioning(self): self.proto.ps.use_webpush = True chid = str(uuid.uuid4()) self.proto.ps.uaid = uuid.uuid4().hex - self.proto.ap_settings.message.register_channel = register = Mock() + self.proto.db.message.register_channel = register = Mock() def throw_provisioned(*args, **kwargs): raise ProvisionedThroughputExceededException(None, None) @@ -1213,7 +1217,7 @@ def throw_provisioned(*args, **kwargs): register.side_effect = throw_provisioned yield self.proto.process_register(dict(channelID=chid)) - ok_(self.proto.ap_settings.message.register_channel.called) + ok_(self.proto.db.message.register_channel.called) ok_(self.send_mock.called) args, _ = self.send_mock.call_args msg = json.loads(args[0]) @@ -1313,7 +1317,7 @@ def test_ws_unregister_fail(self): chid = str(uuid.uuid4()) # Replace storage delete with call to fail - table = self.proto.ap_settings.storage.table + table = self.proto.db.storage.table delete = table.delete_item def raise_exception(*args, **kwargs): @@ -1524,7 +1528,7 @@ def __init__(self): def __call__(self, *args, **kwargs): return self.tries != 0 - self.proto.ap_settings.storage = Mock( + self.proto.db.storage = Mock( **{"delete_notification.side_effect": FailFirst()}) chid = str(uuid.uuid4()) @@ -1584,7 +1588,7 @@ def test_process_notifications(self): self.proto.ps.uaid = uuid.uuid4().hex # Swap out fetch_notifications - self.proto.ap_settings.storage.fetch_notifications = Mock( + self.proto.db.storage.fetch_notifications = Mock( return_value=[] ) @@ -1619,7 +1623,7 @@ def throw_error(*args): raise ProvisionedThroughputExceededException(None, None) # Swap out fetch_notifications - self.proto.ap_settings.storage.fetch_notifications = MockAssist([ + self.proto.db.storage.fetch_notifications = MockAssist([ throw_error, [], ]) @@ -1645,7 +1649,7 @@ def test_process_notification_error(self): def throw_error(*args, **kwargs): raise Exception("An error happened!") - self.proto.ap_settings.storage = Mock( + self.proto.db.storage = Mock( **{"fetch_notifications.side_effect": throw_error}) self.proto.ps._check_notifications = True self.proto.process_notifications() @@ -1747,11 +1751,11 @@ def test_notif_finished_with_webpush_with_old_notifications(self): def test_notif_finished_with_too_many_messages(self): self._connect() + self.ap_settings.msg_limit = 2 self.proto.ps.uaid = uuid.uuid4().hex self.proto.ps.use_webpush = True self.proto.ps._check_notifications = True - self.proto.ps.msg_limit = 2 - self.proto.ap_settings.router.drop_user = Mock() + self.proto.db.router.drop_user = Mock() self.proto.ps.message.fetch_messages = Mock() notif = make_webpush_notification( @@ -1786,14 +1790,14 @@ def test_notification_results(self): chid3 = str(uuid.uuid4()) # Create a router record - router = self.proto.ap_settings.router + router = self.proto.db.router router.register_user(dict( uaid=uaid, connected_at=ms_time(), router_type="simplepush", )) - storage = self.proto.ap_settings.storage + storage = self.proto.db.storage storage.save_notification(uaid, chid, 12) storage.save_notification(uaid, chid2, 8) storage.save_notification(uaid, chid3, 9) @@ -1831,14 +1835,14 @@ def test_notification_dont_deliver_after_ack(self): chid = str(uuid.uuid4()) # Create a dummy router record - router = self.proto.ap_settings.router + router = self.proto.db.router router.register_user(dict( uaid=uaid, connected_at=ms_time(), router_type="simplepush", )) - storage = self.proto.ap_settings.storage + storage = self.proto.db.storage storage.save_notification(uaid, chid, 10) # Verify the message is stored @@ -1878,14 +1882,14 @@ def test_notification_dont_deliver(self): chid = str(uuid.uuid4()) # Create a dummy router record - router = self.proto.ap_settings.router + router = self.proto.db.router router.register_user(dict( uaid=uaid, connected_at=ms_time(), router_type="simplepush", )) - storage = self.proto.ap_settings.storage + storage = self.proto.db.storage storage.save_notification(uaid, chid, 12) # Verify the message is stored @@ -1920,7 +1924,7 @@ def test_notification_dont_deliver(self): eq_(len(calls), 1) def test_incomplete_uaid(self): - mm = self.proto.ap_settings.router = Mock() + mm = self.proto.db.router = Mock() fr = self.proto.force_retry = Mock() uaid = uuid.uuid4().hex mm.get_uaid.return_value = { diff --git a/autopush/web/base.py b/autopush/web/base.py index 9fa44798..d1f110b5 100644 --- a/autopush/web/base.py +++ b/autopush/web/base.py @@ -55,8 +55,13 @@ def _validate_request(self, request_handler, *args, **kwargs): "arguments": request_handler.request.arguments, } schema = self.schema() - schema.context["settings"] = request_handler.ap_settings - schema.context["log"] = self.log + schema.context.update( + settings=request_handler.ap_settings, + metrics=request_handler.metrics, + db=request_handler.db, + routers=request_handler.routers, + log=self.log + ) return schema.load(data) def _call_func(self, result, func, request_handler): @@ -139,14 +144,17 @@ class BaseWebHandler(BaseHandler): ############################################################# # Cyclone API Methods ############################################################# - def initialize(self, ap_settings): + def initialize(self): """Setup basic aliases and attributes""" - super(BaseWebHandler, self).initialize(ap_settings) - self.metrics = ap_settings.metrics + super(BaseWebHandler, self).initialize() self._base_tags = {} self._start_time = time.time() self._timings = {} + @property + def routers(self): + return self.application.routers + def prepare(self): """Common request preparation""" if self.ap_settings.enable_tls_auth: diff --git a/autopush/web/health.py b/autopush/web/health.py index 863a532e..3f140c91 100644 --- a/autopush/web/health.py +++ b/autopush/web/health.py @@ -32,8 +32,8 @@ def get(self): } dl = DeferredList([ - self._check_table(self.ap_settings.router.table), - self._check_table(self.ap_settings.storage.table) + self._check_table(self.db.router.table), + self._check_table(self.db.storage.table) ]) dl.addBoth(self._finish_response) diff --git a/autopush/web/message.py b/autopush/web/message.py index 2212e469..b3162d9e 100644 --- a/autopush/web/message.py +++ b/autopush/web/message.py @@ -40,8 +40,7 @@ def delete(self, notification): """ - d = deferToThread(self.ap_settings.message.delete_message, - notification) + d = deferToThread(self.db.message.delete_message, notification) d.addCallback(self._delete_completed) self._db_error_handling(d) return d diff --git a/autopush/web/registration.py b/autopush/web/registration.py index 5715e449..edd4cfdb 100644 --- a/autopush/web/registration.py +++ b/autopush/web/registration.py @@ -102,7 +102,7 @@ class TypeAppSchema(Schema): @validates("router_type") def validate_router_type(self, value): - if value not in self.context['settings'].routers: + if value not in self.context['routers']: raise InvalidRequest("Invalid router", status_code=400, errno=108) @@ -115,7 +115,7 @@ class TypeAppUaidSchema(TypeAppSchema): @validates("uaid") def validate_uaid(self, value): try: - self.context['settings'].router.get_uaid(value.hex) + self.context['db'].router.get_uaid(value.hex) except ItemNotFound: raise InvalidRequest("UAID not found", status_code=410, errno=103) @@ -191,7 +191,7 @@ class RouterDataSchema(Schema): @validates_schema(skip_on_field_errors=True) def register_router(self, data): router_type = data["path_kwargs"]["router_type"] - router = self.context["settings"].routers[router_type] + router = self.context["routers"][router_type] try: router.register(uaid="", router_data=data["router_data"], app_id=data["path_kwargs"]["app_id"]) @@ -284,13 +284,13 @@ def base_tags(self): def _register_channel(self, uaid, chid, app_server_key): # type: (uuid.UUID, str, Optional[str]) -> str """Register a new channel and create/return its endpoint""" - self.ap_settings.message.register_channel(uaid.hex, chid) + self.db.message.register_channel(uaid.hex, chid) return self.ap_settings.make_endpoint(uaid.hex, chid, app_server_key) def _register_user(self, uaid, router_type, router_data): # type: (uuid.UUID, str, JSONDict) -> None """Save a new user record""" - self.ap_settings.router.register_user(dict( + self.db.router.register_user(dict( uaid=uaid.hex, router_type=router_type, router_data=router_data, @@ -310,7 +310,7 @@ def _write_endpoint(self, endpoint, uaid, chid, router_type, router_data, self.ap_settings.bear_hash_key[0], uaid.hex) response.update(uaid=uaid.hex, secret=secret) # Apply any router specific fixes to the outbound response. - router = self.ap_settings.routers[router_type] + router = self.routers[router_type] router.amend_endpoint_response(response, router_data) self.set_header("Content-Type", "application/json") self.write(json.dumps(response)) @@ -337,8 +337,8 @@ def post(self, router_type, router_data): Router type/data registration. """ - self.ap_settings.metrics.increment("updates.client.register", - tags=self.base_tags()) + self.metrics.increment("updates.client.register", + tags=self.base_tags()) uaid = uuid.uuid4() @@ -372,7 +372,7 @@ def get(self, uaid): Return a list of known channelIDs for a given UAID """ - d = deferToThread(self.ap_settings.message.all_channels, str(uaid)) + d = deferToThread(self.db.message.all_channels, str(uaid)) d.addCallback(self._write_channels, uaid) d.addErrback(self._uaid_not_found_err) d.addErrback(self._response_err) @@ -411,7 +411,7 @@ def delete(self, uaid): def _delete_uaid(self, uaid): self.log.info(format="Dropping User", code=101, uaid_hash=hasher(uaid.hex)) - if not self.ap_settings.router.drop_user(uaid.hex): + if not self.db.router.drop_user(uaid.hex): raise ItemNotFound("UAID not found") def _uaid_not_found_err(self, fail): @@ -455,8 +455,8 @@ class ChannelRegistrationHandler(BaseRegistrationHandler): @threaded_validate(UnregisterChidSchema) def delete(self, uaid, chid): # type: (uuid.UUID, str) -> Deferred - self.ap_settings.metrics.increment("updates.client.unregister", - tags=self.base_tags()) + self.metrics.increment("updates.client.unregister", + tags=self.base_tags()) d = deferToThread(self._delete_channel, uaid, chid) d.addCallback(self._success) d.addErrback(self._chid_not_found_err) @@ -464,7 +464,7 @@ def delete(self, uaid, chid): return d def _delete_channel(self, uaid, chid): - if not self.ap_settings.message.unregister_channel(uaid.hex, chid): + if not self.db.message.unregister_channel(uaid.hex, chid): raise ItemNotFound("ChannelID not found") def _chid_not_found_err(self, fail): diff --git a/autopush/web/simplepush.py b/autopush/web/simplepush.py index e7b03e5d..9c47d395 100644 --- a/autopush/web/simplepush.py +++ b/autopush/web/simplepush.py @@ -36,6 +36,7 @@ class SimplePushSubscriptionSchema(Schema): def extract_subscription(self, d): try: result = self.context["settings"].parse_endpoint( + self.context["metrics"], token=d["token"], version=d["api_ver"], ) @@ -46,7 +47,7 @@ def extract_subscription(self, d): @validates_schema def validate_uaid_chid(self, d): try: - result = self.context["settings"].router.get_uaid(d["uaid"].hex) + result = self.context["db"].router.get_uaid(d["uaid"].hex) except ItemNotFound: raise InvalidRequest("UAID not found", status_code=410, errno=103) @@ -118,7 +119,7 @@ def put(self, subscription, version, data): channel_id=str(subscription["chid"]), ) - router = self.ap_settings.routers[user_data["router_type"]] + router = self.routers[user_data["router_type"]] d = maybeDeferred(router.route_notification, notification, user_data) d.addCallback(self._router_completed, user_data, "") d.addErrback(self._router_fail_err) diff --git a/autopush/web/webpush.py b/autopush/web/webpush.py index 4b88d8aa..b694bb95 100644 --- a/autopush/web/webpush.py +++ b/autopush/web/webpush.py @@ -25,13 +25,13 @@ ) from autopush.crypto_key import CryptoKey +from autopush.db import DatabaseManager # noqa from autopush.db import dump_uaid, hasher from autopush.exceptions import ( InvalidRequest, InvalidTokenException, VapidAuthException, ) -from autopush.settings import AutopushSettings # noqa from autopush.utils import ( base64url_encode, extract_jwt, @@ -61,6 +61,7 @@ class WebPushSubscriptionSchema(Schema): def extract_subscription(self, d): try: result = self.context["settings"].parse_endpoint( + self.context["metrics"], token=d["token"], version=d["api_ver"], ckey_header=d["ckey_header"], @@ -75,9 +76,10 @@ def extract_subscription(self, d): @validates_schema(skip_on_field_errors=True) def validate_uaid_month_and_chid(self, d): - settings = self.context["settings"] # type: AutopushSettings + db = self.context["db"] # type: DatabaseManager + try: - result = settings.router.get_uaid(d["uaid"].hex) + result = db.router.get_uaid(d["uaid"].hex) except ItemNotFound: raise InvalidRequest("UAID not found", status_code=410, errno=103) @@ -90,7 +92,7 @@ def validate_uaid_month_and_chid(self, d): # Make sure we note that this record is bad. result['critical_failure'] = \ result.get('critical_failure', "Missing SenderID") - settings.router.register_user(result) + db.router.register_user(result) if result.get("critical_failure"): raise InvalidRequest("Critical Failure: %s" % @@ -105,7 +107,7 @@ def validate_uaid_month_and_chid(self, d): d["user_data"] = result def _validate_webpush(self, d, result): - settings = self.context["settings"] # type: AutopushSettings + db = self.context["db"] # type: DatabaseManager log = self.context["log"] # type: Logger channel_id = normalize_id(d["chid"]) uaid = result["uaid"] @@ -113,20 +115,19 @@ def _validate_webpush(self, d, result): log.info(format="Dropping User", code=102, uaid_hash=hasher(uaid), uaid_record=dump_uaid(result)) - settings.router.drop_user(uaid) + db.router.drop_user(uaid) raise InvalidRequest("No such subscription", status_code=410, errno=106) month_table = result["current_month"] - if month_table not in settings.message_tables: + if month_table not in db.message_tables: log.info(format="Dropping User", code=103, uaid_hash=hasher(uaid), uaid_record=dump_uaid(result)) - settings.router.drop_user(uaid) + db.router.drop_user(uaid) raise InvalidRequest("No such subscription", status_code=410, errno=106) - exists, chans = settings.message_tables[month_table].all_channels( - uaid=uaid) + exists, chans = db.message_tables[month_table].all_channels(uaid=uaid) if (not exists or channel_id.lower() not in map(lambda x: normalize_id(x), chans)): @@ -461,7 +462,7 @@ def post(self, version=notification.version, encoding=encoding, ) - router = self.ap_settings.routers[user_data["router_type"]] + router = self.routers[user_data["router_type"]] self._router_time = time.time() d = maybeDeferred(router.route_notification, notification, user_data) d.addCallback(self._router_completed, user_data, "") @@ -484,8 +485,7 @@ def _router_completed(self, response, uaid_data, warning=""): uaid_hash=hasher(uaid_data["uaid"]), uaid_record=dump_uaid(uaid_data), client_info=self._client_info) - d = deferToThread(self.ap_settings.router.drop_user, - uaid_data["uaid"]) + d = deferToThread(self.db.router.drop_user, uaid_data["uaid"]) d.addCallback(lambda x: self._router_response(response)) return d # The router data needs to be updated to include any changes @@ -493,8 +493,7 @@ def _router_completed(self, response, uaid_data, warning=""): uaid_data["router_data"] = response.router_data # set the AWS mandatory data uaid_data["connected_at"] = ms_time() - d = deferToThread(self.ap_settings.router.register_user, - uaid_data) + d = deferToThread(self.db.router.register_user, uaid_data) response.router_data = None d.addCallback(lambda x: self._router_completed( response, diff --git a/autopush/websocket.py b/autopush/websocket.py index 429d34e6..86590ea6 100644 --- a/autopush/websocket.py +++ b/autopush/websocket.py @@ -87,10 +87,11 @@ generate_last_connect, dump_uaid, ) -from autopush.db import Message # noqa +from autopush.db import DatabaseManager, Message # noqa from autopush.exceptions import MessageOverloadException from autopush.noseplugin import track_object from autopush.protocol import IgnoreBody +from autopush.metrics import IMetrics # noqa from autopush.settings import AutopushSettings # noqa from autopush.ssl import AutopushSSLContextFactory from autopush.utils import ( @@ -116,18 +117,14 @@ def extract_code(data): return code -def periodic_reporter(settings, factory): - # type: (AutopushSettings, PushServerFactory) -> None +def periodic_reporter(settings, metrics, factory): + # type: (AutopushSettings, IMetrics, PushServerFactory) -> None """Twisted Task function that runs every few seconds to emit general metrics regarding twisted and client counts""" - settings.metrics.gauge("update.client.writers", - len(reactor.getWriters())) - settings.metrics.gauge("update.client.readers", - len(reactor.getReaders())) - settings.metrics.gauge("update.client.connections", - len(settings.clients)) - settings.metrics.gauge("update.client.ws_connections", - factory.countConnections) + metrics.gauge("update.client.writers", len(reactor.getWriters())) + metrics.gauge("update.client.readers", len(reactor.getReaders())) + metrics.gauge("update.client.connections", len(settings.clients)) + metrics.gauge("update.client.ws_connections", factory.countConnections) def log_exception(func): @@ -187,7 +184,6 @@ class PushState(object): '_base_tags', '_should_stop', '_paused', - 'metrics', '_uaid_obj', '_uaid_hash', 'raw_agent', @@ -197,7 +193,7 @@ class PushState(object): 'router_type', 'wake_data', 'connected_at', - 'settings', + 'db', 'stats', # Table rotation @@ -217,14 +213,13 @@ class PushState(object): 'updates_sent', 'direct_updates', - 'msg_limit', '_reset_uaid', ] - def __init__(self, settings, request): + def __init__(self, db, request): self._callbacks = [] self.stats = SessionStatistics() - self.settings = settings + self.db = db host = "" if request: @@ -245,11 +240,11 @@ def __init__(self, settings, request): if host: self._base_tags.append("host:%s" % host) + db.metrics.increment("client.socket.connect", + tags=self._base_tags or None) + self._should_stop = False self._paused = False - self.metrics = settings.metrics - self.metrics.increment("client.socket.connect", - tags=self._base_tags or None) self.uaid = None self.last_ping = 0 self.check_storage = False @@ -260,12 +255,11 @@ def __init__(self, settings, request): self.ping_time_out = False # Message table rotation initial settings - self.message_month = settings.current_msg_month + self.message_month = db.current_msg_month self.rotate_message_table = False self._check_notifications = False self._more_notifications = False - self.msg_limit = settings.message_limit # Timestamp message defaults self.scan_timestamps = False @@ -291,7 +285,7 @@ def __init__(self, settings, request): def message(self): # type: () -> Message """Property to access the currently used message table""" - return self.settings.message_tables[self.message_month] + return self.db.message_tables[self.message_month] @property def user_agent(self): @@ -370,8 +364,19 @@ class PushServerProtocol(WebSocketServerProtocol, policies.TimeoutMixin): @property def ap_settings(self): + # type: () -> AutopushSettings return self.factory.ap_settings + @property + def db(self): + # type: () -> DatabaseManager + return self.factory.db + + @property + def metrics(self): + # type: () -> IMetrics + return self.db.metrics + # Defer helpers def deferToThread(self, func, *args, **kwargs): # type (Callable[..., Any], *Any, **Any) -> Deferred @@ -458,13 +463,13 @@ def _sendAutoPing(self): """Override for sanity checking during auto-ping interval""" if not self.ps.uaid: # No uaid yet, drop the connection - self.ps.metrics.increment("client.autoping.no_uaid", - tags=self.base_tags) + self.metrics.increment("client.autoping.no_uaid", + tags=self.base_tags) self.sendClose() elif self.ap_settings.clients.get(self.ps.uaid) != self: # UAID, but we're not in clients anymore for some reason - self.ps.metrics.increment("client.autoping.invalid_client", - tags=self.base_tags) + self.metrics.increment("client.autoping.invalid_client", + tags=self.base_tags) self.sendClose() return WebSocketServerProtocol._sendAutoPing(self) @@ -481,13 +486,13 @@ def nukeConnection(self): still hadn't run by this point""" # Did onClose get called? If so, we shutdown properly, no worries. if hasattr(self, "_shutdown_ran"): - self.ps.metrics.increment("client.success.sendClose", - tags=self.base_tags) + self.metrics.increment("client.success.sendClose", + tags=self.base_tags) return # Uh-oh, we have not been shut-down properly, report detailed data - self.ps.metrics.increment("client.error.sendClose_failed", - tags=self.base_tags) + self.metrics.increment("client.error.sendClose_failed", + tags=self.base_tags) self.transport.abortConnection() @@ -495,7 +500,7 @@ def nukeConnection(self): def onConnect(self, request): """autobahn onConnect handler for when a connection has started""" track_object(self, msg="onConnect Start") - self.ps = PushState(settings=self.ap_settings, request=request) + self.ps = PushState(db=self.db, request=request) # Setup ourself to handle producing the data self.transport.bufferSize = 2 * 1024 @@ -610,11 +615,10 @@ def onClose(self, wasClean, code, reason): def cleanUp(self, wasClean, code, reason): """Thorough clean-up method to cancel all remaining deferreds, and send connection metrics in""" - self.ps.metrics.increment("client.socket.disconnect", - tags=self.base_tags) + self.metrics.increment("client.socket.disconnect", tags=self.base_tags) elapsed = (ms_time() - self.ps.connected_at) / 1000.0 - self.ps.metrics.timing("client.socket.lifespan", duration=elapsed, - tags=self.base_tags) + self.metrics.timing("client.socket.lifespan", duration=elapsed, + tags=self.base_tags) self.ps.stats.connection_time = int(elapsed) # Cleanup our client entry @@ -661,7 +665,7 @@ def _save_webpush_notif(self, notif): def _save_simple_notif(self, channel_id, version): """Save a simplepush notification""" return deferToThread( - self.ap_settings.storage.save_notification, + self.db.storage.save_notification, uaid=self.ps.uaid, chid=channel_id, version=version, @@ -672,7 +676,7 @@ def _lookup_node(self, results): connected""" # Locate the node that has this client connected d = deferToThread( - self.ap_settings.router.get_uaid, + self.db.router.get_uaid, self.ps.uaid ) d.addCallback(self._notify_node) @@ -684,15 +688,15 @@ def _trap_uaid_not_found(self, fail): # type: (failure.Failure) -> None """Traps UAID not found error""" fail.trap(ItemNotFound) - self.ps.metrics.increment("client.lookup_uaid_failure", - tags=self.base_tags) + self.metrics.increment("client.lookup_uaid_failure", + tags=self.base_tags) def _notify_node(self, result): """Checks the result of lookup node to send the notify if the client is connected elsewhere now""" if not result: - self.ps.metrics.increment("client.notify_uaid_failure", - tags=self.base_tags) + self.metrics.increment("client.notify_uaid_failure", + tags=self.base_tags) return node_id = result.get("node_id") @@ -840,7 +844,7 @@ def _register_user(self, existing_user=True): if self.ps.wake_data: user_item["wake_data"] = self.ps.wake_data - return self.ap_settings.router.register_user(user_item) + return self.db.router.register_user(user_item) def _verify_user_record(self): """Verify a user record is valid @@ -852,7 +856,7 @@ def _verify_user_record(self): """ try: - record = self.ap_settings.router.get_uaid(self.ps.uaid) + record = self.db.router.get_uaid(self.ps.uaid) except ItemNotFound: return None @@ -862,18 +866,18 @@ def _verify_user_record(self): self.log.info(format="Dropping User", code=104, uaid_hash=self.ps.uaid_hash, uaid_record=dump_uaid(record)) - self.force_retry(self.ap_settings.router.drop_user, self.ps.uaid) + self.force_retry(self.db.router.drop_user, self.ps.uaid) return None # Validate webpush records if self.ps.use_webpush: # Current month must exist and be a valid prior month if ("current_month" not in record) or record["current_month"] \ - not in self.ps.settings.message_tables: + not in self.db.message_tables: self.log.info(format="Dropping User", code=105, uaid_hash=self.ps.uaid_hash, uaid_record=dump_uaid(record)) - self.force_retry(self.ap_settings.router.drop_user, + self.force_retry(self.db.router.drop_user, self.ps.uaid) return None @@ -963,7 +967,7 @@ def finish_hello(self, previous): self.sendJSON(msg) self.log.info(format="hello", uaid_hash=self.ps.uaid_hash, **self.ps.raw_agent) - self.ps.metrics.increment("updates.client.hello", tags=self.base_tags) + self.metrics.increment("updates.client.hello", tags=self.base_tags) self.process_notifications() def process_notifications(self): @@ -996,7 +1000,7 @@ def process_notifications(self): d = self.deferToThread(self.webpush_fetch()) else: d = self.deferToThread( - self.ap_settings.storage.fetch_notifications, self.ps.uaid) + self.db.storage.fetch_notifications, self.ps.uaid) d.addCallback(self.finish_notifications) d.addErrback(self.error_notification_overload) d.addErrback(self.trap_cancel) @@ -1031,7 +1035,7 @@ def error_notification_overload(self, fail): def error_message_overload(self, fail): """errBack for handling excessive messages per UAID""" fail.trap(MessageOverloadException) - self.force_retry(self.ap_settings.router.drop_user, self.ps.uaid) + self.force_retry(self.db.router.drop_user, self.ps.uaid) self.sendClose() def finish_notifications(self, notifs): @@ -1075,7 +1079,7 @@ def finish_notifications(self, notifs): d.addErrback(self.trap_cancel) elif self.ps.reset_uaid: # Told to reset the user? - self.force_retry(self.ap_settings.router.drop_user, self.ps.uaid) + self.force_retry(self.db.router.drop_user, self.ps.uaid) self.sendClose() def finish_webpush_notifications(self, result): @@ -1110,7 +1114,7 @@ def finish_webpush_notifications(self, result): # Told to reset the user? if self.ps.reset_uaid: self.force_retry( - self.ap_settings.router.drop_user, self.ps.uaid) + self.db.router.drop_user, self.ps.uaid) self.sendClose() # Not told to check for notifications, do we need to now rotate @@ -1138,11 +1142,11 @@ def finish_webpush_notifications(self, result): msg = notif.websocket_format() messages_sent = True self.sent_notification_count += 1 - if self.sent_notification_count > self.ps.msg_limit: + if self.sent_notification_count > self.ap_settings.msg_limit: raise MessageOverloadException() if notif.topic: - self.ps.metrics.increment("updates.notification.topic", - tags=self.base_tags) + self.metrics.increment("updates.notification.topic", + tags=self.base_tags) self.sendJSON(msg) # Did we send any messages? @@ -1183,19 +1187,19 @@ def _monthly_transition(self): _, channels = self.ps.message.all_channels(self.ps.uaid) # Get the current message month - cur_month = self.ap_settings.current_msg_month + cur_month = self.db.current_msg_month if channels: # Save the current channels into this months message table - msg_table = self.ap_settings.message_tables[cur_month] + msg_table = self.db.message_tables[cur_month] msg_table.save_channels(self.ps.uaid, channels) # Finally, update the route message month - self.ap_settings.router.update_message_month(self.ps.uaid, cur_month) + self.db.router.update_message_month(self.ps.uaid, cur_month) def _finish_monthly_transition(self, result): """Mark the client as successfully transitioned and resume""" # Update the current month now that we've moved forward a month - self.ps.message_month = self.ap_settings.current_msg_month + self.ps.message_month = self.db.current_msg_month self.ps.rotate_message_table = False self.transport.resumeProducing() @@ -1215,7 +1219,7 @@ def error_monthly_rotation_overload(self, fail): def _send_ping(self): """Helper for ping sending that tracks when the ping was sent""" self.ps.last_ping = time.time() - self.ps.metrics.increment("updates.client.ping", tags=self.base_tags) + self.metrics.increment("updates.client.ping", tags=self.base_tags) return self.sendMessage("{}", False) def process_ping(self): @@ -1291,8 +1295,7 @@ def send_register_finish(self, result, endpoint, chid): "status": 200 } self.sendJSON(msg) - self.ps.metrics.increment("updates.client.register", - tags=self.base_tags) + self.metrics.increment("updates.client.register", tags=self.base_tags) self.ps.stats.registers += 1 self.log.info(format="Register", channel_id=chid, endpoint=endpoint, @@ -1310,8 +1313,8 @@ def process_unregister(self, data): except ValueError: return self.bad_message("unregister", "Invalid ChannelID") - self.ps.metrics.increment("updates.client.unregister", - tags=self.base_tags) + self.metrics.increment("updates.client.unregister", + tags=self.base_tags) self.ps.stats.unregisters += 1 event = dict(format="Unregister", channel_id=chid, uaid_hash=self.ps.uaid_hash, @@ -1335,7 +1338,7 @@ def process_unregister(self, data): self.ps.uaid, chid) else: # Delete any record from storage, we don't wait for this - self.force_retry(self.ap_settings.storage.delete_notification, + self.force_retry(self.db.storage.delete_notification, self.ps.uaid, chid) data["status"] = 200 @@ -1463,7 +1466,7 @@ def _handle_simple_ack(self, chid, version, code): del self.ps.updates_sent[chid] else: return - return self.force_retry(self.ap_settings.storage.delete_notification, + return self.force_retry(self.db.storage.delete_notification, self.ps.uaid, chid, version) def process_ack(self, data): @@ -1473,7 +1476,7 @@ def process_ack(self, data): if not updates or not isinstance(updates, list): return - self.ps.metrics.increment("updates.client.ack", tags=self.base_tags) + self.metrics.increment("updates.client.ack", tags=self.base_tags) defers = filter(None, map(self.ack_update, updates)) if defers: @@ -1551,8 +1554,8 @@ def send_notification(self, update): update) self.ps.direct_updates[chid].append(notif) if notif.topic: - self.ps.metrics.increment("updates.notification.topic", - tags=self.base_tags) + self.metrics.increment("updates.notification.topic", + tags=self.base_tags) self.sendJSON(notif.websocket_format()) else: self.ps.direct_updates[chid] = version @@ -1565,10 +1568,11 @@ class PushServerFactory(WebSocketServerFactory): protocol = PushServerProtocol - def __init__(self, ap_settings): - # type: (AutopushSettings) -> None + def __init__(self, ap_settings, db): + # type: (AutopushSettings, DatabaseManager) -> None WebSocketServerFactory.__init__(self, ap_settings.ws_url) self.ap_settings = ap_settings + self.db = db self.setProtocolOptions( webStatus=False, openHandshakeTimeout=5, @@ -1596,20 +1600,20 @@ def put(self, uaid): client = settings.clients.get(uaid) if not client: self.set_status(404, reason=None) - settings.metrics.increment("updates.router.disconnected") + self.metrics.increment("updates.router.disconnected") self.write("Client not connected.") return if client.paused: self.set_status(503, reason=None) - settings.metrics.increment("updates.router.busy") + self.metrics.increment("updates.router.busy") self.write("Client busy.") return update = json.loads(self.request.body) client.send_notification(update) - settings.metrics.increment("updates.router.received") + self.metrics.increment("updates.router.received") self.write("Client accepted for delivery") @@ -1623,10 +1627,9 @@ def put(self, uaid, *args): """ client = self.ap_settings.clients.get(uaid) - settings = self.ap_settings if not client: self.set_status(404, reason=None) - settings.metrics.increment("updates.notification.disconnected") + self.metrics.increment("updates.notification.disconnected") self.write("Client not connected.") return @@ -1634,13 +1637,13 @@ def put(self, uaid, *args): # Client already busy waiting for stuff, flag for check client._check_notifications = True self.set_status(202) - settings.metrics.increment("updates.notification.flagged") + self.metrics.increment("updates.notification.flagged") self.write("Flagged for Notification check") return # Client is online and idle, start a notification check client.process_notifications() - settings.metrics.increment("updates.notification.checking") + self.metrics.increment("updates.notification.checking") self.write("Notification check started") def delete(self, uaid, connected_at):