Skip to content

Commit

Permalink
[PT FE] Retry hf hub model load up to 3 times if http error (openvino…
Browse files Browse the repository at this point in the history
…toolkit#25648)

### Details:
 - *item1*
 - *...*

### Tickets:
 - *CVS-143832*
  • Loading branch information
mvafin authored Jul 22, 2024
1 parent cdf342e commit 554e6fe
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
18 changes: 18 additions & 0 deletions tests/model_hub_tests/models_hub_common/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import functools
import itertools
import os
import shutil
Expand Down Expand Up @@ -145,3 +146,20 @@ def call_with_timer(timer_label: str, func, args):

def print_stat(s: str, value: float):
print(s.format(round_num(value)))


def retry(max_retries=3, exceptions=(Exception,), delay=None):
def retry_decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except exceptions as e:
print(f"Attempt {attempt + 1} of {max_retries} failed: {e}")
if attempt < max_retries - 1 and delay is not None:
time.sleep(delay)
else:
raise e
return wrapper
return retry_decorator
4 changes: 3 additions & 1 deletion tests/model_hub_tests/pytorch/test_hf_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
import pytest
import torch
from huggingface_hub import model_info
from huggingface_hub.utils import HfHubHTTPError
from models_hub_common.constants import hf_hub_cache_dir
from models_hub_common.utils import cleanup_dir
from models_hub_common.utils import cleanup_dir, retry
import transformers
from transformers import AutoConfig, AutoModel, AutoProcessor, AutoTokenizer, AutoFeatureExtractor, AutoModelForTextToWaveform, \
CLIPFeatureExtractor, XCLIPVisionModel, T5Tokenizer, VisionEncoderDecoderModel, ViTImageProcessor, BlipProcessor, BlipForConditionalGeneration, \
Expand Down Expand Up @@ -101,6 +102,7 @@ def setup_class(self):
self.image = Image.open(requests.get(url, stream=True).raw)
self.cuda_available, self.gptq_postinit = None, None

@retry(3, exceptions=(HfHubHTTPError,), delay=1)
def load_model(self, name, type):
name_suffix = ''
if name.find(':') != -1:
Expand Down

0 comments on commit 554e6fe

Please sign in to comment.