Skip to content
This repository has been archived by the owner on Jul 13, 2023. It is now read-only.

Commit

Permalink
Merge pull request #945 from mozilla-services/refactor/632
Browse files Browse the repository at this point in the history
refactor: split clients and agent off settings
  • Loading branch information
jrconlin authored Jun 29, 2017
2 parents ef0b74d + 70129bc commit 2594f9e
Show file tree
Hide file tree
Showing 12 changed files with 155 additions and 103 deletions.
69 changes: 62 additions & 7 deletions autopush/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@
)

import cyclone.web
from twisted.internet import reactor
from twisted.web.client import (
_HTTP11ClientFactory,
Agent,
HTTPConnectionPool,
)

from autopush.base import BaseHandler
from autopush.db import DatabaseManager
Expand Down Expand Up @@ -59,14 +65,12 @@ class BaseHTTPFactory(cyclone.web.Application):
def __init__(self,
ap_settings, # type: AutopushSettings
db, # type: DatabaseManager
routers, # type: Dict[str, IRouter]
handlers=None, # type: APHandlers
log_function=skip_request_logging, # type: CycloneLogger
**kwargs):
# type: (...) -> None
self.ap_settings = ap_settings
self.db = db
self.routers = routers
self.noisy = ap_settings.debug

cyclone.web.Application.__init__(
Expand All @@ -91,7 +95,6 @@ def for_handler(cls,
handler_cls, # Type[BaseHTTPFactory]
ap_settings, # type: AutopushSettings
db=None, # type: Optional[DatabaseManager]
routers=None, # type: Optional[Dict[str, IRouter]]
**kwargs):
# type: (...) -> BaseHTTPFactory
"""Create a cyclone app around a specific handler_cls for tests.
Expand All @@ -109,18 +112,21 @@ def for_handler(cls,
if handler is handler_cls:
if db is None:
db = DatabaseManager.from_settings(ap_settings)
if routers is None:
routers = routers_from_settings(ap_settings, db)
return cls(
return cls._for_handler(
ap_settings,
db=db,
routers=routers,
handlers=[(pattern, handler)],
**kwargs
)
raise ValueError("{!r} not in ap_handlers".format(
handler_cls)) # pragma: nocover

@classmethod
def _for_handler(cls, **kwargs):
# type: (**Any) -> BaseHTTPFactory
"""Create an instance w/ default kwargs for for_handler"""
raise NotImplementedError # pragma: nocover


class EndpointHTTPFactory(BaseHTTPFactory):

Expand All @@ -146,6 +152,15 @@ class EndpointHTTPFactory(BaseHTTPFactory):

protocol = LimitedHTTPConnection

def __init__(self,
ap_settings, # type: AutopushSettings
db, # type: DatabaseManager
routers, # type: Dict[str, IRouter]
**kwargs):
# type: (...) -> None
BaseHTTPFactory.__init__(self, ap_settings, db=db, **kwargs)
self.routers = routers

def ssl_cf(self):
# type: () -> Optional[AutopushSSLContextFactory]
"""Build our SSL Factory (if configured).
Expand All @@ -164,6 +179,16 @@ def ssl_cf(self):
require_peer_certs=settings.enable_tls_auth
)

@classmethod
def _for_handler(cls, ap_settings, db, routers=None, **kwargs):
if routers is None:
routers = routers_from_settings(
ap_settings,
db=db,
agent=agent_from_settings(ap_settings)
)
return cls(ap_settings, db=db, routers=routers, **kwargs)


class InternalRouterHTTPFactory(BaseHTTPFactory):

Expand All @@ -172,6 +197,15 @@ class InternalRouterHTTPFactory(BaseHTTPFactory):
(r"/notif/([^\/]+)(?:/(\d+))?", NotificationHandler),
)

def __init__(self,
ap_settings, # type: AutopushSettings
db, # type: DatabaseManager
clients, # type: Dict[str, PushServerProtocol]
**kwargs):
# type: (...) -> None
BaseHTTPFactory.__init__(self, ap_settings, db, **kwargs)
self.clients = clients

@property
def _hostname(self):
return self.ap_settings.router_hostname
Expand All @@ -193,9 +227,30 @@ def ssl_cf(self):
dh_file=settings.ssl_dh_param
)

@classmethod
def _for_handler(cls, ap_settings, db, clients=None, **kwargs):
if clients is None:
clients = {}
return cls(ap_settings, db=db, clients=clients, **kwargs)


class MemUsageHTTPFactory(BaseHTTPFactory):

ap_handlers = (
(r"^/_memusage", MemUsageHandler),
)


class QuietClientFactory(_HTTP11ClientFactory):
"""Silence the start/stop factory messages."""
noisy = False


