diff --git a/tests/model_hub_tests/models_hub_common/utils.py b/tests/model_hub_tests/models_hub_common/utils.py index 3df17a343dd05b..6dac33640162de 100644 --- a/tests/model_hub_tests/models_hub_common/utils.py +++ b/tests/model_hub_tests/models_hub_common/utils.py @@ -148,7 +148,7 @@ def print_stat(s: str, value: float): print(s.format(round_num(value))) -def retry(max_retries=3, exceptions=(Exception,)): +def retry(max_retries=3, exceptions=(Exception,), delay=None): def retry_decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): @@ -157,7 +157,9 @@ def wrapper(*args, **kwargs): return func(*args, **kwargs) except exceptions as e: print(f"Attempt {attempt + 1} of {max_retries} failed: {e}") - if attempt == max_retries - 1: + if attempt < max_retries - 1 and delay is not None: + time.sleep(delay) + else: raise e return wrapper return retry_decorator diff --git a/tests/model_hub_tests/pytorch/test_hf_transformers.py b/tests/model_hub_tests/pytorch/test_hf_transformers.py index 52532ccd6b2c74..5e3f19ad945399 100644 --- a/tests/model_hub_tests/pytorch/test_hf_transformers.py +++ b/tests/model_hub_tests/pytorch/test_hf_transformers.py @@ -102,7 +102,7 @@ def setup_class(self): self.image = Image.open(requests.get(url, stream=True).raw) self.cuda_available, self.gptq_postinit = None, None - @retry(3, exceptions=(HfHubHTTPError,)) + @retry(3, exceptions=(HfHubHTTPError,), delay=1) def load_model(self, name, type): name_suffix = '' if name.find(':') != -1: