Skip to content

Commit

Permalink
Retry based on HTTP methods
Browse files Browse the repository at this point in the history
Co-authored-by: Pandede <[email protected]>
#96
  • Loading branch information
inyutin committed Oct 27, 2024
1 parent 1f050a5 commit 1bace91
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 0 deletions.
3 changes: 3 additions & 0 deletions aiohttp_retry/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ async def _is_skip_retry(self, current_attempt: int, response: ClientResponse) -
if current_attempt == self._retry_options.attempts:
return True

if response.method not in self._retry_options.methods:
return True

if response.status >= _MIN_SERVER_ERROR_STATUS and self._retry_options.retry_all_server_errors:
return False

Expand Down
15 changes: 15 additions & 0 deletions aiohttp_retry/retry_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(
attempts: int = 3, # How many times we should retry
statuses: Iterable[int] | None = None, # On which statuses we should retry
exceptions: Iterable[type[Exception]] | None = None, # On which exceptions we should retry, by default on all
methods: Iterable[str] | None = None, # On which HTTP methods we should retry
retry_all_server_errors: bool = True, # If should retry all 500 errors or not
# a callback that will run on response to decide if retry
evaluate_response_callback: EvaluateResponseCallbackType | None = None,
Expand All @@ -29,6 +30,10 @@ def __init__(
exceptions = set()
self.exceptions: Iterable[type[Exception]] = exceptions

if methods is None:
methods = {"HEAD", "GET", "PUT", "DELETE", "OPTIONS", "TRACE", "POST", "CONNECT", "PATCH"}
self.methods: Iterable[str] = methods

self.retry_all_server_errors = retry_all_server_errors
self.evaluate_response_callback = evaluate_response_callback

Expand All @@ -46,13 +51,15 @@ def __init__(
factor: float = 2.0, # How much we increase timeout each time
statuses: set[int] | None = None, # On which statuses we should retry
exceptions: set[type[Exception]] | None = None, # On which exceptions we should retry
methods: set[str] | None = None, # On which HTTP methods we should retry
retry_all_server_errors: bool = True,
evaluate_response_callback: EvaluateResponseCallbackType | None = None,
) -> None:
super().__init__(
attempts=attempts,
statuses=statuses,
exceptions=exceptions,
methods=methods,
retry_all_server_errors=retry_all_server_errors,
evaluate_response_callback=evaluate_response_callback,
)
Expand Down Expand Up @@ -82,6 +89,7 @@ def __init__(
attempts: int = 3, # How many times we should retry
statuses: Iterable[int] | None = None, # On which statuses we should retry
exceptions: Iterable[type[Exception]] | None = None, # On which exceptions we should retry
methods: Iterable[str] | None = None, # On which HTTP methods we should retry
min_timeout: float = 0.1, # Minimum possible timeout
max_timeout: float = 3.0, # Maximum possible timeout between tries
random_func: Callable[[], float] = random.random, # Random number generator
Expand All @@ -92,6 +100,7 @@ def __init__(
attempts=attempts,
statuses=statuses,
exceptions=exceptions,
methods=methods,
retry_all_server_errors=retry_all_server_errors,
evaluate_response_callback=evaluate_response_callback,
)
Expand All @@ -116,13 +125,15 @@ def __init__(
timeouts: list[float],
statuses: Iterable[int] | None = None, # On which statuses we should retry
exceptions: Iterable[type[Exception]] | None = None, # On which exceptions we should retry
methods: Iterable[str] | None = None, # On which HTTP methods we should retry
retry_all_server_errors: bool = True,
evaluate_response_callback: EvaluateResponseCallbackType | None = None,
) -> None:
super().__init__(
attempts=len(timeouts),
statuses=statuses,
exceptions=exceptions,
methods=methods,
retry_all_server_errors=retry_all_server_errors,
evaluate_response_callback=evaluate_response_callback,
)
Expand All @@ -144,6 +155,7 @@ def __init__(
multiplier: float = 1.0,
statuses: Iterable[int] | None = None,
exceptions: Iterable[type[Exception]] | None = None,
methods: Iterable[str] | None = None,
max_timeout: float = 3.0, # Maximum possible timeout between tries
retry_all_server_errors: bool = True,
evaluate_response_callback: EvaluateResponseCallbackType | None = None,
Expand All @@ -152,6 +164,7 @@ def __init__(
attempts=attempts,
statuses=statuses,
exceptions=exceptions,
methods=methods,
retry_all_server_errors=retry_all_server_errors,
evaluate_response_callback=evaluate_response_callback,
)
Expand Down Expand Up @@ -184,6 +197,7 @@ def __init__(
factor: float = 2.0, # How much we increase timeout each time
statuses: set[int] | None = None, # On which statuses we should retry
exceptions: set[type[Exception]] | None = None, # On which exceptions we should retry
methods: set[str] | None = None, # On which HTTP methods we should retry
random_interval_size: float = 2.0, # size of interval for random component
retry_all_server_errors: bool = True,
evaluate_response_callback: EvaluateResponseCallbackType | None = None,
Expand All @@ -195,6 +209,7 @@ def __init__(
factor=factor,
statuses=statuses,
exceptions=exceptions,
methods=methods,
retry_all_server_errors=retry_all_server_errors,
evaluate_response_callback=evaluate_response_callback,
)
Expand Down
13 changes: 13 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,19 @@ async def test_list_retry_works_for_multiple_attempts(aiohttp_client: pytest_aio
await retry_client.close()


async def test_dont_retry_if_not_in_retry_methods(aiohttp_client: pytest_aiohttp.plugin.AiohttpClient) -> None:
retry_client, test_app = await get_retry_client_and_test_app_for_test(
aiohttp_client,
retry_options=ExponentialRetry(methods={"POST"}), # not "GET"
)

async with retry_client.get("/internal_error") as response:
assert response.status == 500
assert test_app.counter == 1

await retry_client.close()


async def test_implicit_client(aiohttp_client: pytest_aiohttp.plugin.AiohttpClient) -> None:
# check that if client not passed that it created implicitly
test_app = App()
Expand Down

0 comments on commit 1bace91

Please sign in to comment.