def agent_from_settings(settings):
# type: (AutopushSettings) -> Agent
"""Create a twisted.web.client Agent from settings"""
# Use a persistent connection pool for HTTP requests.
pool = HTTPConnectionPool(reactor)
if not settings.debug:
pool._factory = QuietClientFactory
return Agent(reactor, connectTimeout=settings.connect_timeout, pool=pool)
29 changes: 21 additions & 8 deletions autopush/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
)
from twisted.application.service import MultiService
from twisted.internet import reactor
from twisted.internet.defer import inlineCallbacks
from twisted.internet.protocol import ServerFactory # noqa
from twisted.logger import Logger
from typing import ( # noqa
Expand All @@ -21,7 +22,8 @@
from autopush.http import (
InternalRouterHTTPFactory,
EndpointHTTPFactory,
MemUsageHTTPFactory
MemUsageHTTPFactory,
agent_from_settings
)
import autopush.utils as utils
import autopush.logging as logging
Expand All @@ -35,7 +37,6 @@
from autopush.websocket import (
ConnectionWSSite,
PushServerFactory,
periodic_reporter,
)

log = Logger()
Expand All @@ -60,6 +61,7 @@ def __init__(self, settings):
super(AutopushMultiService, self).__init__()
self.settings = settings
self.db = DatabaseManager.from_settings(settings)
self.agent = agent_from_settings(settings)

@staticmethod
def parse_args(config_files, args):
Expand Down Expand Up @@ -87,7 +89,7 @@ def add_timer(self, *args, **kwargs):

def add_memusage(self):
"""Add the memusage Service"""
factory = MemUsageHTTPFactory(self.settings, None, None)
factory = MemUsageHTTPFactory(self.settings, None)
self.addService(
TCPServer(self.settings.memusage_port, factory, reactor=reactor))

Expand All @@ -97,6 +99,11 @@ def run(self):
self.startService()
reactor.run()

@inlineCallbacks
def stopService(self):
yield self.agent._pool.closeCachedConnections()
yield super(AutopushMultiService, self).stopService()

@classmethod
def _from_argparse(cls, ns, **kwargs):
# type: (Namespace, **Any) -> AutopushMultiService
Expand Down Expand Up @@ -157,7 +164,8 @@ class EndpointApplication(AutopushMultiService):

def __init__(self, *args, **kwargs):
super(EndpointApplication, self).__init__(*args, **kwargs)
self.routers = routers_from_settings(self.settings, self.db)
self.routers = routers_from_settings(self.settings, self.db,
self.agent)

def setup(self, rotate_tables=True):
self.db.setup(self.settings.preflight_uaid)
Expand Down Expand Up @@ -220,6 +228,10 @@ class ConnectionApplication(AutopushMultiService):
websocket_factory = PushServerFactory
websocket_site_factory = ConnectionWSSite

def __init__(self, *args, **kwargs):
super(ConnectionApplication, self).__init__(*args, **kwargs)
self.clients = {}

def setup(self, rotate_tables=True):
self.db.setup(self.settings.preflight_uaid)

Expand All @@ -235,19 +247,20 @@ def setup(self, rotate_tables=True):

def add_internal_router(self):
"""Start the internal HTTP notification router"""
factory = self.internal_router_factory(self.settings, self.db, None)
factory = self.internal_router_factory(
self.settings, self.db, self.clients)
factory.add_health_handlers()
self.add_maybe_ssl(self.settings.router_port, factory,
factory.ssl_cf())

def add_websocket(self):
"""Start the public WebSocket server"""
settings = self.settings
ws_factory = self.websocket_factory(settings, self.db)
ws_factory = self.websocket_factory(settings, self.db, self.agent,
self.clients)
site_factory = self.websocket_site_factory(settings, ws_factory)
self.add_maybe_ssl(settings.port, site_factory, site_factory.ssl_cf())
self.add_timer(1.0, periodic_reporter, settings, self.db.metrics,
ws_factory)
self.add_timer(1.0, ws_factory.periodic_reporter, self.db.metrics)

@classmethod
def from_argparse(cls, ns):
Expand Down
10 changes: 6 additions & 4 deletions autopush/router/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
"""
from typing import Dict # noqa

from twisted.web.client import Agent # noqa

from autopush.db import DatabaseManager # noqa
from autopush.router.apnsrouter import APNSRouter
from autopush.router.gcm import GCMRouter
Expand All @@ -19,14 +21,14 @@
"WebPushRouter"]


def routers_from_settings(settings, db):
# type: (AutopushSettings, DatabaseManager) -> Dict[str, IRouter]
def routers_from_settings(settings, db, agent):
# type: (AutopushSettings, DatabaseManager, Agent) -> Dict[str, IRouter]
"""Create a dict of IRouters for the given settings"""
router_conf = settings.router_conf
routers = dict(
simplepush=SimpleRouter(
settings, router_conf.get("simplepush"), db),
webpush=WebPushRouter(settings, None, db)
settings, router_conf.get("simplepush"), db, agent),
webpush=WebPushRouter(settings, None, db, agent)
)
if 'apns' in router_conf:
routers["apns"] = APNSRouter(settings, router_conf["apns"], db.metrics)
Expand Down
7 changes: 4 additions & 3 deletions autopush/router/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,12 @@ class SimpleRouter(object):
"""
log = Logger()

def __init__(self, ap_settings, router_conf, db):
def __init__(self, ap_settings, router_conf, db, agent):
"""Create a new SimpleRouter"""
self.ap_settings = ap_settings
self.conf = router_conf
self.db = db
self.agent = agent
self.waker = None

@property
Expand Down Expand Up @@ -195,7 +196,7 @@ def _send_notification(self, uaid, node_id, notification):
"version": notification.version,
"data": notification.data})
url = node_id + "/push/" + uaid
d = self.ap_settings.agent.request(
d = self.agent.request(
"PUT",
url.encode("utf8"),
bodyProducer=FileBodyProducer(StringIO(payload)),
Expand All @@ -206,7 +207,7 @@ def _send_notification(self, uaid, node_id, notification):
def _send_notification_check(self, uaid, node_id):
"""Send a command to the node to check for notifications"""
url = node_id + "/notif/" + uaid
return self.ap_settings.agent.request(
return self.agent.request(
"PUT",
url.encode("utf8"),
).addCallback(IgnoreBody.ignore)
Expand Down
2 changes: 1 addition & 1 deletion autopush/router/webpush.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _send_notification(self, uaid, node_id, notification):
payload = notification.serialize()
payload["timestamp"] = int(time.time())
url = node_id + "/push/" + uaid
request = self.ap_settings.agent.request(
request = self.agent.request(
"PUT",
url.encode("utf8"),
bodyProducer=FileBodyProducer(StringIO(json.dumps(payload))),
Expand Down
12 changes: 2 additions & 10 deletions autopush/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@

from cryptography.fernet import Fernet, MultiFernet
from cryptography.hazmat.primitives import constant_time
from twisted.internet import reactor
from twisted.web.client import Agent, HTTPConnectionPool, _HTTP11ClientFactory

from twisted.web.client import _HTTP11ClientFactory

import autopush.db as db
from autopush.exceptions import (
Expand Down Expand Up @@ -98,13 +96,8 @@ def __init__(self,
"""
self.debug = debug
# Use a persistent connection pool for HTTP requests.
pool = HTTPConnectionPool(reactor)
if not debug:
pool._factory = QuietClientFactory

self.agent = Agent(reactor, connectTimeout=connect_timeout,
pool=pool)
self.connect_timeout = connect_timeout

if not crypto_key:
crypto_key = [Fernet.generate_key()]
Expand All @@ -120,7 +113,6 @@ def __init__(self,
self.bear_hash_key = bear_hash_key

self.max_data = max_data
self.clients = {}

# Setup hosts/ports/urls
default_hostname = socket.gethostname()
Expand Down
2 changes: 1 addition & 1 deletion autopush/tests/test_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def setUp(self):
}
db.create_initial_message_tables()

self.routers = routers = routers_from_settings(settings, db)
self.routers = routers = routers_from_settings(settings, db, Mock())
routers["test"] = Mock(spec=IRouter)
app = EndpointHTTPFactory(settings, db=db, routers=routers)
self.client = Client(app)
Expand Down
9 changes: 3 additions & 6 deletions autopush/tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,9 +375,6 @@ def setUp(self):
crypto_key=crypto_key,
**self.conn_kwargs()
)
# Dirty reactor unless we shut down the cached connections
self.addCleanup(ep_settings.agent._pool.closeCachedConnections)
self.addCleanup(conn_settings.agent._pool.closeCachedConnections)

# Endpoint HTTP router
self.ep = ep = EndpointApplication(ep_settings)
Expand Down Expand Up @@ -1409,7 +1406,7 @@ def test_webpush_monthly_rotation(self):
eq_(chan, result["channelID"])

# Check that the client is going to rotate the month
server_client = self.conn.settings.clients[client.uaid]
server_client = self.conn.clients[client.uaid]
eq_(server_client.ps.rotate_message_table, True)

# Acknowledge the notification, which triggers the migration
Expand Down Expand Up @@ -1505,7 +1502,7 @@ def test_webpush_monthly_rotation_prior_record_exists(self):
eq_(chan, result["channelID"])

# Check that the client is going to rotate the month
server_client = self.conn.settings.clients[client.uaid]
server_client = self.conn.clients[client.uaid]
eq_(server_client.ps.rotate_message_table, True)

# Acknowledge the notification, which triggers the migration
Expand Down Expand Up @@ -1573,7 +1570,7 @@ def test_webpush_monthly_rotation_no_channels(self):
yield client.hello()

# Check that the client is going to rotate the month
server_client = self.conn.settings.clients[client.uaid]
server_client = self.conn.clients[client.uaid]
eq_(server_client.ps.rotate_message_table, True)

# Wait up to 2 seconds for the table rotation to occur
Expand Down
Loading

0 comments on commit 2594f9e

Please sign in to comment.