diff --git a/CHANGES.txt b/CHANGES.txt index 4a9e019f338..2633ca985d7 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -4,6 +4,9 @@ CHANGES 0.16.0 (XX-XX-XXXX) ------------------- +- Support new `expect_fingerprint` param of TCPConnector to enable verifying + ssl certificates via md5, sha1, or sha256 fingerprint + - Setup uploaded filename if field value is binary and transfer encoding is not specified #349 diff --git a/aiohttp/connector.py b/aiohttp/connector.py index 392309b2166..c1627c53b86 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -8,7 +8,9 @@ import traceback import warnings +from binascii import hexlify, unhexlify from collections import defaultdict +from hashlib import md5, sha1, sha256 from itertools import chain from math import ceil @@ -17,6 +19,7 @@ from .errors import ServerDisconnectedError from .errors import HttpProxyError, ProxyConnectionError from .errors import ClientOSError, ClientTimeoutError +from .errors import FingerprintMismatch from .helpers import BasicAuth @@ -25,6 +28,12 @@ PY_34 = sys.version_info >= (3, 4) PY_343 = sys.version_info >= (3, 4, 3) +HASHFUNC_BY_DIGESTLEN = { + 32: md5, + 40: sha1, + 64: sha256, +} + class Connection(object): @@ -347,13 +356,16 @@ class TCPConnector(BaseConnector): """TCP connector. :param bool verify_ssl: Set to True to check ssl certifications. + :param str expect_fingerprint: Set to the md5, sha1, or sha256 fingerprint + (as a hexadecimal string) of the expected certificate (DER-encoded) + to verify the cert matches. May be interspersed with colons. :param bool resolve: Set to True to do DNS lookup for host name. :param family: socket address family :param args: see :class:`BaseConnector` :param kwargs: see :class:`BaseConnector` """ - def __init__(self, *, verify_ssl=True, + def __init__(self, *, verify_ssl=True, expect_fingerprint=None, resolve=False, family=socket.AF_INET, ssl_context=None, **kwargs): super().__init__(**kwargs) @@ -364,6 +376,17 @@ def __init__(self, *, verify_ssl=True, "verify_ssl=False or specify ssl_context, not both.") self._verify_ssl = verify_ssl + + if expect_fingerprint: + expect_fingerprint = expect_fingerprint.replace(':', '').lower() + digestlen = len(expect_fingerprint) + hashfunc = HASHFUNC_BY_DIGESTLEN.get(digestlen) + if not hashfunc: + raise ValueError('Fingerprint is of invalid length.') + self._hashfunc = hashfunc + self._fingerprint_bytes = unhexlify(expect_fingerprint) + + self._expect_fingerprint = expect_fingerprint self._ssl_context = ssl_context self._family = family self._resolve = resolve @@ -374,6 +397,11 @@ def verify_ssl(self): """Do check for ssl certifications?""" return self._verify_ssl + @property + def expect_fingerprint(self): + """Expected value of ssl certificate fingerprint, if any.""" + return self._expect_fingerprint + @property def ssl_context(self): """SSLContext instance for https requests. @@ -464,11 +492,25 @@ def _create_connection(self, req): for hinfo in hosts: try: - return (yield from self._loop.create_connection( - self._factory, hinfo['host'], hinfo['port'], + host = hinfo['host'] + port = hinfo['port'] + conn = yield from self._loop.create_connection( + self._factory, host, port, ssl=sslcontext, family=hinfo['family'], proto=hinfo['proto'], flags=hinfo['flags'], - server_hostname=hinfo['hostname'] if sslcontext else None)) + server_hostname=hinfo['hostname'] if sslcontext else None) + if req.ssl and self._expect_fingerprint: + transport = conn[0] + sock = transport.get_extra_info('socket') + # gives DER-encoded cert as a sequence of bytes (or None) + cert = sock.getpeercert(binary_form=True) + got = cert and self._hashfunc(cert).digest() + expected = self._fingerprint_bytes + if expected != got: + got = got and hexlify(got).decode('ascii') + expected = hexlify(expected).decode('ascii') + raise FingerprintMismatch(expected, got, host, port) + return conn except OSError as e: exc = e else: diff --git a/aiohttp/errors.py b/aiohttp/errors.py index 5c148638c1f..b488c963c2c 100644 --- a/aiohttp/errors.py +++ b/aiohttp/errors.py @@ -13,6 +13,7 @@ 'ClientError', 'ClientHttpProcessingError', 'ClientConnectionError', 'ClientOSError', 'ClientTimeoutError', 'ProxyConnectionError', 'ClientRequestError', 'ClientResponseError', + 'FingerprintMismatch', 'WSServerHandshakeError', 'WSClientDisconnectedError') @@ -170,3 +171,18 @@ class LineLimitExceededParserError(ParserError): def __init__(self, msg, limit): super().__init__(msg) self.limit = limit + + +class FingerprintMismatch(ClientConnectionError): + """SSL certificate does not match expected fingerprint.""" + + def __init__(self, expected, got, host, port): + self.expected = expected + self.got = got + self.host = host + self.port = port + + def __repr__(self): + return '<{} expected={} got={} host={} port={}>'.format( + self.__class__.__name__, self.expected, self.got, + self.host, self.port) diff --git a/docs/client.rst b/docs/client.rst index 8304cf0de01..3889e0c50af 100644 --- a/docs/client.rst +++ b/docs/client.rst @@ -396,20 +396,29 @@ By default it uses strict checks for HTTPS protocol. Certification checks can be relaxed by passing ``verify_ssl=False``:: >>> conn = aiohttp.TCPConnector(verify_ssl=False) - >>> r = yield from aiohttp.request( - ... 'get', 'https://example.com', connector=conn) + >>> session = aiohttp.ClientSession(connector=conn) + >>> r = yield from session.get('https://example.com') If you need to setup custom ssl parameters (use own certification files for example) you can create a :class:`ssl.SSLContext` instance and pass it into the connector:: - >>> sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - >>> sslcontext.verify_mode = ssl.CERT_REQUIRED - >>> sslcontext.load_verify_locations("/etc/ssl/certs/ca-bundle.crt") + >>> sslcontext = ssl.create_default_context(cafile='/path/to/ca-bundle.crt') >>> conn = aiohttp.TCPConnector(ssl_context=sslcontext) - >>> r = yield from aiohttp.request( - ... 'get', 'https://example.com', connector=conn) + >>> session = aiohttp.ClientSession(connector=conn) + >>> r = yield from session.get('https://example.com') + +You may also verify certificates via fingerprint:: + + >>> # hex string of md5, sha1, or sha256 of expected cert (in DER format) + >>> expected_md5 = 'ca3b499c75768e7313384e243f15cacb' + >>> conn = aiohttp.TCPConnector(expect_fingerprint=expected_md5) + >>> session = aiohttp.ClientSession(connector=conn) + >>> r = yield from session.get('https://www.python.org') + Traceback (most recent call last)\: + File "", line 1, in + FingerprintMismatch: ('ca3b499c75768e7313384e243f15cacb', 'a20647adaaf5d85c4a995e62793b063d', 'www.python.org', 443) Unix domain sockets diff --git a/tests/sample.crt.der b/tests/sample.crt.der new file mode 100644 index 00000000000..ce22b75b9e0 Binary files /dev/null and b/tests/sample.crt.der differ diff --git a/tests/test_connector.py b/tests/test_connector.py index fc81d559045..68c8dc78176 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -12,6 +12,7 @@ import aiohttp from aiohttp import client from aiohttp import test_utils +from aiohttp.errors import FingerprintMismatch from aiohttp.client import ClientResponse, ClientRequest from aiohttp.connector import Connection @@ -452,10 +453,51 @@ def test_cleanup3(self): def test_tcp_connector_ctor(self): conn = aiohttp.TCPConnector(loop=self.loop) self.assertTrue(conn.verify_ssl) + self.assertIs(conn.expect_fingerprint, None) self.assertFalse(conn.resolve) self.assertEqual(conn.family, socket.AF_INET) self.assertEqual(conn.resolved_hosts, {}) + def test_tcp_connector_ctor_expect_fingerprint_valid(self): + valid = '7393fd3aed081d6fa9ae71391ae3c57f89e76cf9' + conn = aiohttp.TCPConnector(loop=self.loop, expect_fingerprint=valid) + self.assertEqual(conn.expect_fingerprint, valid) + + def test_tcp_connector_expect_fingerprint_invalid_len(self): + invalid = 'a1b2c3' + with self.assertRaises(ValueError): + aiohttp.TCPConnector(loop=self.loop, expect_fingerprint=invalid) + + def test_tcp_connector_expect_fingerprint(self): + # the even-index fingerprints below are for sample.crt.der, + # the certificate presented by test_utils.run_server + testcases = ( + # md5 + 'a20647adaaf5d85c4a995e62793b063d', # good + 'ffffffffffffffffffffffffffffffff', # bad + # sha1 + '7393fd3aed081d6fa9ae71391ae3c57f89e76cf9', # good + 'ffffffffffffffffffffffffffffffffffffffff', # bad + # sha256 + '309ac94483dc9127889111a16497fdcb7e37551444404c11ab99a8aeb714ee8b', # good # flake8: noqa + 'ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff', # bad # flake8: noqa + ) + for i, fingerprint in enumerate(testcases): + expect_fail = i % 2 + conn = aiohttp.TCPConnector(loop=self.loop, verify_ssl=False, + expect_fingerprint=fingerprint) + with test_utils.run_server(self.loop, use_ssl=True) as httpd: + coro = client.request('get', httpd.url('method', 'get'), + connector=conn, loop=self.loop) + if expect_fail: + with self.assertRaises(FingerprintMismatch) as cm: + self.loop.run_until_complete(coro) + self.assertEqual(cm.exception.expected, fingerprint) + self.assertEqual(cm.exception.got, testcases[i-1]) + else: + # should not raise + self.loop.run_until_complete(coro) + def test_tcp_connector_clear_resolved_hosts(self): conn = aiohttp.TCPConnector(loop=self.loop) info = object()