From 68cb55884f9b24c9cb302607d7943e76c1de8478 Mon Sep 17 00:00:00 2001 From: hellysmile Date: Mon, 30 Oct 2017 15:42:07 +0200 Subject: [PATCH] Fix wrap ssl errors for proxy connector. --- CHANGES/2408.bugfix | 1 + aiohttp/connector.py | 99 +++++++++++++++++---------------- tests/test_proxy.py | 129 +++++++++++++++++++++++++++++++++++-------- 3 files changed, 158 insertions(+), 71 deletions(-) create mode 100644 CHANGES/2408.bugfix diff --git a/CHANGES/2408.bugfix b/CHANGES/2408.bugfix new file mode 100644 index 00000000000..4fec0c8b07c --- /dev/null +++ b/CHANGES/2408.bugfix @@ -0,0 +1 @@ +Fix ClientConnectorSSLError and ClientProxyConnectionError for proxy connector diff --git a/aiohttp/connector.py b/aiohttp/connector.py index 47f8f3fb95b..fbad64c4435 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -781,7 +781,22 @@ def _get_fingerprint_and_hashfunc(self, req): return (None, None) @asyncio.coroutine - def _create_direct_connection(self, req): + def _wrap_create_connection(self, *args, + req, client_error=ClientConnectorError, + **kwargs): + try: + return (yield from self._loop.create_connection(*args, **kwargs)) + except certificate_errors as exc: + raise ClientConnectorCertificateError( + req.connection_key, exc) from exc + except ssl_errors as exc: + raise ClientConnectorSSLError(req.connection_key, exc) from exc + except OSError as exc: + raise client_error(req.connection_key, exc) from exc + + @asyncio.coroutine + def _create_direct_connection(self, req, + *, client_error=ClientConnectorError): sslcontext = self._get_ssl_context(req) fingerprint, hashfunc = self._get_fingerprint_and_hashfunc(req) @@ -792,45 +807,36 @@ def _create_direct_connection(self, req): # it is problem of resolving proxy ip itself raise ClientConnectorError(req.connection_key, exc) from exc - hosts = yield from self._resolve_host(req.url.raw_host, req.port) - for hinfo in hosts: - try: - host = hinfo['host'] - port = hinfo['port'] - transp, proto = yield from self._loop.create_connection( - self._factory, host, port, - ssl=sslcontext, family=hinfo['family'], - proto=hinfo['proto'], flags=hinfo['flags'], - server_hostname=hinfo['hostname'] if sslcontext else None, - local_addr=self._local_addr) - has_cert = transp.get_extra_info('sslcontext') - if has_cert and fingerprint: - sock = transp.get_extra_info('socket') - if not hasattr(sock, 'getpeercert'): - # Workaround for asyncio 3.5.0 - # Starting from 3.5.1 version - # there is 'ssl_object' extra info in transport - sock = transp._ssl_protocol._sslpipe.ssl_object - # gives DER-encoded cert as a sequence of bytes (or None) - cert = sock.getpeercert(binary_form=True) - assert cert - got = hashfunc(cert).digest() - expected = fingerprint - if got != expected: - transp.close() - if not self._cleanup_closed_disabled: - self._cleanup_closed_transports.append(transp) - raise ServerFingerprintMismatch( - expected, got, host, port) - return transp, proto - except certificate_errors as exc: - raise ClientConnectorCertificateError( - req.connection_key, exc) from exc - except ssl_errors as exc: - raise ClientConnectorSSLError(req.connection_key, exc) from exc - except OSError as exc: - raise ClientConnectorError(req.connection_key, exc) from exc + host = hinfo['host'] + port = hinfo['port'] + transp, proto = yield from self._wrap_create_connection( + self._factory, host, port, + ssl=sslcontext, family=hinfo['family'], + proto=hinfo['proto'], flags=hinfo['flags'], + server_hostname=hinfo['hostname'] if sslcontext else None, + local_addr=self._local_addr, + req=req, client_error=client_error) + has_cert = transp.get_extra_info('sslcontext') + if has_cert and fingerprint: + sock = transp.get_extra_info('socket') + if not hasattr(sock, 'getpeercert'): + # Workaround for asyncio 3.5.0 + # Starting from 3.5.1 version + # there is 'ssl_object' extra info in transport + sock = transp._ssl_protocol._sslpipe.ssl_object + # gives DER-encoded cert as a sequence of bytes (or None) + cert = sock.getpeercert(binary_form=True) + assert cert + got = hashfunc(cert).digest() + expected = fingerprint + if got != expected: + transp.close() + if not self._cleanup_closed_disabled: + self._cleanup_closed_transports.append(transp) + raise ServerFingerprintMismatch( + expected, got, host, port) + return transp, proto @asyncio.coroutine def _create_proxy_connection(self, req): @@ -847,12 +853,10 @@ def _create_proxy_connection(self, req): verify_ssl=req.verify_ssl, fingerprint=req.fingerprint, ssl_context=req.ssl_context) - try: - # create connection to proxy server - transport, proto = yield from self._create_direct_connection( - proxy_req) - except OSError as exc: - raise ClientProxyConnectionError(proxy_req, exc) from exc + + # create connection to proxy server + transport, proto = yield from self._create_direct_connection( + proxy_req, client_error=ClientProxyConnectionError) auth = proxy_req.headers.pop(hdrs.AUTHORIZATION, None) if auth is not None: @@ -903,9 +907,10 @@ def _create_proxy_connection(self, req): finally: transport.close() - transport, proto = yield from self._loop.create_connection( + transport, proto = yield from self._wrap_create_connection( self._factory, ssl=sslcontext, sock=rawsock, - server_hostname=req.host) + server_hostname=req.host, + req=req) finally: proxy_resp.close() diff --git a/tests/test_proxy.py b/tests/test_proxy.py index ab123c05a56..cc066904e2b 100644 --- a/tests/test_proxy.py +++ b/tests/test_proxy.py @@ -2,6 +2,7 @@ import gc import hashlib import socket +import ssl import unittest from unittest import mock @@ -231,6 +232,24 @@ def test_proxy_dns_error(self): self.assertEqual(req.url.path, '/') self.assertEqual(dict(req.headers), expected_headers) + def test_proxy_connection_error(self): + connector = aiohttp.TCPConnector(loop=self.loop) + connector._resolve_host = make_mocked_coro([{ + 'hostname': 'www.python.org', + 'host': '127.0.0.1', 'port': 80, + 'family': socket.AF_INET, 'proto': 0, + 'flags': socket.AI_NUMERICHOST}]) + connector._loop.create_connection = make_mocked_coro( + raise_exception=OSError('dont take it serious')) + + req = ClientRequest( + 'GET', URL('http://www.python.org'), + proxy=URL('http://proxy.example.com'), + loop=self.loop, + ) + with self.assertRaises(aiohttp.ClientProxyConnectionError): + self.loop.run_until_complete(connector.connect(req)) + @mock.patch('aiohttp.connector.ClientRequest') def test_auth(self, ClientRequestMock): proxy_req = ClientRequest( @@ -314,30 +333,6 @@ def test_auth_from_url(self, ClientRequestMock): ssl_context=None, verify_ssl=None) conn.close() - @mock.patch('aiohttp.connector.ClientRequest') - def test_auth__not_modifying_request(self, ClientRequestMock): - proxy_req = ClientRequest('GET', - URL('http://user:pass@proxy.example.com'), - loop=self.loop) - ClientRequestMock.return_value = proxy_req - proxy_req_headers = dict(proxy_req.headers) - - connector = aiohttp.TCPConnector(loop=self.loop) - connector._resolve_host = make_mocked_coro( - raise_exception=OSError('nothing personal')) - - req = ClientRequest( - 'GET', URL('http://www.python.org'), - proxy=URL('http://user:pass@proxy.example.com'), - loop=self.loop, - ) - req_headers = dict(req.headers) - with self.assertRaises(aiohttp.ClientConnectorError): - self.loop.run_until_complete(connector.connect(req)) - self.assertEqual(req.headers, req_headers) - self.assertEqual(req.url.path, '/') - self.assertEqual(proxy_req.headers, proxy_req_headers) - @mock.patch('aiohttp.connector.ClientRequest') def test_https_connect(self, ClientRequestMock): proxy_req = ClientRequest('GET', URL('http://proxy.example.com'), @@ -375,6 +370,92 @@ def test_https_connect(self, ClientRequestMock): proxy_resp.close() self.loop.run_until_complete(req.close()) + @mock.patch('aiohttp.connector.ClientRequest') + def test_https_connect_certificate_error(self, ClientRequestMock): + proxy_req = ClientRequest('GET', URL('http://proxy.example.com'), + loop=self.loop) + ClientRequestMock.return_value = proxy_req + + proxy_resp = ClientResponse('get', URL('http://proxy.example.com')) + proxy_resp._loop = self.loop + proxy_req.send = send_mock = mock.Mock() + send_mock.return_value = proxy_resp + proxy_resp.start = make_mocked_coro(mock.Mock(status=200)) + + connector = aiohttp.TCPConnector(loop=self.loop) + connector._resolve_host = make_mocked_coro( + [{'hostname': 'hostname', 'host': '127.0.0.1', 'port': 80, + 'family': socket.AF_INET, 'proto': 0, 'flags': 0}]) + + seq = 0 + + @asyncio.coroutine + def create_connection(*args, **kwargs): + nonlocal seq + seq += 1 + + # connection to http://proxy.example.com + if seq == 1: + return mock.Mock(), mock.Mock() + # connection to https://www.python.org + elif seq == 2: + raise ssl.CertificateError + else: + assert False + + self.loop.create_connection = create_connection + + req = ClientRequest( + 'GET', URL('https://www.python.org'), + proxy=URL('http://proxy.example.com'), + loop=self.loop, + ) + with self.assertRaises(aiohttp.ClientConnectorCertificateError): + self.loop.run_until_complete(connector._create_connection(req)) + + @mock.patch('aiohttp.connector.ClientRequest') + def test_https_connect_ssl_error(self, ClientRequestMock): + proxy_req = ClientRequest('GET', URL('http://proxy.example.com'), + loop=self.loop) + ClientRequestMock.return_value = proxy_req + + proxy_resp = ClientResponse('get', URL('http://proxy.example.com')) + proxy_resp._loop = self.loop + proxy_req.send = send_mock = mock.Mock() + send_mock.return_value = proxy_resp + proxy_resp.start = make_mocked_coro(mock.Mock(status=200)) + + connector = aiohttp.TCPConnector(loop=self.loop) + connector._resolve_host = make_mocked_coro( + [{'hostname': 'hostname', 'host': '127.0.0.1', 'port': 80, + 'family': socket.AF_INET, 'proto': 0, 'flags': 0}]) + + seq = 0 + + @asyncio.coroutine + def create_connection(*args, **kwargs): + nonlocal seq + seq += 1 + + # connection to http://proxy.example.com + if seq == 1: + return mock.Mock(), mock.Mock() + # connection to https://www.python.org + elif seq == 2: + raise ssl.SSLError + else: + assert False + + self.loop.create_connection = create_connection + + req = ClientRequest( + 'GET', URL('https://www.python.org'), + proxy=URL('http://proxy.example.com'), + loop=self.loop, + ) + with self.assertRaises(aiohttp.ClientConnectorSSLError): + self.loop.run_until_complete(connector._create_connection(req)) + @mock.patch('aiohttp.connector.ClientRequest') def test_https_connect_runtime_error(self, ClientRequestMock): proxy_req = ClientRequest('GET', URL('http://proxy.example.com'),