Skip to content

Commit

Permalink
Choose addr cached based on a round robin strategy
Browse files Browse the repository at this point in the history
At each attempt to open a new connection, the addrs related to that
specific host will be retrieved using a round robin strategy. As a
result all hosts resolved by the DNS query will be used.
  • Loading branch information
pfreixes committed Apr 23, 2017
1 parent 4bd7040 commit 3da3a01
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 53 deletions.
77 changes: 59 additions & 18 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import warnings
from collections import defaultdict
from hashlib import md5, sha1, sha256
from itertools import cycle, islice
from time import monotonic
from types import MappingProxyType

from . import hdrs, helpers
Expand Down Expand Up @@ -483,6 +485,55 @@ def _create_connection(self, req):
_SSL_OP_NO_COMPRESSION = getattr(ssl, "OP_NO_COMPRESSION", 0)


class _DNSCacheTable:

def __init__(self, ttl=None):
self._addrs = {}
self._addrs_rr = {}
self._timestamps = {}
self._ttl = ttl

def __contains__(self, host):
return host in self._addrs

@property
def addrs(self):
return self._addrs

def add(self, host, addrs):
self._addrs[host] = addrs
self._addrs_rr[host] = cycle(addrs)

if self._ttl:
self._timestamps[host] = monotonic()

def remove(self, host):
self._addrs.pop(host, None)
self._addrs_rr.pop(host, None)

if self._ttl:
self._timestamps.pop(host, None)

def clear(self):
self._addrs.clear()
self._addrs_rr.clear()
self._timestamps.clear()

def next_addrs(self, host):
# Return an iterator that will get at maximum as many addrs
# there are for the specific host starting from the last
# not itereated addr.
return islice(self._addrs_rr[host], len(self._addrs[host]))

def expired(self, host):
if self._ttl is None:
return False

return (
self._timestamps[host] + self._ttl
) < monotonic()


