From f52486b3f3fa57b2c5f3de82c4acb2e497f1b0a0 Mon Sep 17 00:00:00 2001 From: davfsa Date: Sun, 25 Jun 2023 22:06:11 +0200 Subject: [PATCH 1/2] Retry REST requests on connection errors Co-authored-by: Leo Developer --- hikari/impl/rest.py | 31 ++++++++++++++++++++++++++++--- tests/hikari/impl/test_rest.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 3 deletions(-) diff --git a/hikari/impl/rest.py b/hikari/impl/rest.py index 77030bdc04..024704ff65 100644 --- a/hikari/impl/rest.py +++ b/hikari/impl/rest.py @@ -736,8 +736,10 @@ async def _request( raise errors.ComponentStateConflictError("The REST client was closed mid-request") + # Ignore too long and too complex, respectively + # We rather keep everything we can here inline. @typing.final - async def _perform_request( + async def _perform_request( # noqa: CFQ001, C901 self, compiled_route: routes.CompiledRoute, *, @@ -779,13 +781,14 @@ async def _perform_request( url = compiled_route.create_url(self._rest_url) stack = contextlib.AsyncExitStack() - # This is initiated the first time we hit a 5xx error to save a little memory when nothing goes wrong + # This is initiated the first time we time out or hit a 5xx error to + # save a little memory when nothing goes wrong backoff: typing.Optional[rate_limits.ExponentialBackOff] = None retry_count = 0 trace_logging_enabled = _LOGGER.isEnabledFor(ux.TRACE) while True: - async with stack: + try: if form_builder: data = await form_builder.build(stack, executor=self._executor) @@ -832,6 +835,28 @@ async def _perform_request( # Ensure we are not rate limited, and update rate limiting headers where appropriate. time_before_retry = await self._parse_ratelimits(compiled_route, auth, response) + except (asyncio.TimeoutError, aiohttp.ClientConnectionError) as ex: + if retry_count >= self._max_retries: + raise errors.HTTPError(message=str(ex)) from ex + + if backoff is None: + backoff = rate_limits.ExponentialBackOff(maximum=_MAX_BACKOFF_DURATION) + + sleep_time = next(backoff) + _LOGGER.warning( + "Connection error (%s), backing off for %.2fs and retrying. Retries remaining: %s", + type(ex).__name__, + sleep_time, + self._max_retries - retry_count, + ) + retry_count += 1 + + await asyncio.sleep(sleep_time) + continue + + finally: + await stack.aclose() + if time_before_retry is not None: await asyncio.sleep(time_before_retry) continue diff --git a/tests/hikari/impl/test_rest.py b/tests/hikari/impl/test_rest.py index 972251b607..15132aa748 100644 --- a/tests/hikari/impl/test_rest.py +++ b/tests/hikari/impl/test_rest.py @@ -26,6 +26,7 @@ import re import typing +import aiohttp import mock import pytest @@ -1913,6 +1914,33 @@ class StubResponse: asyncio_sleep.assert_has_awaits([mock.call(1), mock.call(2), mock.call(3)]) generate_error_response.assert_called_once_with(rest_client._client_session.request.return_value) + @hikari_test_helpers.timeout() + @pytest.mark.parametrize("exception", [asyncio.TimeoutError, aiohttp.ClientConnectionError]) + async def test_perform_request_when_connection_error_will_retry_until_exhausted(self, rest_client, exception): + route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) + mock_session = mock.AsyncMock(request=mock.AsyncMock(side_effect=exception)) + rest_client._max_retries = 3 + rest_client._parse_ratelimits = mock.AsyncMock() + rest_client._client_session = mock_session + + stack = contextlib.ExitStack() + stack.enter_context(pytest.raises(errors.HTTPError)) + exponential_backoff = stack.enter_context( + mock.patch.object( + rate_limits, + "ExponentialBackOff", + return_value=mock.Mock(__next__=mock.Mock(side_effect=[1, 2, 3, 4, 5])), + ) + ) + asyncio_sleep = stack.enter_context(mock.patch.object(asyncio, "sleep")) + + with stack: + await rest_client._perform_request(route) + + assert exponential_backoff.return_value.__next__.call_count == 3 + exponential_backoff.assert_called_once_with(maximum=16) + asyncio_sleep.assert_has_awaits([mock.call(1), mock.call(2), mock.call(3)]) + @pytest.mark.parametrize("enabled", [True, False]) @hikari_test_helpers.timeout() async def test_perform_request_logger(self, rest_client, enabled): From cd43ba669b69d872ebf0745d2d1a2c50b2823a92 Mon Sep 17 00:00:00 2001 From: davfsa Date: Sun, 25 Jun 2023 22:18:00 +0200 Subject: [PATCH 2/2] Add changelog file --- changes/1648.bugfix.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 changes/1648.bugfix.md diff --git a/changes/1648.bugfix.md b/changes/1648.bugfix.md new file mode 100644 index 0000000000..64c2b65966 --- /dev/null +++ b/changes/1648.bugfix.md @@ -0,0 +1 @@ +Retry REST requests on connection errors