From 043d4897ac01b2ef184fc950bb99407a65d8e135 Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Fri, 19 Jul 2024 16:41:59 +0200 Subject: [PATCH] [PT FE] Retry hf hub model load up to 3 times if http error --- tests/model_hub_tests/models_hub_common/utils.py | 16 ++++++++++++++++ .../pytorch/test_hf_transformers.py | 4 +++- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/tests/model_hub_tests/models_hub_common/utils.py b/tests/model_hub_tests/models_hub_common/utils.py index d223f3bc984c6e..3df17a343dd05b 100644 --- a/tests/model_hub_tests/models_hub_common/utils.py +++ b/tests/model_hub_tests/models_hub_common/utils.py @@ -1,6 +1,7 @@ # Copyright (C) 2018-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import functools import itertools import os import shutil @@ -145,3 +146,18 @@ def call_with_timer(timer_label: str, func, args): def print_stat(s: str, value: float): print(s.format(round_num(value))) + + +def retry(max_retries=3, exceptions=(Exception,)): + def retry_decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + for attempt in range(max_retries): + try: + return func(*args, **kwargs) + except exceptions as e: + print(f"Attempt {attempt + 1} of {max_retries} failed: {e}") + if attempt == max_retries - 1: + raise e + return wrapper + return retry_decorator diff --git a/tests/model_hub_tests/pytorch/test_hf_transformers.py b/tests/model_hub_tests/pytorch/test_hf_transformers.py index 32d3ac70947a91..52532ccd6b2c74 100644 --- a/tests/model_hub_tests/pytorch/test_hf_transformers.py +++ b/tests/model_hub_tests/pytorch/test_hf_transformers.py @@ -6,8 +6,9 @@ import pytest import torch from huggingface_hub import model_info +from huggingface_hub.utils import HfHubHTTPError from models_hub_common.constants import hf_hub_cache_dir -from models_hub_common.utils import cleanup_dir +from models_hub_common.utils import cleanup_dir, retry import transformers from transformers import AutoConfig, AutoModel, AutoProcessor, AutoTokenizer, AutoFeatureExtractor, AutoModelForTextToWaveform, \ CLIPFeatureExtractor, XCLIPVisionModel, T5Tokenizer, VisionEncoderDecoderModel, ViTImageProcessor, BlipProcessor, BlipForConditionalGeneration, \ @@ -101,6 +102,7 @@ def setup_class(self): self.image = Image.open(requests.get(url, stream=True).raw) self.cuda_available, self.gptq_postinit = None, None + @retry(3, exceptions=(HfHubHTTPError,)) def load_model(self, name, type): name_suffix = '' if name.find(':') != -1: