diff --git a/tests/model_hub_tests/models_hub_common/utils.py b/tests/model_hub_tests/models_hub_common/utils.py index d223f3bc984c6e..6dac33640162de 100644 --- a/tests/model_hub_tests/models_hub_common/utils.py +++ b/tests/model_hub_tests/models_hub_common/utils.py @@ -1,6 +1,7 @@ # Copyright (C) 2018-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import functools import itertools import os import shutil @@ -145,3 +146,20 @@ def call_with_timer(timer_label: str, func, args): def print_stat(s: str, value: float): print(s.format(round_num(value))) + + +def retry(max_retries=3, exceptions=(Exception,), delay=None): + def retry_decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + for attempt in range(max_retries): + try: + return func(*args, **kwargs) + except exceptions as e: + print(f"Attempt {attempt + 1} of {max_retries} failed: {e}") + if attempt < max_retries - 1 and delay is not None: + time.sleep(delay) + else: + raise e + return wrapper + return retry_decorator diff --git a/tests/model_hub_tests/pytorch/test_hf_transformers.py b/tests/model_hub_tests/pytorch/test_hf_transformers.py index 32d3ac70947a91..5e3f19ad945399 100644 --- a/tests/model_hub_tests/pytorch/test_hf_transformers.py +++ b/tests/model_hub_tests/pytorch/test_hf_transformers.py @@ -6,8 +6,9 @@ import pytest import torch from huggingface_hub import model_info +from huggingface_hub.utils import HfHubHTTPError from models_hub_common.constants import hf_hub_cache_dir -from models_hub_common.utils import cleanup_dir +from models_hub_common.utils import cleanup_dir, retry import transformers from transformers import AutoConfig, AutoModel, AutoProcessor, AutoTokenizer, AutoFeatureExtractor, AutoModelForTextToWaveform, \ CLIPFeatureExtractor, XCLIPVisionModel, T5Tokenizer, VisionEncoderDecoderModel, ViTImageProcessor, BlipProcessor, BlipForConditionalGeneration, \ @@ -101,6 +102,7 @@ def setup_class(self): self.image = Image.open(requests.get(url, stream=True).raw) self.cuda_available, self.gptq_postinit = None, None + @retry(3, exceptions=(HfHubHTTPError,), delay=1) def load_model(self, name, type): name_suffix = '' if name.find(':') != -1: