Skip to content
This repository has been archived by the owner on Jul 13, 2023. It is now read-only.

Commit

Permalink
feat: cache SSLContexts created by ssl.wrap_socket
Browse files Browse the repository at this point in the history
avoids unnecessary allocation of them which appears to stress pypy's
GC (w/ their finalization needs)

also fix some type signatures and kill some now unnecessary mock patches

Closes: #1038
  • Loading branch information
pjenvey committed Oct 18, 2017
1 parent 334bf81 commit 6dcbba2
Show file tree
Hide file tree
Showing 8 changed files with 156 additions and 13 deletions.
4 changes: 4 additions & 0 deletions autopush/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,9 @@ class AutopushConfig(object):
# Strict-Transport-Security max age (Default 1 year in secs)
sts_max_age = attrib(default=31536000) # type: int

# Don't cache ssl.wrap_socket's SSLContexts
no_sslcontext_cache = attrib(default=False) # type: bool

def __attrs_post_init__(self):
"""Initialize the Settings object"""
# Setup hosts/ports/urls
Expand Down Expand Up @@ -303,6 +306,7 @@ def from_argparse(cls, ns, **kwargs):
connect_timeout=ns.connection_timeout,
memusage_port=ns.memusage_port,
use_cryptography=ns.use_cryptography,
no_sslcontext_cache=ns._no_sslcontext_cache,
router_table=dict(
tablename=ns.router_tablename,
read_throughput=ns.router_read_throughput,
Expand Down
3 changes: 2 additions & 1 deletion autopush/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
Set,
TypeVar,
Tuple,
Union,
)
from twisted.internet.defer import Deferred # noqa
from twisted.internet.defer import inlineCallbacks, returnValue
Expand Down Expand Up @@ -504,7 +505,7 @@ def fetch_messages(
def fetch_timestamp_messages(
self,
uaid, # type: uuid.UUID
timestamp=None, # type: Optional[int or str]
timestamp=None, # type: Optional[Union[int, str]]
limit=10, # type: int
):
# type: (...) -> Tuple[Optional[int], List[WebPushNotification]]
Expand Down
15 changes: 14 additions & 1 deletion autopush/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
from autopush.logging import PushLogger
from autopush.main_argparse import parse_connection, parse_endpoint
from autopush.router import routers_from_config
from autopush.ssl import (
monkey_patch_ssl_wrap_socket,
undo_monkey_patch_ssl_wrap_socket,
)
from autopush.webpush_server import WebPushServer
from autopush.websocket import (
ConnectionWSSite,
Expand Down Expand Up @@ -75,7 +79,8 @@ def parse_args(config_files, args):
def setup(self, rotate_tables=True):
# type: (bool) -> None
"""Initialize the services"""
raise NotImplementedError # pragma: nocover
if not self.conf.no_sslcontext_cache:
monkey_patch_ssl_wrap_socket()

def add_maybe_ssl(self, port, factory, ssl_cf):
# type: (int, ServerFactory, Optional[Any]) -> None
Expand Down Expand Up @@ -106,6 +111,8 @@ def run(self):
def stopService(self):
yield self.agent._pool.closeCachedConnections()
yield super(AutopushMultiService, self).stopService()
if not self.conf.no_sslcontext_cache:
undo_monkey_patch_ssl_wrap_socket()

@classmethod
def _from_argparse(cls, ns, **kwargs):
Expand Down Expand Up @@ -171,6 +178,8 @@ def __init__(self, conf):
self.routers = routers_from_config(conf, self.db, self.agent)

def setup(self, rotate_tables=True):
super(EndpointApplication, self).setup(rotate_tables)

self.db.setup(self.conf.preflight_uaid)

self.add_endpoint()
Expand Down Expand Up @@ -237,6 +246,8 @@ def __init__(self, conf):
self.clients = {} # type: Dict[str, PushServerProtocol]

def setup(self, rotate_tables=True):
super(ConnectionApplication, self).setup(rotate_tables)

self.db.setup(self.conf.preflight_uaid)

self.add_internal_router()
Expand Down Expand Up @@ -309,6 +320,8 @@ def __init__(self, conf):
super(RustConnectionApplication, self).__init__(conf)

def setup(self, rotate_tables=True):
super(RustConnectionApplication, self).setup(rotate_tables)

self.db.setup(self.conf.preflight_uaid)

if self.conf.memusage_port:
Expand Down
4 changes: 4 additions & 0 deletions autopush/main_argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ def add_shared_args(parser):
help="Max Strict Transport Age in seconds",
type=int, default=31536000,
env_var="STS_MAX_AGE")
parser.add_argument('--_no_sslcontext_cache',
help="Don't cache ssl.wrap_socket's SSLContexts",
action="store_true", default=False,
env_var="_NO_SSLCONTEXT_CACHE")
# No ENV because this is for humans
_add_external_router_args(parser)
_obsolete_args(parser)
Expand Down
81 changes: 78 additions & 3 deletions autopush/ssl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
"""Custom SSL configuration"""
from __future__ import absolute_import
import socket # noqa
import ssl
from typing import ( # noqa
Any,
Dict,
FrozenSet,
Optional,
Tuple,
)

from OpenSSL import SSL
from twisted.internet import ssl
from twisted.internet.ssl import DefaultOpenSSLContextFactory

MOZILLA_INTERMEDIATE_CIPHERS = (
'ECDHE-RSA-AES128-GCM-SHA256:'
Expand Down Expand Up @@ -36,13 +47,13 @@
)


class AutopushSSLContextFactory(ssl.DefaultOpenSSLContextFactory):
class AutopushSSLContextFactory(DefaultOpenSSLContextFactory):
"""A SSL context factory"""

def __init__(self, *args, **kwargs):
self.dh_file = kwargs.pop('dh_file', None)
self.require_peer_certs = kwargs.pop('require_peer_certs', False)
ssl.DefaultOpenSSLContextFactory.__init__(self, *args, **kwargs)
DefaultOpenSSLContextFactory.__init__(self, *args, **kwargs)

def cacheContext(self):
"""Setup the main context factory with custom SSL settings"""
Expand Down Expand Up @@ -77,3 +88,67 @@ def _allow_peer(self, conn, cert, errno, depth, preverify_ok):
# skip verification: we only care about whitelisted signatures
# on file
return True


def monkey_patch_ssl_wrap_socket():
"""Replace ssl.wrap_socket with ssl_wrap_socket_cached"""
ssl.wrap_socket = ssl_wrap_socket_cached


def undo_monkey_patch_ssl_wrap_socket():
"""Undo monkey_patch_ssl_wrap_socket"""
ssl.wrap_socket = _orig_ssl_wrap_socket


_CacheKey = FrozenSet[Tuple[str, Any]]
_sslcontext_cache = {} # type: Dict[_CacheKey, ssl.SSLContext]
_orig_ssl_wrap_socket = ssl.wrap_socket


def ssl_wrap_socket_cached(
sock, # type: socket.socket
keyfile=None, # type: Optional[str]
certfile=None, # type: Optional[str]
server_side=False, # type: bool
cert_reqs=ssl.CERT_NONE, # type: int
ssl_version=ssl.PROTOCOL_TLS, # type: int
ca_certs=None, # type: Optional[str]
do_handshake_on_connect=True, # type: bool
suppress_ragged_eofs=True, # type: bool
ciphers=None # type: Optional[str]
):
# type: (...) -> ssl.SSLSocket
"""ssl.wrap_socket replacement that caches SSLContexts"""
key_kwargs = (
('keyfile', keyfile),
('certfile', certfile),
('cert_reqs', cert_reqs),
('ssl_version', ssl_version),
('ca_certs', ca_certs),
('ciphers', ciphers),
)
key = frozenset(key_kwargs)

context = _sslcontext_cache.get(key)
if context is not None:
return context.wrap_socket(
sock,
server_side=server_side,
do_handshake_on_connect=do_handshake_on_connect,
suppress_ragged_eofs=suppress_ragged_eofs
)

wrapped = _orig_ssl_wrap_socket(
sock,
keyfile=keyfile,
certfile=certfile,
server_side=server_side,
cert_reqs=cert_reqs,
ssl_version=ssl_version,
ca_certs=ca_certs,
do_handshake_on_connect=do_handshake_on_connect,
suppress_ragged_eofs=suppress_ragged_eofs,
ciphers=ciphers
)
_sslcontext_cache[key] = wrapped.context
return wrapped
16 changes: 10 additions & 6 deletions autopush/tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ class TestArg(AutopushConfig):
disable_simplepush = True
use_cryptography = False
sts_max_age = 1234
_no_sslcontext_cache = False

def setUp(self):
patchers = [
Expand Down Expand Up @@ -318,8 +319,7 @@ def test_memusage(self):
"--memusage_port=8083",
], False)

