Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix wrap ssl errors for proxy connector. #2446

Merged
merged 1 commit into from
Oct 30, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/2408.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix ClientConnectorSSLError and ClientProxyConnectionError for proxy connector
99 changes: 52 additions & 47 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,7 +781,22 @@ def _get_fingerprint_and_hashfunc(self, req):
return (None, None)

@asyncio.coroutine
def _create_direct_connection(self, req):
def _wrap_create_connection(self, *args,
req, client_error=ClientConnectorError,
**kwargs):
try:
return (yield from self._loop.create_connection(*args, **kwargs))
except certificate_errors as exc:
raise ClientConnectorCertificateError(
req.connection_key, exc) from exc
except ssl_errors as exc:
raise ClientConnectorSSLError(req.connection_key, exc) from exc
except OSError as exc:
raise client_error(req.connection_key, exc) from exc

@asyncio.coroutine
def _create_direct_connection(self, req,
*, client_error=ClientConnectorError):
sslcontext = self._get_ssl_context(req)
fingerprint, hashfunc = self._get_fingerprint_and_hashfunc(req)

Expand All @@ -792,45 +807,36 @@ def _create_direct_connection(self, req):
# it is problem of resolving proxy ip itself
raise ClientConnectorError(req.connection_key, exc) from exc

hosts = yield from self._resolve_host(req.url.raw_host, req.port)

for hinfo in hosts:
try:
host = hinfo['host']
port = hinfo['port']
transp, proto = yield from self._loop.create_connection(
self._factory, host, port,
ssl=sslcontext, family=hinfo['family'],
proto=hinfo['proto'], flags=hinfo['flags'],
server_hostname=hinfo['hostname'] if sslcontext else None,
local_addr=self._local_addr)
has_cert = transp.get_extra_info('sslcontext')
if has_cert and fingerprint:
sock = transp.get_extra_info('socket')
if not hasattr(sock, 'getpeercert'):
# Workaround for asyncio 3.5.0
# Starting from 3.5.1 version
# there is 'ssl_object' extra info in transport
sock = transp._ssl_protocol._sslpipe.ssl_object
# gives DER-encoded cert as a sequence of bytes (or None)
cert = sock.getpeercert(binary_form=True)
assert cert
got = hashfunc(cert).digest()
expected = fingerprint
if got != expected:
transp.close()
if not self._cleanup_closed_disabled:
self._cleanup_closed_transports.append(transp)
raise ServerFingerprintMismatch(
expected, got, host, port)
return transp, proto
except certificate_errors as exc:
raise ClientConnectorCertificateError(
req.connection_key, exc) from exc
except ssl_errors as exc:
raise ClientConnectorSSLError(req.connection_key, exc) from exc
except OSError as exc:
raise ClientConnectorError(req.connection_key, exc) from exc
host = hinfo['host']
port = hinfo['port']
transp, proto = yield from self._wrap_create_connection(
self._factory, host, port,
ssl=sslcontext, family=hinfo['family'],
proto=hinfo['proto'], flags=hinfo['flags'],
server_hostname=hinfo['hostname'] if sslcontext else None,
local_addr=self._local_addr,
req=req, client_error=client_error)
has_cert = transp.get_extra_info('sslcontext')
if has_cert and fingerprint:
sock = transp.get_extra_info('socket')
if not hasattr(sock, 'getpeercert'):
# Workaround for asyncio 3.5.0
# Starting from 3.5.1 version
# there is 'ssl_object' extra info in transport
sock = transp._ssl_protocol._sslpipe.ssl_object
# gives DER-encoded cert as a sequence of bytes (or None)
cert = sock.getpeercert(binary_form=True)
assert cert
got = hashfunc(cert).digest()
expected = fingerprint
if got != expected:
transp.close()
if not self._cleanup_closed_disabled:
self._cleanup_closed_transports.append(transp)
raise ServerFingerprintMismatch(
expected, got, host, port)
return transp, proto

@asyncio.coroutine
def _create_proxy_connection(self, req):
Expand All @@ -847,12 +853,10 @@ def _create_proxy_connection(self, req):
verify_ssl=req.verify_ssl,
fingerprint=req.fingerprint,
ssl_context=req.ssl_context)
try:
# create connection to proxy server
transport, proto = yield from self._create_direct_connection(
proxy_req)
except OSError as exc:
raise ClientProxyConnectionError(proxy_req, exc) from exc

# create connection to proxy server
transport, proto = yield from self._create_direct_connection(
proxy_req, client_error=ClientProxyConnectionError)

auth = proxy_req.headers.pop(hdrs.AUTHORIZATION, None)
if auth is not None:
Expand Down Expand Up @@ -903,9 +907,10 @@ def _create_proxy_connection(self, req):
finally:
transport.close()

transport, proto = yield from self._loop.create_connection(
transport, proto = yield from self._wrap_create_connection(
self._factory, ssl=sslcontext, sock=rawsock,
server_hostname=req.host)
server_hostname=req.host,
req=req)
finally:
proxy_resp.close()

Expand Down
129 changes: 105 additions & 24 deletions tests/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import gc
import hashlib
import socket
import ssl
import unittest
from unittest import mock

Expand Down Expand Up @@ -231,6 +232,24 @@ def test_proxy_dns_error(self):
self.assertEqual(req.url.path, '/')
self.assertEqual(dict(req.headers), expected_headers)

