Skip to content

Commit

Permalink
Don't reuse a connection with the same URL but different proxy/TLS se…
Browse files Browse the repository at this point in the history
…ttings (#2985)

* Fix #2981: Don't reuse a connection with the same URL but different proxy/TLS settings

* Add TLS to white list

* Fix broken TLS over proxy
  • Loading branch information
asvetlov committed May 10, 2018
1 parent c72910c commit 920b058
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 49 deletions.
1 change: 1 addition & 0 deletions CHANGES/2981.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Don't reuse a connection with the same URL but different proxy/TLS settings
2 changes: 1 addition & 1 deletion aiohttp/client_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def port(self):

@property
def ssl(self):
return self._conn_key.ssl
return self._conn_key.is_ssl

def __str__(self):
return ('Cannot connect to host {0.host}:{0.port} ssl:{0.ssl} '
Expand Down
22 changes: 19 additions & 3 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import sys
import traceback
import warnings
from collections import namedtuple
from hashlib import md5, sha1, sha256
from http.cookies import CookieError, Morsel, SimpleCookie
from types import MappingProxyType
Expand Down Expand Up @@ -136,7 +135,17 @@ def _merge_ssl_params(ssl, verify_ssl, ssl_context, fingerprint):
return ssl


ConnectionKey = namedtuple('ConnectionKey', ['host', 'port', 'ssl'])
@attr.s(slots=True, frozen=True)
class ConnectionKey:
# the key should contain an information about used proxy / TLS
# to prevent reusing wrong connections from a pool
host = attr.ib(type=str)
port = attr.ib(type=int)
is_ssl = attr.ib(type=bool)
ssl = attr.ib() # SSLContext or None
proxy = attr.ib() # URL or None
proxy_auth = attr.ib() # BasicAuth
proxy_headers_hash = attr.ib(type=int) # hash(CIMultiDict)


def _is_expected_content_type(response_content_type, expected_content_type):
Expand Down Expand Up @@ -237,7 +246,14 @@ def ssl(self):

@property
def connection_key(self):
return ConnectionKey(self.host, self.port, self.is_ssl())
proxy_headers = self.proxy_headers
if proxy_headers:
h = hash(tuple((k, v) for k, v in proxy_headers.items()))
else:
h = None
return ConnectionKey(self.host, self.port, self.is_ssl(),
self.ssl,
self.proxy, self.proxy_auth, h)

@property
def host(self):
Expand Down
14 changes: 10 additions & 4 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from itertools import cycle, islice
from time import monotonic

import attr

from . import hdrs, helpers
from .client_exceptions import (ClientConnectionError,
ClientConnectorCertificateError,
Expand Down Expand Up @@ -272,7 +274,8 @@ def _cleanup(self):
if proto.is_connected():
if use_time - deadline < 0:
transport = proto.close()
if key[-1] and not self._cleanup_closed_disabled:
if (key.is_ssl and
not self._cleanup_closed_disabled):
self._cleanup_closed_transports.append(
transport)
else:
Expand Down Expand Up @@ -482,7 +485,7 @@ def _get(self, key):
if t1 - t0 > self._keepalive_timeout:
transport = proto.close()
# only for SSL transports
if key[-1] and not self._cleanup_closed_disabled:
if key.is_ssl and not self._cleanup_closed_disabled:
self._cleanup_closed_transports.append(transport)
else:
if not conns:
Expand Down Expand Up @@ -546,7 +549,7 @@ def _release(self, key, protocol, *, should_close=False):
if should_close or protocol.should_close:
transport = protocol.close()

if key[-1] and not self._cleanup_closed_disabled:
if key.is_ssl and not self._cleanup_closed_disabled:
self._cleanup_closed_transports.append(transport)
else:
conns = self._conns.get(key)
Expand Down Expand Up @@ -918,7 +921,10 @@ async def _create_proxy_connection(self, req, traces=None):
# asyncio handles this perfectly
proxy_req.method = hdrs.METH_CONNECT
proxy_req.url = req.url
key = (req.host, req.port, req.ssl)
key = attr.evolve(req.connection_key,
proxy=None,
proxy_auth=None,
proxy_headers_hash=None)
conn = Connection(self, key, proto, self._loop)
proxy_resp = await proxy_req.send(conn)
try:
Expand Down
1 change: 1 addition & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ syscalls
Systemd
tarball
TCP
TLS
teardown
Teardown
TestClient
Expand Down
77 changes: 36 additions & 41 deletions tests/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import aiohttp
from aiohttp import client, web
from aiohttp.client import ClientRequest
from aiohttp.client_reqrep import ConnectionKey
from aiohttp.connector import Connection, _DNSCacheTable
from aiohttp.test_utils import make_mocked_coro, unused_port
from aiohttp.tracing import Trace
Expand All @@ -25,19 +26,19 @@
@pytest.fixture()
def key():
"""Connection key"""
return ('localhost1', 80, False)
return ConnectionKey('localhost', 80, False, None, None, None, None)


@pytest.fixture
def key2():
"""Connection key"""
return ('localhost2', 80, False)
return ConnectionKey('localhost', 80, False, None, None, None, None)


@pytest.fixture
def ssl_key():
"""Connection key"""
return ('localhost', 80, True)
return ConnectionKey('localhost', 80, True, None, None, None, None)


@pytest.fixture
Expand Down Expand Up @@ -266,22 +267,24 @@ def test_get(loop):

def test_get_expired(loop):
conn = aiohttp.BaseConnector(loop=loop)
assert conn._get(('localhost', 80, False)) is None
key = ConnectionKey('localhost', 80, False, None, None, None, None)
assert conn._get(key) is None

proto = mock.Mock()
conn._conns[('localhost', 80, False)] = [(proto, loop.time() - 1000)]
assert conn._get(('localhost', 80, False)) is None
conn._conns[key] = [(proto, loop.time() - 1000)]
assert conn._get(key) is None
assert not conn._conns
conn.close()


def test_get_expired_ssl(loop):
conn = aiohttp.BaseConnector(loop=loop, enable_cleanup_closed=True)
assert conn._get(('localhost', 80, True)) is None
key = ConnectionKey('localhost', 80, True, None, None, None, None)
assert conn._get(key) is None

proto = mock.Mock()
conn._conns[('localhost', 80, True)] = [(proto, loop.time() - 1000)]
assert conn._get(('localhost', 80, True)) is None
conn._conns[key] = [(proto, loop.time() - 1000)]
assert conn._get(key) is None
assert not conn._conns
assert conn._cleanup_closed_transports == [proto.close.return_value]
conn.close()
Expand Down Expand Up @@ -443,11 +446,10 @@ def test_release_waiter_no_available(loop, key, key2):
conn.close()


def test_release_close(loop):
def test_release_close(loop, key):
conn = aiohttp.BaseConnector(loop=loop)
proto = mock.Mock(should_close=True)

key = ('localhost', 80, False)
conn._acquired.add(proto)
conn._release(key, proto)
assert not conn._conns
Expand Down Expand Up @@ -946,20 +948,18 @@ def test_get_pop_empty_conns(loop):
assert not conn._conns


def test_release_close_do_not_add_to_pool(loop):
def test_release_close_do_not_add_to_pool(loop, key):
# see issue #473
conn = aiohttp.BaseConnector(loop=loop)

key = ('127.0.0.1', 80, False)
proto = mock.Mock(should_close=True)

conn._acquired.add(proto)
conn._release(key, proto)
assert not conn._conns


def test_release_close_do_not_delete_existing_connections(loop):
key = ('127.0.0.1', 80, False)
def test_release_close_do_not_delete_existing_connections(loop, key):
proto1 = mock.Mock()

conn = aiohttp.BaseConnector(loop=loop)
Expand Down Expand Up @@ -987,24 +987,22 @@ def test_release_not_started(loop):
conn.close()


def test_release_not_opened(loop):
def test_release_not_opened(loop, key):
conn = aiohttp.BaseConnector(loop=loop)

proto = mock.Mock()
key = ('localhost', 80, False)
conn._acquired.add(proto)
conn._release(key, proto)
assert proto.close.called


async def test_connect(loop):
async def test_connect(loop, key):
proto = mock.Mock()
proto.is_connected.return_value = True

req = ClientRequest('GET', URL('http://host:80'), loop=loop)
req = ClientRequest('GET', URL('http://localhost:80'), loop=loop)

conn = aiohttp.BaseConnector(loop=loop)
key = ('host', 80, False)
conn._conns[key] = [(proto, loop.time())]
conn._create_connection = mock.Mock()
conn._create_connection.return_value = loop.create_future()
Expand Down Expand Up @@ -1098,8 +1096,7 @@ def test_ctor_cleanup():
assert conn._cleanup_closed_handle is not None


def test_cleanup():
key = ('localhost', 80, False)
def test_cleanup(key):
testset = {
key: [(mock.Mock(), 10),
(mock.Mock(), 300)],
Expand All @@ -1119,10 +1116,9 @@ def test_cleanup():
assert conn._cleanup_handle is not None


def test_cleanup_close_ssl_transport():
def test_cleanup_close_ssl_transport(ssl_key):
proto = mock.Mock()
key = ('localhost', 80, True)
testset = {key: [(proto, 10)]}
testset = {ssl_key: [(proto, 10)]}

loop = mock.Mock()
loop.time.return_value = 300
Expand Down Expand Up @@ -1153,8 +1149,7 @@ def test_cleanup2():
conn.close()


def test_cleanup3():
key = ('localhost', 80, False)
def test_cleanup3(key):
testset = {key: [(mock.Mock(), 290.1),
(mock.Mock(), 305.1)]}
testset[key][0][0].is_connected.return_value = True
Expand Down Expand Up @@ -1374,7 +1369,7 @@ async def test_connect_with_limit(loop, key):
proto = mock.Mock()
proto.is_connected.return_value = True

req = ClientRequest('GET', URL('http://localhost1:80'),
req = ClientRequest('GET', URL('http://localhost:80'),
loop=loop,
response_class=mock.Mock())

Expand Down Expand Up @@ -1499,7 +1494,7 @@ async def test_connect_reuseconn_tracing(loop, key):
proto = mock.Mock()
proto.is_connected.return_value = True

req = ClientRequest('GET', URL('http://localhost1:80'),
req = ClientRequest('GET', URL('http://localhost:80'),
loop=loop,
response_class=mock.Mock())

Expand All @@ -1520,7 +1515,7 @@ async def test_connect_with_limit_and_limit_per_host(loop, key):
proto = mock.Mock()
proto.is_connected.return_value = True

req = ClientRequest('GET', URL('http://localhost1:80'), loop=loop)
req = ClientRequest('GET', URL('http://localhost:80'), loop=loop)

conn = aiohttp.BaseConnector(loop=loop, limit=1000, limit_per_host=1)
conn._conns[key] = [(proto, loop.time())]
Expand Down Expand Up @@ -1586,7 +1581,7 @@ async def test_connect_with_no_limits(loop, key):
proto = mock.Mock()
proto.is_connected.return_value = True

req = ClientRequest('GET', URL('http://localhost1:80'), loop=loop)
req = ClientRequest('GET', URL('http://localhost:80'), loop=loop)

conn = aiohttp.BaseConnector(loop=loop, limit=0, limit_per_host=0)
conn._conns[key] = [(proto, loop.time())]
Expand Down Expand Up @@ -1834,11 +1829,11 @@ def test_force_close_and_explicit_keep_alive(loop):
assert conn


async def test_error_on_connection(loop):
async def test_error_on_connection(loop, key):
conn = aiohttp.BaseConnector(limit=1, loop=loop)

req = mock.Mock()
req.connection_key = 'key'
req.connection_key = key
proto = mock.Mock()
i = 0

Expand All @@ -1861,16 +1856,16 @@ async def create_connection(req, traces=None):
await asyncio.sleep(0, loop=loop)
assert not t1.done()
assert not t2.done()
assert len(conn._acquired_per_host['key']) == 1
assert len(conn._acquired_per_host[key]) == 1

fut.set_result(None)
with pytest.raises(OSError):
await t1

ret = await t2
assert len(conn._acquired_per_host['key']) == 1
assert len(conn._acquired_per_host[key]) == 1

assert ret._key == 'key'
assert ret._key == key
assert ret.protocol == proto
assert proto in conn._acquired
ret.release()
Expand Down Expand Up @@ -1898,11 +1893,11 @@ async def create_connection(req, traces=None):
await conn2


async def test_error_on_connection_with_cancelled_waiter(loop):
async def test_error_on_connection_with_cancelled_waiter(loop, key):
conn = aiohttp.BaseConnector(limit=1, loop=loop)

req = mock.Mock()
req.connection_key = 'key'
req.connection_key = key
proto = mock.Mock()
i = 0

Expand All @@ -1929,7 +1924,7 @@ async def create_connection(req, traces=None):
await asyncio.sleep(0, loop=loop)
assert not t1.done()
assert not t2.done()
assert len(conn._acquired_per_host['key']) == 1
assert len(conn._acquired_per_host[key]) == 1

fut1.set_result(None)
fut2.cancel()
Expand All @@ -1940,9 +1935,9 @@ async def create_connection(req, traces=None):
await t2

ret = await t3
assert len(conn._acquired_per_host['key']) == 1
assert len(conn._acquired_per_host[key]) == 1

assert ret._key == 'key'
assert ret._key == key
assert ret.protocol == proto
assert proto in conn._acquired
ret.release()
Expand Down

0 comments on commit 920b058

Please sign in to comment.