Skip to content

Commit

Permalink
Only call status_code_from_exception if the Exception is an instance …
Browse files Browse the repository at this point in the history
…of class.base_class
  • Loading branch information
tremble committed Sep 12, 2021
1 parent 4787ba9 commit 1401ce1
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 16 deletions.
46 changes: 31 additions & 15 deletions plugins/module_utils/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __next__(self):
return return_value


def _retry_func(func, sleep_time_generator, retries, catch_extra_error_codes, found_f, status_code_from_except_f):
def _retry_func(func, sleep_time_generator, retries, catch_extra_error_codes, found_f, status_code_from_except_f, base_class):
counter = 0
for sleep_time in sleep_time_generator:
try:
Expand All @@ -67,6 +67,8 @@ def _retry_func(func, sleep_time_generator, retries, catch_extra_error_codes, fo
counter += 1
if counter == retries:
raise
if base_class and not isinstance(exc, base_class):
raise
status_code = status_code_from_except_f(exc)
if found_f(status_code, catch_extra_error_codes):
time.sleep(sleep_time)
Expand All @@ -79,8 +81,16 @@ class CloudRetry:
The base class to be used by other cloud providers to provide a backoff/retry decorator based on status codes.
"""

base_class = type(None)

@staticmethod
def status_code_from_exception(error):
"""
Returns the Error 'code' from an exception.
Args:
error: The Exception from which the error code is to be extracted.
error will be an instance of class.base_class.
"""
raise NotImplementedError()

@staticmethod
Expand All @@ -96,8 +106,8 @@ def _is_iterable():
return True
return _is_iterable() and response_code in catch_extra_error_codes

@staticmethod
def base_decorator(retries, found, status_code_from_exception, catch_extra_error_codes, sleep_time_generator):
@classmethod
def base_decorator(cls, retries, found, status_code_from_exception, catch_extra_error_codes, sleep_time_generator):
def retry_decorator(func):
@functools.wraps(func)
def _retry_wrapper(*args, **kwargs):
Expand All @@ -108,7 +118,8 @@ def _retry_wrapper(*args, **kwargs):
retries=retries,
catch_extra_error_codes=catch_extra_error_codes,
found_f=found,
status_code_from_except_f=status_code_from_exception
status_code_from_except_f=status_code_from_exception,
base_class=cls.base_class,
)
return _retry_wrapper
return retry_decorator
Expand All @@ -131,11 +142,13 @@ def exponential_backoff(cls, retries=10, delay=3, backoff=2, max_delay=60, catch
Callable: A generator that calls the decorated function using an exponential backoff.
"""
sleep_time_generator = BackoffIterator(delay=delay, backoff=backoff, max_delay=max_delay)
return CloudRetry.base_decorator(retries=retries,
found=cls.found,
status_code_from_exception=cls.status_code_from_exception,
catch_extra_error_codes=catch_extra_error_codes,
sleep_time_generator=sleep_time_generator)
return cls.base_decorator(
retries=retries,
found=cls.found,
status_code_from_exception=cls.status_code_from_exception,
catch_extra_error_codes=catch_extra_error_codes,
sleep_time_generator=sleep_time_generator,
)

@classmethod
def jittered_backoff(cls, retries=10, delay=3, backoff=2.0, max_delay=60, catch_extra_error_codes=None):
Expand All @@ -155,11 +168,13 @@ def jittered_backoff(cls, retries=10, delay=3, backoff=2.0, max_delay=60, catch_
Callable: A generator that calls the decorated function using using a jittered backoff strategy.
"""
sleep_time_generator = BackoffIterator(delay=delay, backoff=backoff, max_delay=max_delay, jitter=True)
return CloudRetry.base_decorator(retries=retries,
found=cls.found,
status_code_from_exception=cls.status_code_from_exception,
catch_extra_error_codes=catch_extra_error_codes,
sleep_time_generator=sleep_time_generator)
return cls.base_decorator(
retries=retries,
found=cls.found,
status_code_from_exception=cls.status_code_from_exception,
catch_extra_error_codes=catch_extra_error_codes,
sleep_time_generator=sleep_time_generator,
)

@classmethod
def backoff(cls, tries=10, delay=3, backoff=1.1, catch_extra_error_codes=None):
Expand All @@ -184,4 +199,5 @@ def backoff(cls, tries=10, delay=3, backoff=1.1, catch_extra_error_codes=None):
delay=delay,
backoff=backoff,
max_delay=None,
catch_extra_error_codes=catch_extra_error_codes)
catch_extra_error_codes=catch_extra_error_codes,
)
66 changes: 65 additions & 1 deletion tests/unit/module_utils/test_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)

from ansible_collections.amazon.aws.plugins.module_utils.cloud import CloudRetry, BackoffIterator
import pytest
import unittest
import random
from datetime import datetime
Expand Down Expand Up @@ -62,12 +61,14 @@ def __str__(self):
return "TestException with status: {0}".format(self.status)

class UnitTestsRetry(CloudRetry):
base_class = Exception

@staticmethod
def status_code_from_exception(error):
return getattr(error, "status") if hasattr(error, "status") else None

class CustomRetry(CloudRetry):
base_class = Exception

@staticmethod
def status_code_from_exception(error):
Expand All @@ -80,6 +81,28 @@ def found(response_code, catch_extra_error_codes=None):
else:
return response_code in CloudRetryUtils.custom_error_codes

class KeyRetry(CloudRetry):
base_class = KeyError

@staticmethod
def status_code_from_exception(error):
return True

@staticmethod
def found(response_code, catch_extra_error_codes=None):
return True

class KeyAndIndexRetry(CloudRetry):
base_class = (KeyError, IndexError)

@staticmethod
def status_code_from_exception(error):
return True

@staticmethod
def found(response_code, catch_extra_error_codes=None):
return True

# ========================================================
# Setup some initial data that we can use within our tests
# ========================================================
Expand Down Expand Up @@ -195,3 +218,44 @@ def _fail():
assert duration == 2
finally:
assert raised

def test_only_base_exception(self):
def _fail_index():
my_list = list()
return my_list[5]

def _fail_key():
my_dict = dict()
return my_dict['invalid_key']

def _fail_exception():
raise Exception('bang')

key_retry_decorator = CloudRetryUtils.KeyRetry.exponential_backoff(retries=2, delay=2, backoff=4, max_delay=100)
key_and_index_retry_decorator = CloudRetryUtils.KeyAndIndexRetry.exponential_backoff(retries=2, delay=2, backoff=4, max_delay=100)

expectations = [
[key_retry_decorator, _fail_exception, 0],
[key_retry_decorator, _fail_index, 0],
[key_retry_decorator, _fail_key, 2],
[key_and_index_retry_decorator, _fail_exception, 0],
[key_and_index_retry_decorator, _fail_index, 2],
[key_and_index_retry_decorator, _fail_key, 2],
]

for expection in expectations:
decorator = expection[0]
function = expection[1]
duration = expection[2]

start = datetime.now()
raised = False
try:
decorator(function)()
except Exception:
raised = True
_duration = (datetime.now() - start).seconds
# Index errors shouldn't be retried
assert duration == _duration
finally:
assert raised

0 comments on commit 1401ce1

Please sign in to comment.