Skip to content
This repository has been archived by the owner on Jul 13, 2023. It is now read-only.

Commit

Permalink
fix: impl. a haproxy endpoint that actually wraps SSL
Browse files Browse the repository at this point in the history
closes #823
  • Loading branch information
pjenvey committed Feb 17, 2017
1 parent fbbda50 commit f39886d
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 76 deletions.
35 changes: 35 additions & 0 deletions autopush/haproxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from twisted.internet import defer
from twisted.internet.interfaces import IStreamServerEndpoint
from twisted.protocols.haproxy._wrapper import HAProxyWrappingFactory
from twisted.protocols.tls import TLSMemoryBIOFactory
from zope.interface import implementer


@implementer(IStreamServerEndpoint)
class HAProxyServerEndpoint(object):
"""A HAProxy endpoint, optionally handling TLS"""

wrapper_factory = HAProxyWrappingFactory

def __init__(self, reactor, port, ssl_cf=None, **kwargs):
self._reactor = reactor
self._port = port
self._ssl_cf = ssl_cf
self._kwargs = kwargs

def listen(self, factory):
"""Implement IStreamServerEndpoint.listen to listen on TCP.
Optionally configuring TLS behind the HAProxy protocol.
"""
if self._ssl_cf:
factory = TLSMemoryBIOFactory(self._ssl_cf, False, factory)
proxyf = self.wrapper_factory(factory)
return defer.execute(self._listen, self._port, proxyf, **self._kwargs)

def _listen(self, *args, **kwargs):
port = self._reactor.listenTCP(*args, **kwargs)
if self._ssl_cf:
port._type = 'TLS'
return port
19 changes: 11 additions & 8 deletions autopush/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,21 +610,24 @@ def endpoint_main(sysargs=None, use_files=True):

settings.metrics.start()

def create_endpoint(port):
if not args.ssl_key:
return TCP4ServerEndpoint(reactor, port)
if args.ssl_key:
ssl_cf = AutopushSSLContextFactory(
args.ssl_key,
args.ssl_cert,
dh_file=args.ssl_dh_param,
require_peer_certs=settings.enable_tls_auth)
return SSL4ServerEndpoint(reactor, port, ssl_cf)

endpoint = create_endpoint(args.port)
endpoint = SSL4ServerEndpoint(reactor, args.port, ssl_cf)
else:
ssl_cf = None
endpoint = TCP4ServerEndpoint(reactor, args.port)
endpoint.listen(site)

if args.proxy_protocol_port:
from twisted.protocols.haproxy import proxyEndpoint
pendpoint = proxyEndpoint(create_endpoint(args.proxy_protocol_port))
from autopush.haproxy import HAProxyServerEndpoint
pendpoint = HAProxyServerEndpoint(
reactor,
args.proxy_protocol_port,
ssl_cf)
pendpoint.listen(site)

reactor.suggestThreadPoolSize(50)
Expand Down
196 changes: 130 additions & 66 deletions autopush/tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ class IntegrationBase(unittest.TestCase):
track_objects = True
track_objects_excludes = [AutopushSettings, PushServerFactory]
proxy_protocol_port = None
endpoint_scheme = 'http'

def setUp(self):
import cyclone.web
Expand All @@ -383,12 +384,10 @@ def setUp(self):
storage_table = os.environ.get("STORAGE_TABLE", "storage_int_test")
message_table = os.environ.get("MESSAGE_TABLE", "message_int_test")

client_certs = self.make_client_certs()
is_https = client_certs is not None

self.endpoint_port = 9020
self.router_port = 9030
self.memusage_port = 9040
client_certs = self.make_client_certs()
settings = AutopushSettings(
hostname="localhost",
statsd_host=None,
Expand All @@ -398,7 +397,7 @@ def setUp(self):
storage_tablename=storage_table,
message_tablename=message_table,
client_certs=client_certs,
endpoint_scheme='https' if is_https else 'http',
endpoint_scheme=self.endpoint_scheme,
)

# Websocket server
Expand Down Expand Up @@ -445,25 +444,24 @@ def setUp(self):
self._settings = settings

self.endpoints = []
ep = self._create_endpoint(self.endpoint_port, is_https)
ep = self.create_endpoint(self.endpoint_port)
ep.listen(site).addCallback(self._endpoint_listening)

if self.proxy_protocol_port:
from twisted.protocols.haproxy import proxyEndpoint
ep = proxyEndpoint(
self._create_endpoint(self.proxy_protocol_port, is_https)
)
ep = self.create_proxy_endpoint(self.proxy_protocol_port)
ep.listen(site).addCallback(self._endpoint_listening)