def test_proxy_connection_error(self):
connector = aiohttp.TCPConnector(loop=self.loop)
connector._resolve_host = make_mocked_coro([{
'hostname': 'www.python.org',
'host': '127.0.0.1', 'port': 80,
'family': socket.AF_INET, 'proto': 0,
'flags': socket.AI_NUMERICHOST}])
connector._loop.create_connection = make_mocked_coro(
raise_exception=OSError('dont take it serious'))

req = ClientRequest(
'GET', URL('http://www.python.org'),
proxy=URL('http://proxy.example.com'),
loop=self.loop,
)
with self.assertRaises(aiohttp.ClientProxyConnectionError):
self.loop.run_until_complete(connector.connect(req))

@mock.patch('aiohttp.connector.ClientRequest')
def test_auth(self, ClientRequestMock):
proxy_req = ClientRequest(
Expand Down Expand Up @@ -314,30 +333,6 @@ def test_auth_from_url(self, ClientRequestMock):
ssl_context=None, verify_ssl=None)
conn.close()

@mock.patch('aiohttp.connector.ClientRequest')
def test_auth__not_modifying_request(self, ClientRequestMock):
proxy_req = ClientRequest('GET',
URL('http://user:[email protected]'),
loop=self.loop)
ClientRequestMock.return_value = proxy_req
proxy_req_headers = dict(proxy_req.headers)

connector = aiohttp.TCPConnector(loop=self.loop)
connector._resolve_host = make_mocked_coro(
raise_exception=OSError('nothing personal'))

req = ClientRequest(
'GET', URL('http://www.python.org'),
proxy=URL('http://user:[email protected]'),
loop=self.loop,
)
req_headers = dict(req.headers)
with self.assertRaises(aiohttp.ClientConnectorError):
self.loop.run_until_complete(connector.connect(req))
self.assertEqual(req.headers, req_headers)
self.assertEqual(req.url.path, '/')
self.assertEqual(proxy_req.headers, proxy_req_headers)

@mock.patch('aiohttp.connector.ClientRequest')
def test_https_connect(self, ClientRequestMock):
proxy_req = ClientRequest('GET', URL('http://proxy.example.com'),
Expand Down Expand Up @@ -375,6 +370,92 @@ def test_https_connect(self, ClientRequestMock):
proxy_resp.close()
self.loop.run_until_complete(req.close())

@mock.patch('aiohttp.connector.ClientRequest')
def test_https_connect_certificate_error(self, ClientRequestMock):
proxy_req = ClientRequest('GET', URL('http://proxy.example.com'),
loop=self.loop)
ClientRequestMock.return_value = proxy_req

proxy_resp = ClientResponse('get', URL('http://proxy.example.com'))
proxy_resp._loop = self.loop
proxy_req.send = send_mock = mock.Mock()
send_mock.return_value = proxy_resp
proxy_resp.start = make_mocked_coro(mock.Mock(status=200))

connector = aiohttp.TCPConnector(loop=self.loop)
connector._resolve_host = make_mocked_coro(
[{'hostname': 'hostname', 'host': '127.0.0.1', 'port': 80,
'family': socket.AF_INET, 'proto': 0, 'flags': 0}])

seq = 0

@asyncio.coroutine
def create_connection(*args, **kwargs):
nonlocal seq
seq += 1

# connection to http://proxy.example.com
if seq == 1:
return mock.Mock(), mock.Mock()
# connection to https://www.python.org
elif seq == 2:
raise ssl.CertificateError
else:
assert False

self.loop.create_connection = create_connection

req = ClientRequest(
'GET', URL('https://www.python.org'),
proxy=URL('http://proxy.example.com'),
loop=self.loop,
)
with self.assertRaises(aiohttp.ClientConnectorCertificateError):
self.loop.run_until_complete(connector._create_connection(req))

@mock.patch('aiohttp.connector.ClientRequest')
def test_https_connect_ssl_error(self, ClientRequestMock):
proxy_req = ClientRequest('GET', URL('http://proxy.example.com'),
loop=self.loop)
ClientRequestMock.return_value = proxy_req

proxy_resp = ClientResponse('get', URL('http://proxy.example.com'))
proxy_resp._loop = self.loop
proxy_req.send = send_mock = mock.Mock()
send_mock.return_value = proxy_resp
proxy_resp.start = make_mocked_coro(mock.Mock(status=200))

connector = aiohttp.TCPConnector(loop=self.loop)
connector._resolve_host = make_mocked_coro(
[{'hostname': 'hostname', 'host': '127.0.0.1', 'port': 80,
'family': socket.AF_INET, 'proto': 0, 'flags': 0}])

seq = 0

@asyncio.coroutine
def create_connection(*args, **kwargs):
nonlocal seq
seq += 1

# connection to http://proxy.example.com
if seq == 1:
return mock.Mock(), mock.Mock()
# connection to https://www.python.org
elif seq == 2:
raise ssl.SSLError
else:
assert False

self.loop.create_connection = create_connection

req = ClientRequest(
'GET', URL('https://www.python.org'),
proxy=URL('http://proxy.example.com'),
loop=self.loop,
)
with self.assertRaises(aiohttp.ClientConnectorSSLError):
self.loop.run_until_complete(connector._create_connection(req))

@mock.patch('aiohttp.connector.ClientRequest')
def test_https_connect_runtime_error(self, ClientRequestMock):
proxy_req = ClientRequest('GET', URL('http://proxy.example.com'),
Expand Down