Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Retry REST requests on connection errors #1648

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/1648.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Retry REST requests on connection errors
31 changes: 28 additions & 3 deletions hikari/impl/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,8 +736,10 @@ async def _request(

raise errors.ComponentStateConflictError("The REST client was closed mid-request")

# Ignore too long and too complex, respectively
# We rather keep everything we can here inline.
@typing.final
async def _perform_request(
async def _perform_request( # noqa: CFQ001, C901
self,
compiled_route: routes.CompiledRoute,
*,
Expand Down Expand Up @@ -779,13 +781,14 @@ async def _perform_request(
url = compiled_route.create_url(self._rest_url)

stack = contextlib.AsyncExitStack()
# This is initiated the first time we hit a 5xx error to save a little memory when nothing goes wrong
# This is initiated the first time we time out or hit a 5xx error to
# save a little memory when nothing goes wrong
backoff: typing.Optional[rate_limits.ExponentialBackOff] = None
retry_count = 0
trace_logging_enabled = _LOGGER.isEnabledFor(ux.TRACE)

while True:
async with stack:
try:
if form_builder:
data = await form_builder.build(stack, executor=self._executor)

Expand Down Expand Up @@ -832,6 +835,28 @@ async def _perform_request(
# Ensure we are not rate limited, and update rate limiting headers where appropriate.
time_before_retry = await self._parse_ratelimits(compiled_route, auth, response)

except (asyncio.TimeoutError, aiohttp.ClientConnectionError) as ex:
if retry_count >= self._max_retries:
raise errors.HTTPError(message=str(ex)) from ex

if backoff is None:
backoff = rate_limits.ExponentialBackOff(maximum=_MAX_BACKOFF_DURATION)

sleep_time = next(backoff)
_LOGGER.warning(
"Connection error (%s), backing off for %.2fs and retrying. Retries remaining: %s",
type(ex).__name__,
sleep_time,
self._max_retries - retry_count,
)
retry_count += 1

await asyncio.sleep(sleep_time)
continue

finally:
await stack.aclose()

if time_before_retry is not None:
await asyncio.sleep(time_before_retry)
continue
Expand Down
28 changes: 28 additions & 0 deletions tests/hikari/impl/test_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import re
import typing

import aiohttp
import mock
import pytest

Expand Down Expand Up @@ -1913,6 +1914,33 @@ class StubResponse:
asyncio_sleep.assert_has_awaits([mock.call(1), mock.call(2), mock.call(3)])
generate_error_response.assert_called_once_with(rest_client._client_session.request.return_value)

@hikari_test_helpers.timeout()
@pytest.mark.parametrize("exception", [asyncio.TimeoutError, aiohttp.ClientConnectionError])
async def test_perform_request_when_connection_error_will_retry_until_exhausted(self, rest_client, exception):
route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123)
mock_session = mock.AsyncMock(request=mock.AsyncMock(side_effect=exception))
rest_client._max_retries = 3
rest_client._parse_ratelimits = mock.AsyncMock()
rest_client._client_session = mock_session

stack = contextlib.ExitStack()
stack.enter_context(pytest.raises(errors.HTTPError))
exponential_backoff = stack.enter_context(
mock.patch.object(
rate_limits,
"ExponentialBackOff",
return_value=mock.Mock(__next__=mock.Mock(side_effect=[1, 2, 3, 4, 5])),
)
)
asyncio_sleep = stack.enter_context(mock.patch.object(asyncio, "sleep"))

with stack:
await rest_client._perform_request(route)

assert exponential_backoff.return_value.__next__.call_count == 3
exponential_backoff.assert_called_once_with(maximum=16)
asyncio_sleep.assert_has_awaits([mock.call(1), mock.call(2), mock.call(3)])

@pytest.mark.parametrize("enabled", [True, False])
@hikari_test_helpers.timeout()
async def test_perform_request_logger(self, rest_client, enabled):
Expand Down