self.memusage_site = create_memusage_site(
settings,
self.memusage_port,
False)

def _create_endpoint(self, port, is_https):
if not is_https:
return TCP4ServerEndpoint(reactor, port)
return SSL4ServerEndpoint(reactor, port, self.endpoint_SSLCF())
def create_endpoint(self, port):
return TCP4ServerEndpoint(reactor, port)

def create_proxy_endpoint(self, port):
from autopush.haproxy import HAProxyServerEndpoint
return HAProxyServerEndpoint(reactor, self.proxy_protocol_port)

def _endpoint_listening(self, port):
self.endpoints.append(port)
Expand Down Expand Up @@ -508,6 +506,70 @@ def legacy_endpoint(self):
self._settings._notification_legacy = False


class SSLEndpointMixin(object):

endpoint_scheme = 'https'
certs = os.path.join(os.path.dirname(__file__), "certs")
servercert = os.path.join(certs, "server.pem")

def create_endpoint(self, port):
return SSL4ServerEndpoint(reactor, port, self.endpoint_SSLCF())

def endpoint_SSLCF(self):
"""Return an SSLContextFactory for the endpoint.
Configured with the self-signed test server.pem. server.pem is
additionally the signer of the client certs in the same dir.
"""
from autopush.ssl import AutopushSSLContextFactory
return AutopushSSLContextFactory(
self.servercert,
self.servercert,
require_peer_certs=self._settings.enable_tls_auth)

def client_SSLCF(self, certfile):
"""Return an IPolicyForHTTPS for verifiying tests' server cert.
Optionally configures a client cert.
"""
from twisted.internet.ssl import (
Certificate, PrivateCertificate, optionsForClientTLS)
from twisted.web.iweb import IPolicyForHTTPS
from zope.interface import implementer

with open(self.servercert) as fp:
servercert = Certificate.loadPEM(fp.read())
if certfile:
with open(self.unauth_client) as fp:
unauth_client = PrivateCertificate.loadPEM(fp.read())
else:
unauth_client = None

@implementer(IPolicyForHTTPS)
class UnauthClientPolicyForHTTPS(object):
def creatorForNetloc(self, hostname, port):
return optionsForClientTLS(
hostname.decode('ascii'),
trustRoot=servercert,
clientCertificate=unauth_client)
return UnauthClientPolicyForHTTPS()

def _create_context(self, certfile):
"""Return a client SSLContext
Optionally configures a client cert.
"""
import ssl
context = ssl.create_default_context()
if certfile:
context.load_cert_chain(certfile)
context.load_verify_locations(self.servercert)
return context


class TestSimple(IntegrationBase):
@inlineCallbacks
def test_delivery_while_disconnected(self):
Expand Down Expand Up @@ -1573,66 +1635,20 @@ def test_with_key(self):
yield self.shut_down(client)


class TestClientCerts(IntegrationBase):
class TestClientCerts(SSLEndpointMixin, IntegrationBase):

def setUp(self):
self.certs = certs = os.path.join(os.path.dirname(__file__), "certs")
self.servercert = os.path.join(certs, "server.pem")
certs = self.certs
self.auth_client = os.path.join(certs, "client1.pem")
self.unauth_client = os.path.join(certs, "client2.pem")
with open(os.path.join(self.certs, "client1_sha256.txt")) as fp:
with open(os.path.join(certs, "client1_sha256.txt")) as fp:
client1_sha256 = fp.read().strip()
self._client_certs = {client1_sha256: 'partner1'}
IntegrationBase.setUp(self)

def make_client_certs(self):
return self._client_certs

def endpoint_SSLCF(self):
"""Return an SSLContextFactory for the endpoint.
Configured with the self-signed test server.pem. server.pem is
additionally the signer of the client certs.
"""
from autopush.ssl import AutopushSSLContextFactory
return AutopushSSLContextFactory(
self.servercert,
self.servercert,
require_peer_certs=self._settings.enable_tls_auth)

def _create_unauth_SSLCF(self, certfile):
"""Return an IPolicyForHTTPS for the unauthorized client"""
from twisted.internet.ssl import (
Certificate, PrivateCertificate, optionsForClientTLS)
from twisted.web.iweb import IPolicyForHTTPS
from zope.interface import implementer

with open(self.servercert) as fp:
servercert = Certificate.loadPEM(fp.read())
if certfile:
with open(self.unauth_client) as fp:
unauth_client = PrivateCertificate.loadPEM(fp.read())
else:
unauth_client = None

@implementer(IPolicyForHTTPS)
class UnauthClientPolicyForHTTPS(object):
def creatorForNetloc(self, hostname, port):
return optionsForClientTLS(hostname.decode('ascii'),
trustRoot=servercert,
clientCertificate=unauth_client)
return UnauthClientPolicyForHTTPS()

def _create_context(self, certfile):
"""Return a client SSLContext"""
import ssl
context = ssl.create_default_context()
if certfile:
context.load_cert_chain(certfile)
context.load_verify_locations(self.servercert)
return context

@inlineCallbacks
def test_client_cert_simple(self):
client = yield self.quick_register(
Expand Down Expand Up @@ -1696,7 +1712,7 @@ def _test_unauth(self, certfile):
response, body = yield _agent(
'DELETE',
"https://localhost:9020/m/foo",
contextFactory=self._create_unauth_SSLCF(certfile))
contextFactory=self.client_SSLCF(certfile))
eq_(response.code, 401)
wwwauth = response.headers.getRawHeaders('www-authenticate')
eq_(wwwauth, ['Transport mode="tls-client-certificate"'])
Expand All @@ -1714,7 +1730,7 @@ def _test_log_check_skips_auth(self, certfile):
response, body = yield _agent(
'GET',
"https://localhost:9020/v1/err",
contextFactory=self._create_unauth_SSLCF(certfile))
contextFactory=self.client_SSLCF(certfile))
eq_(response.code, 418)
payload = json.loads(body)
eq_(payload['error'], "Test Error")
Expand All @@ -1732,7 +1748,7 @@ def _test_status_skips_auth(self, certfile):
response, body = yield _agent(
'GET',
"https://localhost:9020/status",
contextFactory=self._create_unauth_SSLCF(certfile))
contextFactory=self.client_SSLCF(certfile))
eq_(response.code, 200)
payload = json.loads(body)
eq_(payload, dict(status="OK", version=__version__))
Expand All @@ -1750,7 +1766,7 @@ def _test_health_skips_auth(self, certfile):
response, body = yield _agent(
'GET',
"https://localhost:9020/health",
contextFactory=self._create_unauth_SSLCF(certfile))
contextFactory=self.client_SSLCF(certfile))
eq_(response.code, 200)
payload = json.loads(body)
eq_(payload['version'], __version__)
Expand Down Expand Up @@ -2045,6 +2061,54 @@ def test_no_proxy_protocol(self):
eq_(payload['error'], "Test Error")


class TestProxyProtocolSSL(SSLEndpointMixin, IntegrationBase):
proxy_protocol_port = 9021

def create_proxy_endpoint(self, port):
from autopush.haproxy import HAProxyServerEndpoint
return HAProxyServerEndpoint(
reactor,
self.proxy_protocol_port,
self.endpoint_SSLCF())

@inlineCallbacks
def test_proxy_protocol_ssl(self):
ip = '198.51.100.22'

def proxy_request():
# like TestProxyProtocol.test_proxy_protocol, we prepend
# the proxy proto. line before the payload (which is
# encrypted in this case). HACK: sneak around httplib's
# wrapped ssl sock by hooking into SSLContext.wrap_socket
proto_line = 'PROXY TCP4 {} 203.0.113.7 35646 80\r\n'.format(ip)

class SSLContextWrapper(object):
def __init__(self, context):
self.context = context

def wrap_socket(self, sock, *args, **kwargs):
# send proto_line over the raw, unencrypted sock
sock.send(proto_line)
# now do the handshake/encrypt sock
return self.context.wrap_socket(sock, *args, **kwargs)

http = httplib.HTTPSConnection(
"localhost:{}".format(self.proxy_protocol_port),
context=SSLContextWrapper(self._create_context(None)))
try:
http.request('GET', '/v1/err')
response = http.getresponse()
return response, response.read()
finally:
http.close()

response, body = yield deferToThread(proxy_request)
eq_(response.status, 418)
payload = json.loads(body)
eq_(payload['error'], "Test Error")
ok_(self.logs.logged_ci(lambda ci: ci.get('remote_ip') == ip))


class TestMemUsage(IntegrationBase):

@inlineCallbacks
Expand Down
4 changes: 2 additions & 2 deletions autopush/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
Optional,
Tuple,
)
from zope.interface import implements
from zope.interface import implementer

from autopush import __version__
from autopush.base import BaseHandler
Expand Down Expand Up @@ -136,8 +136,8 @@ def wrapper(self, *args, **kwargs):
return wrapper


@implementer(IProducer)
class PushState(object):
implements(IProducer)

__slots__ = [
'_callbacks',
Expand Down

0 comments on commit f39886d

Please sign in to comment.