Skip to content

Commit

Permalink
Fix DNSCache race-condition (#2620)
Browse files Browse the repository at this point in the history
  • Loading branch information
socketpair committed Dec 26, 2017
1 parent 2c169cb commit 668fc3c
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 12 deletions.
1 change: 1 addition & 0 deletions CHANGES/2620.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed race-condition for iterating addresses from the DNSCache.
8 changes: 4 additions & 4 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,10 +578,10 @@ def clear(self):
self._timestamps.clear()

def next_addrs(self, host):
# Return an iterator that will get at maximum as many addrs
# there are for the specific host starting from the last
# not itereated addr.
return islice(self._addrs_rr[host], len(self._addrs[host]))
loop = self._addrs_rr[host]
addrs = list(islice(loop, len(self._addrs[host])))
next(loop) # Consume one more element to shift internal state of `cycle`
return addrs

def expired(self, host):
if self._ttl is None:
Expand Down
28 changes: 20 additions & 8 deletions tests/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1959,15 +1959,27 @@ async def test_expired_ttl(self, loop):
assert dns_cache_table.expired('localhost')

def test_next_addrs(self, dns_cache_table):
dns_cache_table.add('foo', ['127.0.0.1', '127.0.0.2'])
dns_cache_table.add('foo', ['127.0.0.1', '127.0.0.2', '127.0.0.3'])

# max elements returned are the full list of addrs
addrs = list(dns_cache_table.next_addrs('foo'))
assert addrs == ['127.0.0.1', '127.0.0.2']

# different calls to next_addrs return the hosts using
# Each calls to next_addrs return the hosts using
# a round robin strategy.
addrs = dns_cache_table.next_addrs('foo')
assert next(addrs) == '127.0.0.1'
assert addrs == ['127.0.0.1', '127.0.0.2', '127.0.0.3']

addrs = dns_cache_table.next_addrs('foo')
assert addrs == ['127.0.0.2', '127.0.0.3', '127.0.0.1']

addrs = dns_cache_table.next_addrs('foo')
assert addrs == ['127.0.0.3', '127.0.0.1', '127.0.0.2']

addrs = dns_cache_table.next_addrs('foo')
assert addrs == ['127.0.0.1', '127.0.0.2', '127.0.0.3']

def test_next_addrs_single(self, dns_cache_table):
dns_cache_table.add('foo', ['127.0.0.1'])

addrs = dns_cache_table.next_addrs('foo')
assert addrs == ['127.0.0.1']

addrs = dns_cache_table.next_addrs('foo')
assert next(addrs) == '127.0.0.2'
assert addrs == ['127.0.0.1']

0 comments on commit 668fc3c

Please sign in to comment.