diff --git a/aiohttp/connector.py b/aiohttp/connector.py index f6a59437925..2cb2e9d2485 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -6,6 +6,8 @@ import warnings from collections import defaultdict from hashlib import md5, sha1, sha256 +from itertools import cycle, islice +from time import monotonic from types import MappingProxyType from . import hdrs, helpers @@ -483,6 +485,55 @@ def _create_connection(self, req): _SSL_OP_NO_COMPRESSION = getattr(ssl, "OP_NO_COMPRESSION", 0) +class _DNSCacheTable: + + def __init__(self, ttl=None): + self._addrs = {} + self._addrs_rr = {} + self._timestamps = {} + self._ttl = ttl + + def __contains__(self, host): + return host in self._addrs + + @property + def addrs(self): + return self._addrs + + def add(self, host, addrs): + self._addrs[host] = addrs + self._addrs_rr[host] = cycle(addrs) + + if self._ttl: + self._timestamps[host] = monotonic() + + def remove(self, host): + self._addrs.pop(host, None) + self._addrs_rr.pop(host, None) + + if self._ttl: + self._timestamps.pop(host, None) + + def clear(self): + self._addrs.clear() + self._addrs_rr.clear() + self._timestamps.clear() + + def next_addrs(self, host): + # Return an iterator that will get at maximum as many addrs + # there are for the specific host starting from the last + # not itereated addr. + return islice(self._addrs_rr[host], len(self._addrs[host])) + + def expired(self, host): + if self._ttl is None: + return False + + return ( + self._timestamps[host] + self._ttl + ) < monotonic() + + class TCPConnector(BaseConnector): """TCP connector. @@ -545,9 +596,7 @@ def __init__(self, *, verify_ssl=True, fingerprint=None, self._resolver = resolver self._use_dns_cache = use_dns_cache - self._ttl_dns_cache = ttl_dns_cache - self._cached_hosts = {} - self._cached_hosts_timestamp = {} + self._cached_hosts = _DNSCacheTable(ttl=ttl_dns_cache) self._ssl_context = ssl_context self._family = family self._local_addr = local_addr @@ -593,26 +642,17 @@ def use_dns_cache(self): @property def cached_hosts(self): """Read-only dict of cached DNS record.""" - return MappingProxyType(self._cached_hosts) + return MappingProxyType(self._cached_hosts.addrs) def clear_dns_cache(self, host=None, port=None): """Remove specified host/port or clear all dns local cache.""" if host is not None and port is not None: - self._cached_hosts.pop((host, port), None) - self._cached_hosts_timestamp.pop((host, port), None) + self._cached_hosts.remove((host, port)) elif host is not None or port is not None: raise ValueError("either both host and port " "or none of them are allowed") else: self._cached_hosts.clear() - self._cached_hosts_timestamp.clear() - - def _dns_entry_expired(self, key): - if self._ttl_dns_cache is None: - return False - return ( - self._cached_hosts_timestamp[key] + self._ttl_dns_cache - ) < self._loop.time() @asyncio.coroutine def _resolve_host(self, host, port): @@ -623,12 +663,13 @@ def _resolve_host(self, host, port): if self._use_dns_cache: key = (host, port) - if key not in self._cached_hosts or (self._dns_entry_expired(key)): - self._cached_hosts[key] = yield from \ + if key not in self._cached_hosts or\ + self._cached_hosts.expired(key): + addrs = yield from \ self._resolver.resolve(host, port, family=self._family) - self._cached_hosts_timestamp[key] = self._loop.time() + self._cached_hosts.add(key, addrs) - return self._cached_hosts[key] + return self._cached_hosts.next_addrs(key) else: res = yield from self._resolver.resolve( host, port, family=self._family) diff --git a/tests/test_connector.py b/tests/test_connector.py index d6e088823c0..41ff713074a 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -9,6 +9,7 @@ import ssl import tempfile import unittest +from time import sleep from unittest import mock import pytest @@ -17,7 +18,7 @@ import aiohttp from aiohttp import client, helpers, web from aiohttp.client import ClientRequest -from aiohttp.connector import Connection +from aiohttp.connector import Connection, _DNSCacheTable from aiohttp.test_utils import unused_port @@ -365,40 +366,57 @@ def test_tcp_connector_resolve_host(loop): @asyncio.coroutine -def test_tcp_connector_dns_cache_not_expired(loop): - conn = aiohttp.TCPConnector( - loop=loop, - use_dns_cache=True, - ttl_dns_cache=10 - ) +def dns_response(): + return ["127.0.0.1"] - res = yield from conn._resolve_host('localhost', 8080) - res2 = yield from conn._resolve_host('localhost', 8080) - assert res is res2 +@asyncio.coroutine +def test_tcp_connector_dns_cache_not_expired(loop): + with mock.patch('aiohttp.connector.DefaultResolver') as m_resolver: + conn = aiohttp.TCPConnector( + loop=loop, + use_dns_cache=True, + ttl_dns_cache=10 + ) + m_resolver().resolve.return_value = dns_response() + yield from conn._resolve_host('localhost', 8080) + yield from conn._resolve_host('localhost', 8080) + m_resolver().resolve.assert_called_once_with( + 'localhost', + 8080, + family=0 + ) @asyncio.coroutine def test_tcp_connector_dns_cache_forever(loop): - conn = aiohttp.TCPConnector( - loop=loop, - use_dns_cache=True, - ttl_dns_cache=None - ) - - res = yield from conn._resolve_host('localhost', 8080) - res2 = yield from conn._resolve_host('localhost', 8080) - assert res is res2 + with mock.patch('aiohttp.connector.DefaultResolver') as m_resolver: + conn = aiohttp.TCPConnector( + loop=loop, + use_dns_cache=True, + ttl_dns_cache=10 + ) + m_resolver().resolve.return_value = dns_response() + yield from conn._resolve_host('localhost', 8080) + yield from conn._resolve_host('localhost', 8080) + m_resolver().resolve.assert_called_once_with( + 'localhost', + 8080, + family=0 + ) @asyncio.coroutine def test_tcp_connector_use_dns_cache_disabled(loop): - conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=False) - - res = yield from conn._resolve_host('localhost', 8080) - res2 = yield from conn._resolve_host('localhost', 8080) - - assert res is not res2 + with mock.patch('aiohttp.connector.DefaultResolver') as m_resolver: + conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=False) + m_resolver().resolve.return_value = dns_response() + yield from conn._resolve_host('localhost', 8080) + yield from conn._resolve_host('localhost', 8080) + m_resolver().resolve.assert_has_calls([ + mock.call('localhost', 8080, family=0), + mock.call('localhost', 8080, family=0) + ]) def test_get_pop_empty_conns(loop): @@ -631,20 +649,15 @@ def test_tcp_connector_fingerprint_invalid(loop): def test_tcp_connector_clear_dns_cache(loop): conn = aiohttp.TCPConnector(loop=loop) - info = object() - conn._cached_hosts[('localhost', 123)] = info - conn._cached_hosts_timestamp[('localhost', 123)] = 100 - conn._cached_hosts[('localhost', 124)] = info - conn._cached_hosts_timestamp[('localhost', 124)] = 101 + hosts = ['a', 'b'] + conn._cached_hosts.add(('localhost', 123), hosts) + conn._cached_hosts.add(('localhost', 124), hosts) conn.clear_dns_cache('localhost', 123) - assert conn.cached_hosts == {('localhost', 124): info} - assert conn._cached_hosts_timestamp == {('localhost', 124): 101} + assert ('localhost', 123) not in conn.cached_hosts conn.clear_dns_cache('localhost', 123) - assert conn.cached_hosts == {('localhost', 124): info} - assert conn._cached_hosts_timestamp == {('localhost', 124): 101} + assert ('localhost', 123) not in conn.cached_hosts conn.clear_dns_cache() assert conn.cached_hosts == {} - assert conn._cached_hosts_timestamp == {} def test_tcp_connector_clear_dns_cache_bad_args(loop): @@ -1181,3 +1194,57 @@ def test_resolver_not_called_with_address_is_ip(self): self.loop.run_until_complete(connector.connect(req)) resolver.resolve.assert_not_called() + + +class TestDNSCacheTable: + + @pytest.fixture + def dns_cache_table(self): + return _DNSCacheTable() + + def test_addrs(self, dns_cache_table): + dns_cache_table.add('localhost', ['127.0.0.1']) + dns_cache_table.add('foo', ['127.0.0.2']) + assert dns_cache_table.addrs == { + 'localhost': ['127.0.0.1'], + 'foo': ['127.0.0.2'] + } + + def test_remove(self, dns_cache_table): + dns_cache_table.add('localhost', ['127.0.0.1']) + dns_cache_table.remove('localhost') + assert dns_cache_table.addrs == {} + + def test_clear(self, dns_cache_table): + dns_cache_table.add('localhost', ['127.0.0.1']) + dns_cache_table.clear() + assert dns_cache_table.addrs == {} + + def test_not_expired_ttl_None(self, dns_cache_table): + dns_cache_table.add('localhost', ['127.0.0.1']) + assert not dns_cache_table.expired('localhost') + + def test_not_expired_ttl(self): + dns_cache_table = _DNSCacheTable(ttl=0.1) + dns_cache_table.add('localhost', ['127.0.0.1']) + assert not dns_cache_table.expired('localhost') + + def test_expired_ttl(self): + dns_cache_table = _DNSCacheTable(ttl=0.1) + dns_cache_table.add('localhost', ['127.0.0.1']) + sleep(0.1) + assert dns_cache_table.expired('localhost') + + def test_next_addrs(self, dns_cache_table): + dns_cache_table.add('foo', ['127.0.0.1', '127.0.0.2']) + + # max elements returned are the full list of addrs + addrs = list(dns_cache_table.next_addrs('foo')) + assert addrs == ['127.0.0.1', '127.0.0.2'] + + # different calls to next_addrs return the hosts using + # a round robin strategy. + addrs = dns_cache_table.next_addrs('foo') + assert next(addrs) == '127.0.0.1' + addrs = dns_cache_table.next_addrs('foo') + assert next(addrs) == '127.0.0.2'