diff --git a/CHANGES/2981.bugfix b/CHANGES/2981.bugfix new file mode 100644 index 00000000000..89dff57bf02 --- /dev/null +++ b/CHANGES/2981.bugfix @@ -0,0 +1 @@ +Don't reuse a connection with the same URL but different proxy/TLS settings \ No newline at end of file diff --git a/aiohttp/client_exceptions.py b/aiohttp/client_exceptions.py index 34b72daf76f..9eacc4e318d 100644 --- a/aiohttp/client_exceptions.py +++ b/aiohttp/client_exceptions.py @@ -241,7 +241,7 @@ def port(self): @property def ssl(self): - return self._conn_key.ssl + return self._conn_key.is_ssl def __str__(self): return ('Cannot connect to host {0.host}:{0.port} ssl:{0.ssl} ' diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index f4688cc300e..9c0a16eba18 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -6,7 +6,6 @@ import sys import traceback import warnings -from collections import namedtuple from hashlib import md5, sha1, sha256 from http.cookies import CookieError, Morsel, SimpleCookie from types import MappingProxyType @@ -136,7 +135,17 @@ def _merge_ssl_params(ssl, verify_ssl, ssl_context, fingerprint): return ssl -ConnectionKey = namedtuple('ConnectionKey', ['host', 'port', 'ssl']) +@attr.s(slots=True, frozen=True) +class ConnectionKey: + # the key should contain an information about used proxy / TLS + # to prevent reusing wrong connections from a pool + host = attr.ib(type=str) + port = attr.ib(type=int) + is_ssl = attr.ib(type=bool) + ssl = attr.ib() # SSLContext or None + proxy = attr.ib() # URL or None + proxy_auth = attr.ib() # BasicAuth + proxy_headers_hash = attr.ib(type=int) # hash(CIMultiDict) def _is_expected_content_type(response_content_type, expected_content_type): @@ -237,7 +246,14 @@ def ssl(self): @property def connection_key(self): - return ConnectionKey(self.host, self.port, self.is_ssl()) + proxy_headers = self.proxy_headers + if proxy_headers: + h = hash(tuple((k, v) for k, v in proxy_headers.items())) + else: + h = None + return ConnectionKey(self.host, self.port, self.is_ssl(), + self.ssl, + self.proxy, self.proxy_auth, h) @property def host(self): diff --git a/aiohttp/connector.py b/aiohttp/connector.py index 634c01a1caa..6a17a43e13a 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -10,6 +10,8 @@ from itertools import cycle, islice from time import monotonic +import attr + from . import hdrs, helpers from .client_exceptions import (ClientConnectionError, ClientConnectorCertificateError, @@ -272,7 +274,8 @@ def _cleanup(self): if proto.is_connected(): if use_time - deadline < 0: transport = proto.close() - if key[-1] and not self._cleanup_closed_disabled: + if (key.is_ssl and + not self._cleanup_closed_disabled): self._cleanup_closed_transports.append( transport) else: @@ -482,7 +485,7 @@ def _get(self, key): if t1 - t0 > self._keepalive_timeout: transport = proto.close() # only for SSL transports - if key[-1] and not self._cleanup_closed_disabled: + if key.is_ssl and not self._cleanup_closed_disabled: self._cleanup_closed_transports.append(transport) else: if not conns: @@ -546,7 +549,7 @@ def _release(self, key, protocol, *, should_close=False): if should_close or protocol.should_close: transport = protocol.close() - if key[-1] and not self._cleanup_closed_disabled: + if key.is_ssl and not self._cleanup_closed_disabled: self._cleanup_closed_transports.append(transport) else: conns = self._conns.get(key) @@ -918,7 +921,10 @@ async def _create_proxy_connection(self, req, traces=None): # asyncio handles this perfectly proxy_req.method = hdrs.METH_CONNECT proxy_req.url = req.url - key = (req.host, req.port, req.ssl) + key = attr.evolve(req.connection_key, + proxy=None, + proxy_auth=None, + proxy_headers_hash=None) conn = Connection(self, key, proto, self._loop) proxy_resp = await proxy_req.send(conn) try: diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index bb90d8c724c..66accdc9120 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -255,6 +255,7 @@ syscalls Systemd tarball TCP +TLS teardown Teardown TestClient diff --git a/tests/test_connector.py b/tests/test_connector.py index 1ae32c8f24c..f0ef56a7a4e 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -17,6 +17,7 @@ import aiohttp from aiohttp import client, web from aiohttp.client import ClientRequest +from aiohttp.client_reqrep import ConnectionKey from aiohttp.connector import Connection, _DNSCacheTable from aiohttp.test_utils import make_mocked_coro, unused_port from aiohttp.tracing import Trace @@ -25,19 +26,19 @@ @pytest.fixture() def key(): """Connection key""" - return ('localhost1', 80, False) + return ConnectionKey('localhost', 80, False, None, None, None, None) @pytest.fixture def key2(): """Connection key""" - return ('localhost2', 80, False) + return ConnectionKey('localhost', 80, False, None, None, None, None) @pytest.fixture def ssl_key(): """Connection key""" - return ('localhost', 80, True) + return ConnectionKey('localhost', 80, True, None, None, None, None) @pytest.fixture @@ -266,22 +267,24 @@ def test_get(loop): def test_get_expired(loop): conn = aiohttp.BaseConnector(loop=loop) - assert conn._get(('localhost', 80, False)) is None + key = ConnectionKey('localhost', 80, False, None, None, None, None) + assert conn._get(key) is None proto = mock.Mock() - conn._conns[('localhost', 80, False)] = [(proto, loop.time() - 1000)] - assert conn._get(('localhost', 80, False)) is None + conn._conns[key] = [(proto, loop.time() - 1000)] + assert conn._get(key) is None assert not conn._conns conn.close() def test_get_expired_ssl(loop): conn = aiohttp.BaseConnector(loop=loop, enable_cleanup_closed=True) - assert conn._get(('localhost', 80, True)) is None + key = ConnectionKey('localhost', 80, True, None, None, None, None) + assert conn._get(key) is None proto = mock.Mock() - conn._conns[('localhost', 80, True)] = [(proto, loop.time() - 1000)] - assert conn._get(('localhost', 80, True)) is None + conn._conns[key] = [(proto, loop.time() - 1000)] + assert conn._get(key) is None assert not conn._conns assert conn._cleanup_closed_transports == [proto.close.return_value] conn.close() @@ -443,11 +446,10 @@ def test_release_waiter_no_available(loop, key, key2): conn.close() -def test_release_close(loop): +def test_release_close(loop, key): conn = aiohttp.BaseConnector(loop=loop) proto = mock.Mock(should_close=True) - key = ('localhost', 80, False) conn._acquired.add(proto) conn._release(key, proto) assert not conn._conns @@ -946,11 +948,10 @@ def test_get_pop_empty_conns(loop): assert not conn._conns -def test_release_close_do_not_add_to_pool(loop): +def test_release_close_do_not_add_to_pool(loop, key): # see issue #473 conn = aiohttp.BaseConnector(loop=loop) - key = ('127.0.0.1', 80, False) proto = mock.Mock(should_close=True) conn._acquired.add(proto) @@ -958,8 +959,7 @@ def test_release_close_do_not_add_to_pool(loop): assert not conn._conns -def test_release_close_do_not_delete_existing_connections(loop): - key = ('127.0.0.1', 80, False) +def test_release_close_do_not_delete_existing_connections(loop, key): proto1 = mock.Mock() conn = aiohttp.BaseConnector(loop=loop) @@ -987,24 +987,22 @@ def test_release_not_started(loop): conn.close() -def test_release_not_opened(loop): +def test_release_not_opened(loop, key): conn = aiohttp.BaseConnector(loop=loop) proto = mock.Mock() - key = ('localhost', 80, False) conn._acquired.add(proto) conn._release(key, proto) assert proto.close.called -async def test_connect(loop): +async def test_connect(loop, key): proto = mock.Mock() proto.is_connected.return_value = True - req = ClientRequest('GET', URL('http://host:80'), loop=loop) + req = ClientRequest('GET', URL('http://localhost:80'), loop=loop) conn = aiohttp.BaseConnector(loop=loop) - key = ('host', 80, False) conn._conns[key] = [(proto, loop.time())] conn._create_connection = mock.Mock() conn._create_connection.return_value = loop.create_future() @@ -1098,8 +1096,7 @@ def test_ctor_cleanup(): assert conn._cleanup_closed_handle is not None -def test_cleanup(): - key = ('localhost', 80, False) +def test_cleanup(key): testset = { key: [(mock.Mock(), 10), (mock.Mock(), 300)], @@ -1119,10 +1116,9 @@ def test_cleanup(): assert conn._cleanup_handle is not None -def test_cleanup_close_ssl_transport(): +def test_cleanup_close_ssl_transport(ssl_key): proto = mock.Mock() - key = ('localhost', 80, True) - testset = {key: [(proto, 10)]} + testset = {ssl_key: [(proto, 10)]} loop = mock.Mock() loop.time.return_value = 300 @@ -1153,8 +1149,7 @@ def test_cleanup2(): conn.close() -def test_cleanup3(): - key = ('localhost', 80, False) +def test_cleanup3(key): testset = {key: [(mock.Mock(), 290.1), (mock.Mock(), 305.1)]} testset[key][0][0].is_connected.return_value = True @@ -1374,7 +1369,7 @@ async def test_connect_with_limit(loop, key): proto = mock.Mock() proto.is_connected.return_value = True - req = ClientRequest('GET', URL('http://localhost1:80'), + req = ClientRequest('GET', URL('http://localhost:80'), loop=loop, response_class=mock.Mock()) @@ -1499,7 +1494,7 @@ async def test_connect_reuseconn_tracing(loop, key): proto = mock.Mock() proto.is_connected.return_value = True - req = ClientRequest('GET', URL('http://localhost1:80'), + req = ClientRequest('GET', URL('http://localhost:80'), loop=loop, response_class=mock.Mock()) @@ -1520,7 +1515,7 @@ async def test_connect_with_limit_and_limit_per_host(loop, key): proto = mock.Mock() proto.is_connected.return_value = True - req = ClientRequest('GET', URL('http://localhost1:80'), loop=loop) + req = ClientRequest('GET', URL('http://localhost:80'), loop=loop) conn = aiohttp.BaseConnector(loop=loop, limit=1000, limit_per_host=1) conn._conns[key] = [(proto, loop.time())] @@ -1586,7 +1581,7 @@ async def test_connect_with_no_limits(loop, key): proto = mock.Mock() proto.is_connected.return_value = True - req = ClientRequest('GET', URL('http://localhost1:80'), loop=loop) + req = ClientRequest('GET', URL('http://localhost:80'), loop=loop) conn = aiohttp.BaseConnector(loop=loop, limit=0, limit_per_host=0) conn._conns[key] = [(proto, loop.time())] @@ -1834,11 +1829,11 @@ def test_force_close_and_explicit_keep_alive(loop): assert conn -async def test_error_on_connection(loop): +async def test_error_on_connection(loop, key): conn = aiohttp.BaseConnector(limit=1, loop=loop) req = mock.Mock() - req.connection_key = 'key' + req.connection_key = key proto = mock.Mock() i = 0 @@ -1861,16 +1856,16 @@ async def create_connection(req, traces=None): await asyncio.sleep(0, loop=loop) assert not t1.done() assert not t2.done() - assert len(conn._acquired_per_host['key']) == 1 + assert len(conn._acquired_per_host[key]) == 1 fut.set_result(None) with pytest.raises(OSError): await t1 ret = await t2 - assert len(conn._acquired_per_host['key']) == 1 + assert len(conn._acquired_per_host[key]) == 1 - assert ret._key == 'key' + assert ret._key == key assert ret.protocol == proto assert proto in conn._acquired ret.release() @@ -1898,11 +1893,11 @@ async def create_connection(req, traces=None): await conn2 -async def test_error_on_connection_with_cancelled_waiter(loop): +async def test_error_on_connection_with_cancelled_waiter(loop, key): conn = aiohttp.BaseConnector(limit=1, loop=loop) req = mock.Mock() - req.connection_key = 'key' + req.connection_key = key proto = mock.Mock() i = 0 @@ -1929,7 +1924,7 @@ async def create_connection(req, traces=None): await asyncio.sleep(0, loop=loop) assert not t1.done() assert not t2.done() - assert len(conn._acquired_per_host['key']) == 1 + assert len(conn._acquired_per_host[key]) == 1 fut1.set_result(None) fut2.cancel() @@ -1940,9 +1935,9 @@ async def create_connection(req, traces=None): await t2 ret = await t3 - assert len(conn._acquired_per_host['key']) == 1 + assert len(conn._acquired_per_host[key]) == 1 - assert ret._key == 'key' + assert ret._key == key assert ret.protocol == proto assert proto in conn._acquired ret.release()