Skip to content

Commit

Permalink
support new TCPConnector param expect_fingerprint
Browse files Browse the repository at this point in the history
  • Loading branch information
requiredfield committed May 16, 2015
1 parent fc7cbbf commit 17008ed
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 11 deletions.
3 changes: 3 additions & 0 deletions CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ CHANGES
0.16.0 (XX-XX-XXXX)
-------------------

- Support new `expect_fingerprint` param of TCPConnector to enable verifying
ssl certificates via md5, sha1, or sha256 fingerprint

- Setup uploaded filename if field value is binary and transfer
encoding is not specified #349

Expand Down
50 changes: 46 additions & 4 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
import traceback
import warnings

from binascii import hexlify, unhexlify
from collections import defaultdict
from hashlib import md5, sha1, sha256
from itertools import chain
from math import ceil

Expand All @@ -17,6 +19,7 @@
from .errors import ServerDisconnectedError
from .errors import HttpProxyError, ProxyConnectionError
from .errors import ClientOSError, ClientTimeoutError
from .errors import FingerprintMismatch
from .helpers import BasicAuth


Expand All @@ -25,6 +28,12 @@
PY_34 = sys.version_info >= (3, 4)
PY_343 = sys.version_info >= (3, 4, 3)

HASHFUNC_BY_DIGESTLEN = {
32: md5,
40: sha1,
64: sha256,
}


class Connection(object):

Expand Down Expand Up @@ -347,13 +356,16 @@ class TCPConnector(BaseConnector):
"""TCP connector.
:param bool verify_ssl: Set to True to check ssl certifications.
:param str expect_fingerprint: Set to the md5, sha1, or sha256 fingerprint
(as a hexadecimal string) of the expected certificate (DER-encoded)
to verify the cert matches. May be interspersed with colons.
:param bool resolve: Set to True to do DNS lookup for host name.
:param family: socket address family
:param args: see :class:`BaseConnector`
:param kwargs: see :class:`BaseConnector`
"""

def __init__(self, *, verify_ssl=True,
def __init__(self, *, verify_ssl=True, expect_fingerprint=None,
resolve=False, family=socket.AF_INET, ssl_context=None,
**kwargs):
super().__init__(**kwargs)
Expand All @@ -364,6 +376,17 @@ def __init__(self, *, verify_ssl=True,
"verify_ssl=False or specify ssl_context, not both.")

self._verify_ssl = verify_ssl

if expect_fingerprint:
expect_fingerprint = expect_fingerprint.replace(':', '').lower()
digestlen = len(expect_fingerprint)
hashfunc = HASHFUNC_BY_DIGESTLEN.get(digestlen)
if not hashfunc:
raise ValueError('Fingerprint is of invalid length.')
self._hashfunc = hashfunc
self._fingerprint_bytes = unhexlify(expect_fingerprint)

self._expect_fingerprint = expect_fingerprint
self._ssl_context = ssl_context
self._family = family
self._resolve = resolve
Expand All @@ -374,6 +397,11 @@ def verify_ssl(self):
"""Do check for ssl certifications?"""
return self._verify_ssl

@property
def expect_fingerprint(self):
"""Expected value of ssl certificate fingerprint, if any."""
return self._expect_fingerprint

@property
def ssl_context(self):
"""SSLContext instance for https requests.
Expand Down Expand Up @@ -464,11 +492,25 @@ def _create_connection(self, req):

for hinfo in hosts:
try:
return (yield from self._loop.create_connection(
self._factory, hinfo['host'], hinfo['port'],
host = hinfo['host']
port = hinfo['port']
conn = yield from self._loop.create_connection(
self._factory, host, port,
ssl=sslcontext, family=hinfo['family'],
proto=hinfo['proto'], flags=hinfo['flags'],
server_hostname=hinfo['hostname'] if sslcontext else None))
server_hostname=hinfo['hostname'] if sslcontext else None)
if req.ssl and self._expect_fingerprint:
transport = conn[0]
sock = transport.get_extra_info('socket')
# gives DER-encoded cert as a sequence of bytes (or None)
cert = sock.getpeercert(binary_form=True)
got = cert and self._hashfunc(cert).digest()
expected = self._fingerprint_bytes
if expected != got:
got = got and hexlify(got).decode('ascii')
expected = hexlify(expected).decode('ascii')
raise FingerprintMismatch(expected, got, host, port)
return conn
except OSError as e:
exc = e
else:
Expand Down
16 changes: 16 additions & 0 deletions aiohttp/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
'ClientError', 'ClientHttpProcessingError', 'ClientConnectionError',
'ClientOSError', 'ClientTimeoutError', 'ProxyConnectionError',
'ClientRequestError', 'ClientResponseError',
'FingerprintMismatch',

'WSServerHandshakeError', 'WSClientDisconnectedError')

Expand Down Expand Up @@ -170,3 +171,18 @@ class LineLimitExceededParserError(ParserError):
def __init__(self, msg, limit):
super().__init__(msg)
self.limit = limit


class FingerprintMismatch(ClientConnectionError):
"""SSL certificate does not match expected fingerprint."""

def __init__(self, expected, got, host, port):
self.expected = expected
self.got = got
self.host = host
self.port = port

def __repr__(self):
return '<{} expected={} got={} host={} port={}>'.format(
self.__class__.__name__, self.expected, self.got,
self.host, self.port)
23 changes: 16 additions & 7 deletions docs/client.rst
Original file line number Diff line number Diff line change
Expand Up @@ -396,20 +396,29 @@ By default it uses strict checks for HTTPS protocol. Certification
checks can be relaxed by passing ``verify_ssl=False``::

>>> conn = aiohttp.TCPConnector(verify_ssl=False)
>>> r = yield from aiohttp.request(
... 'get', 'https://example.com', connector=conn)
>>> session = aiohttp.ClientSession(connector=conn)
>>> r = yield from session.get('https://example.com')


If you need to setup custom ssl parameters (use own certification
files for example) you can create a :class:`ssl.SSLContext` instance and
pass it into the connector::

>>> sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
>>> sslcontext.verify_mode = ssl.CERT_REQUIRED
>>> sslcontext.load_verify_locations("/etc/ssl/certs/ca-bundle.crt")
>>> sslcontext = ssl.create_default_context(cafile='/path/to/ca-bundle.crt')
>>> conn = aiohttp.TCPConnector(ssl_context=sslcontext)
>>> r = yield from aiohttp.request(
... 'get', 'https://example.com', connector=conn)
>>> session = aiohttp.ClientSession(connector=conn)
>>> r = yield from session.get('https://example.com')

You may also verify certificates via fingerprint::

>>> # hex string of md5, sha1, or sha256 of expected cert (in DER format)
>>> expected_md5 = 'ca3b499c75768e7313384e243f15cacb'
>>> conn = aiohttp.TCPConnector(expect_fingerprint=expected_md5)
>>> session = aiohttp.ClientSession(connector=conn)
>>> r = yield from session.get('https://www.python.org')
Traceback (most recent call last)\:
File "<stdin>", line 1, in <module>
FingerprintMismatch: ('ca3b499c75768e7313384e243f15cacb', 'a20647adaaf5d85c4a995e62793b063d', 'www.python.org', 443)


Unix domain sockets
Expand Down
Binary file added tests/sample.crt.der
Binary file not shown.
31 changes: 31 additions & 0 deletions tests/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import aiohttp
from aiohttp import client
from aiohttp import test_utils
from aiohttp.errors import FingerprintMismatch
from aiohttp.client import ClientResponse, ClientRequest
from aiohttp.connector import Connection

Expand Down Expand Up @@ -452,10 +453,40 @@ def test_cleanup3(self):
def test_tcp_connector_ctor(self):
conn = aiohttp.TCPConnector(loop=self.loop)
self.assertTrue(conn.verify_ssl)
self.assertIs(conn.expect_fingerprint, None)
self.assertFalse(conn.resolve)
self.assertEqual(conn.family, socket.AF_INET)
self.assertEqual(conn.resolved_hosts, {})

def test_tcp_connector_ctor_expect_fingerprint_valid(self):
valid = '7393fd3aed081d6fa9ae71391ae3c57f89e76cf9'
conn = aiohttp.TCPConnector(loop=self.loop, expect_fingerprint=valid)
self.assertEqual(conn.expect_fingerprint, valid)

def test_tcp_connector_expect_fingerprint_invalid(self):
invalid = 'a1b2c3'
with self.assertRaises(ValueError):
aiohttp.TCPConnector(loop=self.loop, expect_fingerprint=invalid)

def test_tcp_connector_expect_fingerprint(self):
# sha1 fingerprint of ./sample.crt.der
fpgood = '7393fd3aed081d6fa9ae71391ae3c57f89e76cf9'
fpbad = 'badbadbadbadbadbadbadbadbadbadbadbadbad1'
for fp in (fpgood, fpbad):
conn = aiohttp.TCPConnector(loop=self.loop, verify_ssl=False,
expect_fingerprint=fp)
with test_utils.run_server(self.loop, use_ssl=True) as httpd:
coro = client.request('get', httpd.url('method', 'get'),
connector=conn, loop=self.loop)
if fp == fpgood:
# should not raise
self.loop.run_until_complete(coro)
else:
with self.assertRaises(FingerprintMismatch) as cm:
self.loop.run_until_complete(coro)
self.assertEqual(cm.exception.expected, fpbad)
self.assertEqual(cm.exception.got, fpgood)

def test_tcp_connector_clear_resolved_hosts(self):
conn = aiohttp.TCPConnector(loop=self.loop)
info = object()
Expand Down

0 comments on commit 17008ed

Please sign in to comment.