Skip to content

Commit

Permalink
Fix timeout
Browse files Browse the repository at this point in the history
Fixes: #64
  • Loading branch information
saghul committed Oct 16, 2023
1 parent 80d7ca2 commit 0a2db15
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
19 changes: 16 additions & 3 deletions aiodns/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,11 @@ def __init__(self, nameservers: Optional[Sequence[str]] = None,
raise RuntimeError(
'aiodns needs a SelectorEventLoop on Windows. See more: https://github.com/saghul/aiodns/issues/86')
kwargs.pop('sock_state_cb', None)
self._channel = pycares.Channel(sock_state_cb=self._sock_state_cb, **kwargs)
timeout = kwargs.pop('timeout', None)
self._timeout = timeout
self._channel = pycares.Channel(sock_state_cb=self._sock_state_cb,
timeout=timeout,
**kwargs)
if nameservers:
self.nameservers = nameservers
self._read_fds = set() # type: Set[int]
Expand Down Expand Up @@ -119,7 +123,7 @@ def _sock_state_cb(self, fd: int, readable: bool, writable: bool) -> None:
self.loop.add_writer(fd, self._handle_event, fd, WRITE)
self._write_fds.add(fd)
if self._timer is None:
self._timer = self.loop.call_later(1.0, self._timer_cb)
self._start_timer()
else:
# socket is now closed
if fd in self._read_fds:
Expand All @@ -146,6 +150,15 @@ def _handle_event(self, fd: int, event: Any) -> None:
def _timer_cb(self) -> None:
if self._read_fds or self._write_fds:
self._channel.process_fd(pycares.ARES_SOCKET_BAD, pycares.ARES_SOCKET_BAD)
self._timer = self.loop.call_later(1.0, self._timer_cb)
self._start_timer()
else:
self._timer = None

def _start_timer(self):
timeout = self._timeout
if timeout is None or timeout < 0 or timeout > 1:
timeout = 1
elif timeout == 0:
timeout = 0.1

self._timer = self.loop.call_later(timeout, self._timer_cb)
2 changes: 1 addition & 1 deletion tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def test_query_bad_class(self):
self.assertRaises(ValueError, self.resolver.query, 'google.com', 'A', "INVALIDCLASS")

def test_query_timeout(self):
self.resolver = aiodns.DNSResolver(timeout=0.1, loop=self.loop)
self.resolver = aiodns.DNSResolver(timeout=0.1, tries=1, loop=self.loop)
self.resolver.nameservers = ['1.2.3.4']
f = self.resolver.query('google.com', 'A')
started = time.monotonic()
Expand Down

0 comments on commit 0a2db15

Please sign in to comment.