class TCPConnector(BaseConnector):
"""TCP connector.
Expand Down Expand Up @@ -545,9 +596,7 @@ def __init__(self, *, verify_ssl=True, fingerprint=None,
self._resolver = resolver

self._use_dns_cache = use_dns_cache
self._ttl_dns_cache = ttl_dns_cache
self._cached_hosts = {}
self._cached_hosts_timestamp = {}
self._cached_hosts = _DNSCacheTable(ttl=ttl_dns_cache)
self._ssl_context = ssl_context
self._family = family
self._local_addr = local_addr
Expand Down Expand Up @@ -593,26 +642,17 @@ def use_dns_cache(self):
@property
def cached_hosts(self):
"""Read-only dict of cached DNS record."""
return MappingProxyType(self._cached_hosts)
return MappingProxyType(self._cached_hosts.addrs)

def clear_dns_cache(self, host=None, port=None):
"""Remove specified host/port or clear all dns local cache."""
if host is not None and port is not None:
self._cached_hosts.pop((host, port), None)
self._cached_hosts_timestamp.pop((host, port), None)
self._cached_hosts.remove((host, port))
elif host is not None or port is not None:
raise ValueError("either both host and port "
"or none of them are allowed")
else:
self._cached_hosts.clear()
self._cached_hosts_timestamp.clear()

def _dns_entry_expired(self, key):
if self._ttl_dns_cache is None:
return False
return (
self._cached_hosts_timestamp[key] + self._ttl_dns_cache
) < self._loop.time()

@asyncio.coroutine
def _resolve_host(self, host, port):
Expand All @@ -623,12 +663,13 @@ def _resolve_host(self, host, port):
if self._use_dns_cache:
key = (host, port)

if key not in self._cached_hosts or (self._dns_entry_expired(key)):
self._cached_hosts[key] = yield from \
if key not in self._cached_hosts or\
self._cached_hosts.expired(key):
addrs = yield from \
self._resolver.resolve(host, port, family=self._family)
self._cached_hosts_timestamp[key] = self._loop.time()
self._cached_hosts.add(key, addrs)

return self._cached_hosts[key]
return self._cached_hosts.next_addrs(key)
else:
res = yield from self._resolver.resolve(
host, port, family=self._family)
Expand Down
137 changes: 102 additions & 35 deletions tests/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import ssl
import tempfile
import unittest
from time import sleep
from unittest import mock

import pytest
Expand All @@ -17,7 +18,7 @@
import aiohttp
from aiohttp import client, helpers, web
from aiohttp.client import ClientRequest
from aiohttp.connector import Connection
from aiohttp.connector import Connection, _DNSCacheTable
from aiohttp.test_utils import unused_port


Expand Down Expand Up @@ -365,40 +366,57 @@ def test_tcp_connector_resolve_host(loop):


@asyncio.coroutine
def test_tcp_connector_dns_cache_not_expired(loop):
conn = aiohttp.TCPConnector(
loop=loop,
use_dns_cache=True,
ttl_dns_cache=10
)
def dns_response():
return ["127.0.0.1"]

res = yield from conn._resolve_host('localhost', 8080)
res2 = yield from conn._resolve_host('localhost', 8080)

assert res is res2
@asyncio.coroutine
def test_tcp_connector_dns_cache_not_expired(loop):
with mock.patch('aiohttp.connector.DefaultResolver') as m_resolver:
conn = aiohttp.TCPConnector(
loop=loop,
use_dns_cache=True,
ttl_dns_cache=10
)
m_resolver().resolve.return_value = dns_response()
yield from conn._resolve_host('localhost', 8080)
yield from conn._resolve_host('localhost', 8080)
m_resolver().resolve.assert_called_once_with(
'localhost',
8080,
family=0
)


@asyncio.coroutine
def test_tcp_connector_dns_cache_forever(loop):
conn = aiohttp.TCPConnector(
loop=loop,
use_dns_cache=True,
ttl_dns_cache=None
)

res = yield from conn._resolve_host('localhost', 8080)
res2 = yield from conn._resolve_host('localhost', 8080)
assert res is res2
with mock.patch('aiohttp.connector.DefaultResolver') as m_resolver:
conn = aiohttp.TCPConnector(
loop=loop,
use_dns_cache=True,
ttl_dns_cache=10
)
m_resolver().resolve.return_value = dns_response()
yield from conn._resolve_host('localhost', 8080)
yield from conn._resolve_host('localhost', 8080)
m_resolver().resolve.assert_called_once_with(
'localhost',
8080,
family=0
)


@asyncio.coroutine
def test_tcp_connector_use_dns_cache_disabled(loop):
conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=False)

res = yield from conn._resolve_host('localhost', 8080)
res2 = yield from conn._resolve_host('localhost', 8080)

assert res is not res2
with mock.patch('aiohttp.connector.DefaultResolver') as m_resolver:
conn = aiohttp.TCPConnector(loop=loop, use_dns_cache=False)
m_resolver().resolve.return_value = dns_response()
yield from conn._resolve_host('localhost', 8080)
yield from conn._resolve_host('localhost', 8080)
m_resolver().resolve.assert_has_calls([
mock.call('localhost', 8080, family=0),
mock.call('localhost', 8080, family=0)
])


def test_get_pop_empty_conns(loop):
Expand Down Expand Up @@ -631,20 +649,15 @@ def test_tcp_connector_fingerprint_invalid(loop):

def test_tcp_connector_clear_dns_cache(loop):
conn = aiohttp.TCPConnector(loop=loop)
info = object()
conn._cached_hosts[('localhost', 123)] = info
conn._cached_hosts_timestamp[('localhost', 123)] = 100
conn._cached_hosts[('localhost', 124)] = info
conn._cached_hosts_timestamp[('localhost', 124)] = 101
hosts = ['a', 'b']
conn._cached_hosts.add(('localhost', 123), hosts)
conn._cached_hosts.add(('localhost', 124), hosts)
conn.clear_dns_cache('localhost', 123)
assert conn.cached_hosts == {('localhost', 124): info}
assert conn._cached_hosts_timestamp == {('localhost', 124): 101}
assert ('localhost', 123) not in conn.cached_hosts
conn.clear_dns_cache('localhost', 123)
assert conn.cached_hosts == {('localhost', 124): info}
assert conn._cached_hosts_timestamp == {('localhost', 124): 101}
assert ('localhost', 123) not in conn.cached_hosts
conn.clear_dns_cache()
assert conn.cached_hosts == {}
assert conn._cached_hosts_timestamp == {}


def test_tcp_connector_clear_dns_cache_bad_args(loop):
Expand Down Expand Up @@ -1181,3 +1194,57 @@ def test_resolver_not_called_with_address_is_ip(self):
self.loop.run_until_complete(connector.connect(req))

resolver.resolve.assert_not_called()


class TestDNSCacheTable:

@pytest.fixture
def dns_cache_table(self):
return _DNSCacheTable()

def test_addrs(self, dns_cache_table):
dns_cache_table.add('localhost', ['127.0.0.1'])
dns_cache_table.add('foo', ['127.0.0.2'])
assert dns_cache_table.addrs == {
'localhost': ['127.0.0.1'],
'foo': ['127.0.0.2']
}

def test_remove(self, dns_cache_table):
dns_cache_table.add('localhost', ['127.0.0.1'])
dns_cache_table.remove('localhost')
assert dns_cache_table.addrs == {}

def test_clear(self, dns_cache_table):
dns_cache_table.add('localhost', ['127.0.0.1'])
dns_cache_table.clear()
assert dns_cache_table.addrs == {}

def test_not_expired_ttl_None(self, dns_cache_table):
dns_cache_table.add('localhost', ['127.0.0.1'])
assert not dns_cache_table.expired('localhost')

def test_not_expired_ttl(self):
dns_cache_table = _DNSCacheTable(ttl=0.1)
dns_cache_table.add('localhost', ['127.0.0.1'])
assert not dns_cache_table.expired('localhost')

def test_expired_ttl(self):
dns_cache_table = _DNSCacheTable(ttl=0.1)
dns_cache_table.add('localhost', ['127.0.0.1'])
sleep(0.1)
assert dns_cache_table.expired('localhost')

def test_next_addrs(self, dns_cache_table):
dns_cache_table.add('foo', ['127.0.0.1', '127.0.0.2'])

# max elements returned are the full list of addrs
addrs = list(dns_cache_table.next_addrs('foo'))
assert addrs == ['127.0.0.1', '127.0.0.2']

# different calls to next_addrs return the hosts using
# a round robin strategy.
addrs = dns_cache_table.next_addrs('foo')
assert next(addrs) == '127.0.0.1'
addrs = dns_cache_table.next_addrs('foo')
assert next(addrs) == '127.0.0.2'

0 comments on commit 3da3a01

Please sign in to comment.