diff --git a/autopush/haproxy.py b/autopush/haproxy.py new file mode 100644 index 00000000..d511b723 --- /dev/null +++ b/autopush/haproxy.py @@ -0,0 +1,35 @@ +from twisted.internet import defer +from twisted.internet.interfaces import IStreamServerEndpoint +from twisted.protocols.haproxy._wrapper import HAProxyWrappingFactory +from twisted.protocols.tls import TLSMemoryBIOFactory +from zope.interface import implementer + + +@implementer(IStreamServerEndpoint) +class HAProxyServerEndpoint(object): + """A HAProxy endpoint, optionally handling TLS""" + + wrapper_factory = HAProxyWrappingFactory + + def __init__(self, reactor, port, ssl_cf=None, **kwargs): + self._reactor = reactor + self._port = port + self._ssl_cf = ssl_cf + self._kwargs = kwargs + + def listen(self, factory): + """Implement IStreamServerEndpoint.listen to listen on TCP. + + Optionally configuring TLS behind the HAProxy protocol. + + """ + if self._ssl_cf: + factory = TLSMemoryBIOFactory(self._ssl_cf, False, factory) + proxyf = self.wrapper_factory(factory) + return defer.execute(self._listen, self._port, proxyf, **self._kwargs) + + def _listen(self, *args, **kwargs): + port = self._reactor.listenTCP(*args, **kwargs) + if self._ssl_cf: + port._type = 'TLS' + return port diff --git a/autopush/main.py b/autopush/main.py index 347b23b9..dd14d4c2 100644 --- a/autopush/main.py +++ b/autopush/main.py @@ -610,21 +610,24 @@ def endpoint_main(sysargs=None, use_files=True): settings.metrics.start() - def create_endpoint(port): - if not args.ssl_key: - return TCP4ServerEndpoint(reactor, port) + if args.ssl_key: ssl_cf = AutopushSSLContextFactory( args.ssl_key, args.ssl_cert, dh_file=args.ssl_dh_param, require_peer_certs=settings.enable_tls_auth) - return SSL4ServerEndpoint(reactor, port, ssl_cf) - - endpoint = create_endpoint(args.port) + endpoint = SSL4ServerEndpoint(reactor, args.port, ssl_cf) + else: + ssl_cf = None + endpoint = TCP4ServerEndpoint(reactor, args.port) endpoint.listen(site) + if args.proxy_protocol_port: - from twisted.protocols.haproxy import proxyEndpoint - pendpoint = proxyEndpoint(create_endpoint(args.proxy_protocol_port)) + from autopush.haproxy import HAProxyServerEndpoint + pendpoint = HAProxyServerEndpoint( + reactor, + args.proxy_protocol_port, + ssl_cf) pendpoint.listen(site) reactor.suggestThreadPoolSize(50) diff --git a/autopush/tests/test_integration.py b/autopush/tests/test_integration.py index bb05ccf6..7a7874c2 100644 --- a/autopush/tests/test_integration.py +++ b/autopush/tests/test_integration.py @@ -358,6 +358,7 @@ class IntegrationBase(unittest.TestCase): track_objects = True track_objects_excludes = [AutopushSettings, PushServerFactory] proxy_protocol_port = None + endpoint_scheme = 'http' def setUp(self): import cyclone.web @@ -383,12 +384,10 @@ def setUp(self): storage_table = os.environ.get("STORAGE_TABLE", "storage_int_test") message_table = os.environ.get("MESSAGE_TABLE", "message_int_test") - client_certs = self.make_client_certs() - is_https = client_certs is not None - self.endpoint_port = 9020 self.router_port = 9030 self.memusage_port = 9040 + client_certs = self.make_client_certs() settings = AutopushSettings( hostname="localhost", statsd_host=None, @@ -398,7 +397,7 @@ def setUp(self): storage_tablename=storage_table, message_tablename=message_table, client_certs=client_certs, - endpoint_scheme='https' if is_https else 'http', + endpoint_scheme=self.endpoint_scheme, ) # Websocket server @@ -445,14 +444,11 @@ def setUp(self): self._settings = settings self.endpoints = [] - ep = self._create_endpoint(self.endpoint_port, is_https) + ep = self.create_endpoint(self.endpoint_port) ep.listen(site).addCallback(self._endpoint_listening) if self.proxy_protocol_port: - from twisted.protocols.haproxy import proxyEndpoint - ep = proxyEndpoint( - self._create_endpoint(self.proxy_protocol_port, is_https) - ) + ep = self.create_proxy_endpoint(self.proxy_protocol_port) ep.listen(site).addCallback(self._endpoint_listening) self.memusage_site = create_memusage_site( @@ -460,10 +456,12 @@ def setUp(self): self.memusage_port, False) - def _create_endpoint(self, port, is_https): - if not is_https: - return TCP4ServerEndpoint(reactor, port) - return SSL4ServerEndpoint(reactor, port, self.endpoint_SSLCF()) + def create_endpoint(self, port): + return TCP4ServerEndpoint(reactor, port) + + def create_proxy_endpoint(self, port): + from autopush.haproxy import HAProxyServerEndpoint + return HAProxyServerEndpoint(reactor, self.proxy_protocol_port) def _endpoint_listening(self, port): self.endpoints.append(port) @@ -508,6 +506,70 @@ def legacy_endpoint(self): self._settings._notification_legacy = False +class SSLEndpointMixin(object): + + endpoint_scheme = 'https' + certs = os.path.join(os.path.dirname(__file__), "certs") + servercert = os.path.join(certs, "server.pem") + + def create_endpoint(self, port): + return SSL4ServerEndpoint(reactor, port, self.endpoint_SSLCF()) + + def endpoint_SSLCF(self): + """Return an SSLContextFactory for the endpoint. + + Configured with the self-signed test server.pem. server.pem is + additionally the signer of the client certs in the same dir. + + """ + from autopush.ssl import AutopushSSLContextFactory + return AutopushSSLContextFactory( + self.servercert, + self.servercert, + require_peer_certs=self._settings.enable_tls_auth) + + def client_SSLCF(self, certfile): + """Return an IPolicyForHTTPS for verifiying tests' server cert. + + Optionally configures a client cert. + + """ + from twisted.internet.ssl import ( + Certificate, PrivateCertificate, optionsForClientTLS) + from twisted.web.iweb import IPolicyForHTTPS + from zope.interface import implementer + + with open(self.servercert) as fp: + servercert = Certificate.loadPEM(fp.read()) + if certfile: + with open(self.unauth_client) as fp: + unauth_client = PrivateCertificate.loadPEM(fp.read()) + else: + unauth_client = None + + @implementer(IPolicyForHTTPS) + class UnauthClientPolicyForHTTPS(object): + def creatorForNetloc(self, hostname, port): + return optionsForClientTLS( + hostname.decode('ascii'), + trustRoot=servercert, + clientCertificate=unauth_client) + return UnauthClientPolicyForHTTPS() + + def _create_context(self, certfile): + """Return a client SSLContext + + Optionally configures a client cert. + + """ + import ssl + context = ssl.create_default_context() + if certfile: + context.load_cert_chain(certfile) + context.load_verify_locations(self.servercert) + return context + + class TestSimple(IntegrationBase): @inlineCallbacks def test_delivery_while_disconnected(self): @@ -1573,14 +1635,13 @@ def test_with_key(self): yield self.shut_down(client) -class TestClientCerts(IntegrationBase): +class TestClientCerts(SSLEndpointMixin, IntegrationBase): def setUp(self): - self.certs = certs = os.path.join(os.path.dirname(__file__), "certs") - self.servercert = os.path.join(certs, "server.pem") + certs = self.certs self.auth_client = os.path.join(certs, "client1.pem") self.unauth_client = os.path.join(certs, "client2.pem") - with open(os.path.join(self.certs, "client1_sha256.txt")) as fp: + with open(os.path.join(certs, "client1_sha256.txt")) as fp: client1_sha256 = fp.read().strip() self._client_certs = {client1_sha256: 'partner1'} IntegrationBase.setUp(self) @@ -1588,51 +1649,6 @@ def setUp(self): def make_client_certs(self): return self._client_certs - def endpoint_SSLCF(self): - """Return an SSLContextFactory for the endpoint. - - Configured with the self-signed test server.pem. server.pem is - additionally the signer of the client certs. - - """ - from autopush.ssl import AutopushSSLContextFactory - return AutopushSSLContextFactory( - self.servercert, - self.servercert, - require_peer_certs=self._settings.enable_tls_auth) - - def _create_unauth_SSLCF(self, certfile): - """Return an IPolicyForHTTPS for the unauthorized client""" - from twisted.internet.ssl import ( - Certificate, PrivateCertificate, optionsForClientTLS) - from twisted.web.iweb import IPolicyForHTTPS - from zope.interface import implementer - - with open(self.servercert) as fp: - servercert = Certificate.loadPEM(fp.read()) - if certfile: - with open(self.unauth_client) as fp: - unauth_client = PrivateCertificate.loadPEM(fp.read()) - else: - unauth_client = None - - @implementer(IPolicyForHTTPS) - class UnauthClientPolicyForHTTPS(object): - def creatorForNetloc(self, hostname, port): - return optionsForClientTLS(hostname.decode('ascii'), - trustRoot=servercert, - clientCertificate=unauth_client) - return UnauthClientPolicyForHTTPS() - - def _create_context(self, certfile): - """Return a client SSLContext""" - import ssl - context = ssl.create_default_context() - if certfile: - context.load_cert_chain(certfile) - context.load_verify_locations(self.servercert) - return context - @inlineCallbacks def test_client_cert_simple(self): client = yield self.quick_register( @@ -1696,7 +1712,7 @@ def _test_unauth(self, certfile): response, body = yield _agent( 'DELETE', "https://localhost:9020/m/foo", - contextFactory=self._create_unauth_SSLCF(certfile)) + contextFactory=self.client_SSLCF(certfile)) eq_(response.code, 401) wwwauth = response.headers.getRawHeaders('www-authenticate') eq_(wwwauth, ['Transport mode="tls-client-certificate"']) @@ -1714,7 +1730,7 @@ def _test_log_check_skips_auth(self, certfile): response, body = yield _agent( 'GET', "https://localhost:9020/v1/err", - contextFactory=self._create_unauth_SSLCF(certfile)) + contextFactory=self.client_SSLCF(certfile)) eq_(response.code, 418) payload = json.loads(body) eq_(payload['error'], "Test Error") @@ -1732,7 +1748,7 @@ def _test_status_skips_auth(self, certfile): response, body = yield _agent( 'GET', "https://localhost:9020/status", - contextFactory=self._create_unauth_SSLCF(certfile)) + contextFactory=self.client_SSLCF(certfile)) eq_(response.code, 200) payload = json.loads(body) eq_(payload, dict(status="OK", version=__version__)) @@ -1750,7 +1766,7 @@ def _test_health_skips_auth(self, certfile): response, body = yield _agent( 'GET', "https://localhost:9020/health", - contextFactory=self._create_unauth_SSLCF(certfile)) + contextFactory=self.client_SSLCF(certfile)) eq_(response.code, 200) payload = json.loads(body) eq_(payload['version'], __version__) @@ -2045,6 +2061,54 @@ def test_no_proxy_protocol(self): eq_(payload['error'], "Test Error") +class TestProxyProtocolSSL(SSLEndpointMixin, IntegrationBase): + proxy_protocol_port = 9021 + + def create_proxy_endpoint(self, port): + from autopush.haproxy import HAProxyServerEndpoint + return HAProxyServerEndpoint( + reactor, + self.proxy_protocol_port, + self.endpoint_SSLCF()) + + @inlineCallbacks + def test_proxy_protocol_ssl(self): + ip = '198.51.100.22' + + def proxy_request(): + # like TestProxyProtocol.test_proxy_protocol, we prepend + # the proxy proto. line before the payload (which is + # encrypted in this case). HACK: sneak around httplib's + # wrapped ssl sock by hooking into SSLContext.wrap_socket + proto_line = 'PROXY TCP4 {} 203.0.113.7 35646 80\r\n'.format(ip) + + class SSLContextWrapper(object): + def __init__(self, context): + self.context = context + + def wrap_socket(self, sock, *args, **kwargs): + # send proto_line over the raw, unencrypted sock + sock.send(proto_line) + # now do the handshake/encrypt sock + return self.context.wrap_socket(sock, *args, **kwargs) + + http = httplib.HTTPSConnection( + "localhost:{}".format(self.proxy_protocol_port), + context=SSLContextWrapper(self._create_context(None))) + try: + http.request('GET', '/v1/err') + response = http.getresponse() + return response, response.read() + finally: + http.close() + + response, body = yield deferToThread(proxy_request) + eq_(response.status, 418) + payload = json.loads(body) + eq_(payload['error'], "Test Error") + ok_(self.logs.logged_ci(lambda ci: ci.get('remote_ip') == ip)) + + class TestMemUsage(IntegrationBase): @inlineCallbacks diff --git a/autopush/websocket.py b/autopush/websocket.py index 241325e6..266d51f4 100644 --- a/autopush/websocket.py +++ b/autopush/websocket.py @@ -70,7 +70,7 @@ Optional, Tuple, ) -from zope.interface import implements +from zope.interface import implementer from autopush import __version__ from autopush.base import BaseHandler @@ -136,8 +136,8 @@ def wrapper(self, *args, **kwargs): return wrapper +@implementer(IProducer) class PushState(object): - implements(IProducer) __slots__ = [ '_callbacks',