Skip to content

Commit

Permalink
[PT FE] Retry hf hub model load up to 3 times if http error
Browse files Browse the repository at this point in the history
  • Loading branch information
mvafin committed Jul 19, 2024
1 parent 05ac1b5 commit 043d489
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
16 changes: 16 additions & 0 deletions tests/model_hub_tests/models_hub_common/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import functools
import itertools
import os
import shutil
Expand Down Expand Up @@ -145,3 +146,18 @@ def call_with_timer(timer_label: str, func, args):

def print_stat(s: str, value: float):
print(s.format(round_num(value)))


def retry(max_retries=3, exceptions=(Exception,)):
def retry_decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except exceptions as e:
print(f"Attempt {attempt + 1} of {max_retries} failed: {e}")
if attempt == max_retries - 1:
raise e
return wrapper
return retry_decorator
4 changes: 3 additions & 1 deletion tests/model_hub_tests/pytorch/test_hf_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
import pytest
import torch
from huggingface_hub import model_info
from huggingface_hub.utils import HfHubHTTPError
from models_hub_common.constants import hf_hub_cache_dir
from models_hub_common.utils import cleanup_dir
from models_hub_common.utils import cleanup_dir, retry
import transformers
from transformers import AutoConfig, AutoModel, AutoProcessor, AutoTokenizer, AutoFeatureExtractor, AutoModelForTextToWaveform, \
CLIPFeatureExtractor, XCLIPVisionModel, T5Tokenizer, VisionEncoderDecoderModel, ViTImageProcessor, BlipProcessor, BlipForConditionalGeneration, \
Expand Down Expand Up @@ -101,6 +102,7 @@ def setup_class(self):
self.image = Image.open(requests.get(url, stream=True).raw)
self.cuda_available, self.gptq_postinit = None, None

@retry(3, exceptions=(HfHubHTTPError,))
def load_model(self, name, type):
name_suffix = ''
if name.find(':') != -1:
Expand Down

0 comments on commit 043d489

Please sign in to comment.