@patch('hyper.tls', spec=hyper.tls)
def test_client_certs_parse(self, mock):
def test_client_certs_parse(self):
conf = AutopushConfig.from_argparse(self.TestArg)
assert conf.client_certs["1A:"*31 + "F9"] == 'partner1'
assert conf.client_certs["2B:"*31 + "E8"] == 'partner2'
Expand Down Expand Up @@ -377,15 +377,19 @@ def test_gcm_start(self):
"""--senderid_list={"123":{"auth":"abcd"}}""",
], False)

@patch('autopush.router.apns2.HTTP20Connection',
spec=hyper.HTTP20Connection)
@patch('hyper.tls', spec=hyper.tls)
@patch("requests.get")
def test_aws_ami_id(self, request_mock, mt, mc):
def test_aws_ami_id(self, request_mock):
class MockReply:
content = "ami_123"

request_mock.return_value = MockReply
self.TestArg.no_aws = False
conf = AutopushConfig.from_argparse(self.TestArg)
assert conf.ami_id == "ami_123"

def test_no_sslcontext_cache(self):
conf = AutopushConfig.from_argparse(self.TestArg)
assert not conf.no_sslcontext_cache
self.TestArg._no_sslcontext_cache = True
conf = AutopushConfig.from_argparse(self.TestArg)
assert conf.no_sslcontext_cache
36 changes: 36 additions & 0 deletions autopush/tests/test_ssl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import socket
import ssl
from twisted.trial import unittest

from autopush.ssl import (
monkey_patch_ssl_wrap_socket,
ssl_wrap_socket_cached,
undo_monkey_patch_ssl_wrap_socket
)


class SSLContextCacheTestCase(unittest.TestCase):

def setUp(self):
# XXX: test_main doesn't cleanup after itself
undo_monkey_patch_ssl_wrap_socket()

def test_monkey_patch_ssl_wrap_socket(self):
assert ssl.wrap_socket is not ssl_wrap_socket_cached
orig = ssl.wrap_socket
monkey_patch_ssl_wrap_socket()
self.addCleanup(undo_monkey_patch_ssl_wrap_socket)

assert ssl.wrap_socket is ssl_wrap_socket_cached
undo_monkey_patch_ssl_wrap_socket()
assert ssl.wrap_socket is orig

def test_ssl_wrap_socket_cached(self):
monkey_patch_ssl_wrap_socket()
self.addCleanup(undo_monkey_patch_ssl_wrap_socket)

s1 = socket.create_connection(('search.yahoo.com', 443))
s2 = socket.create_connection(('google.com', 443))
ssl1 = ssl.wrap_socket(s1, do_handshake_on_connect=False)
ssl2 = ssl.wrap_socket(s2, do_handshake_on_connect=False)
assert ssl1.context is ssl2.context
10 changes: 8 additions & 2 deletions autopush/webpush_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@
)
from boto.dynamodb2.exceptions import ItemNotFound
from boto.exception import JSONResponseError
from typing import Dict, List, Optional # noqa
from typing import ( # noqa
Dict,
List,
Optional,
Tuple,
Union
)
from twisted.logger import Logger

from autopush.db import ( # noqa
Expand Down Expand Up @@ -201,7 +207,7 @@ class StoreMessagesResponse(OutputCommand):
###############################################################################
class WebPushServer(object):
def __init__(self, conf, db, num_threads=10):
# type: (AutopushConfig, DatabaseManager) -> WebPushServer
# type: (AutopushConfig, DatabaseManager, int) -> WebPushServer
self.conf = conf
self.db = db
self.db.setup_tables()
Expand Down

0 comments on commit 6dcbba2

Please sign in to comment.