Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PT FE] Fix issue with http error when using torch.hub #19901

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def _run(self, model_name, model_link, ie_device):
fw_outputs = self.infer_fw_model(fw_model, inputs)
print("Infer ov::Model")
ov_outputs = self.infer_ov_model(ov_model, inputs, ie_device)
print("Compare TensorFlow and OpenVINO results")
print("Compare framework and OpenVINO results")
self.compare_results(fw_outputs, ov_outputs)

def run(self, model_name, model_link, ie_device):
Expand Down
2 changes: 1 addition & 1 deletion tests/model_hub_tests/models_hub_common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def get_models_list(file_name: str):
model_name, model_link = model_info.split(',')
elif len(model_info.split(',')) == 4:
model_name, model_link, mark, reason = model_info.split(',')
assert mark == "skip", "Incorrect failure mark for model info {}".format(model_info)
assert mark in ["skip", "xfail"], "Incorrect failure mark for model info {}".format(model_info)
models.append((model_name, model_link, mark, reason))

return models
Expand Down
21 changes: 13 additions & 8 deletions tests/model_hub_tests/torch_tests/test_torchvision_models.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import os
import pytest
import torch
import tempfile
import torchvision.transforms.functional as F
from models_hub_common.test_convert_model import TestConvertModel
from openvino import convert_model
from models_hub_common.test_convert_model import TestConvertModel
from models_hub_common.utils import get_models_list


def get_all_models() -> list:
m_list = torch.hub.list("pytorch/vision")
m_list = torch.hub.list("pytorch/vision", skip_validation=True)
m_list.remove("get_model_weights")
m_list.remove("get_weight")
return m_list
Expand All @@ -36,7 +38,8 @@ def get_video():


def prepare_frames_for_raft(name, frames1, frames2):
w = torch.hub.load("pytorch/vision", "get_model_weights", name=name).DEFAULT
w = torch.hub.load("pytorch/vision", "get_model_weights",
name=name, skip_validation=True).DEFAULT
img1_batch = torch.stack(frames1)
img2_batch = torch.stack(frames2)
img1_batch = F.resize(img1_batch, size=[520, 960], antialias=False)
Expand All @@ -50,13 +53,14 @@ def prepare_frames_for_raft(name, frames1, frames2):


class TestTorchHubConvertModel(TestConvertModel):
def setup_method(self):
def setup_class(self):
self.cache_dir = tempfile.TemporaryDirectory()
# set temp dir for torch cache
torch.hub.set_dir(str(self.cache_dir.name))

def load_model(self, model_name, model_link):
m = torch.hub.load("pytorch/vision", model_name, weights='DEFAULT')
m = torch.hub.load("pytorch/vision", model_name,
weights='DEFAULT', skip_validation=True)
m.eval()
if model_name == "s3d" or any([m in model_name for m in ["swin3d", "r3d_18", "mc3_18", "r2plus1d_18"]]):
self.example = (torch.randn([1, 3, 224, 224, 224]),)
Expand Down Expand Up @@ -109,7 +113,8 @@ def teardown_method(self):
def test_convert_model_precommit(self, model_name, ie_device):
self.run(model_name, None, ie_device)

@pytest.mark.parametrize("model_name", get_all_models())
@pytest.mark.parametrize("name",
[pytest.param(n, marks=pytest.mark.xfail) if m == "xfail" else n for n, _, m, r in get_models_list(os.path.join(os.path.dirname(__file__), "torchvision_models"))])
@pytest.mark.nightly
def test_convert_model_all_models(self, model_name, ie_device):
self.run(model_name, None, ie_device)
def test_convert_model_all_models(self, name, ie_device):
self.run(name, None, ie_device)
97 changes: 97 additions & 0 deletions tests/model_hub_tests/torch_tests/torchvision_models
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
alexnet,none
convnext_base,none
convnext_large,none
convnext_small,none
convnext_tiny,none
deeplabv3_mobilenet_v3_large,none
deeplabv3_resnet101,none
deeplabv3_resnet50,none
densenet121,none
densenet161,none
densenet169,none
densenet201,none
efficientnet_b0,none
efficientnet_b1,none
efficientnet_b2,none
efficientnet_b3,none
efficientnet_b4,none
efficientnet_b5,none
efficientnet_b6,none
efficientnet_b7,none
efficientnet_v2_l,none
efficientnet_v2_m,none
efficientnet_v2_s,none
fcn_resnet101,none
fcn_resnet50,none
googlenet,none
inception_v3,none
lraspp_mobilenet_v3_large,none
maxvit_t,none
mc3_18,none
mnasnet0_5,none
mnasnet0_75,none
mnasnet1_0,none
mnasnet1_3,none
mobilenet_v2,none
mobilenet_v3_large,none
mobilenet_v3_small,none
mvit_v1_b,none
mvit_v2_s,none
r2plus1d_18,none
r3d_18,none
raft_large,none
raft_small,none
regnet_x_16gf,none
regnet_x_1_6gf,none
regnet_x_32gf,none
regnet_x_3_2gf,none
regnet_x_400mf,none
regnet_x_800mf,none
regnet_x_8gf,none
regnet_y_128gf,none
regnet_y_16gf,none
regnet_y_1_6gf,none
regnet_y_32gf,none
regnet_y_3_2gf,none
regnet_y_400mf,none
regnet_y_800mf,none
regnet_y_8gf,none
resnet101,none
resnet152,none
resnet18,none
resnet34,none
resnet50,none
resnext101_32x8d,none
resnext101_64x4d,none
resnext50_32x4d,none
s3d,none
shufflenet_v2_x0_5,none
shufflenet_v2_x1_0,none
shufflenet_v2_x1_5,none
shufflenet_v2_x2_0,none
squeezenet1_0,none
squeezenet1_1,none
swin3d_b,none
swin3d_s,none
swin3d_t,none
swin_b,none
swin_s,none
swin_t,none
swin_v2_b,none
swin_v2_s,none
swin_v2_t,none
vgg11,none
vgg11_bn,none
vgg13,none
vgg13_bn,none
vgg16,none
vgg16_bn,none
vgg19,none
vgg19_bn,none
vit_b_16,none,xfail,Tracing fails
vit_b_32,none,xfail,Tracing fails
vit_h_14,none,xfail,Tracing fails
vit_l_16,none,xfail,Tracing fails
vit_l_32,none,xfail,Tracing fails
wide_resnet101_2,none
wide_resnet50_2,none