diff --git a/test/test_utils.py b/test/test_utils.py index 47669952..39c7daf6 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,7 +1,7 @@ # Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved. -import pytest - +from pathlib import Path import numpy as np +import torch from torch import nn, Tensor from yolort.utils import ( @@ -11,14 +11,18 @@ ) -@pytest.mark.skip("Temporarily close the test here.") def test_update_module_state_from_ultralytics(): + yolov5s_r40_path = Path('yolov5s.pt') + + if not yolov5s_r40_path.is_file(): + yolov5s_r40_url = 'https://github.com/ultralytics/yolov5/releases/download/v4.0/yolov5s.pt' + torch.hub.download_url_to_file(yolov5s_r40_url, yolov5s_r40_path, hash_prefix='9ca9a642') + model = update_module_state_from_ultralytics( + str(yolov5s_r40_path), arch='yolov5s', - version='v4.0', feature_fusion_type='PAN', num_classes=80, - custom_path_or_model=None, ) assert isinstance(model, nn.Module) @@ -33,7 +37,7 @@ def test_read_image_to_tensor(): def test_get_image_from_url(): - url = "https://github.com/ultralytics/yolov5/raw/master/data/images/zidane.jpg" + url = 'https://raw.githubusercontent.com/zhiqwang/yolov5-rt-stack/master/test/assets/zidane.jpg' img = get_image_from_url(url) assert isinstance(img, np.ndarray) assert tuple(img.shape) == (720, 1280, 3) diff --git a/yolort/utils/update_module_state.py b/yolort/utils/update_module_state.py index 68d648b0..ffddc774 100644 --- a/yolort/utils/update_module_state.py +++ b/yolort/utils/update_module_state.py @@ -1,8 +1,7 @@ # Copyright (c) 2020, Zhiqiang Wang. All Rights Reserved. -from typing import Any, Optional +from typing import Any from functools import reduce -import torch from torch import nn from yolort.models import yolo @@ -17,12 +16,12 @@ def update_module_state_from_ultralytics( - custom_path_path: str, - path_to_yolov5: Optional[str] = None, + model_path: str, arch: str = 'yolov5s', feature_fusion_type: str = 'PAN', num_classes: int = 80, set_fp16: bool = True, + verbose: bool = False, **kwargs: Any, ): """ @@ -31,9 +30,7 @@ def update_module_state_from_ultralytics( wish to re-train. Args: - custom_path_path (str): Path to your custom model. - path_to_yolov5 (Optional[str]): Path of the local yolov5 repo. - Default: None. + model_path (str): Path to your custom model. arch (str): yolo architecture. Possible values are 'yolov5s', 'yolov5m' and 'yolov5l'. Default: 'yolov5s'. feature_fusion_type (str): the type of fature fusion. Possible values are PAN and TAN. @@ -42,12 +39,11 @@ def update_module_state_from_ultralytics( Default: 80. set_fp16 (bool): allow selective conversion to fp16 or not. Default: True. + verbose (bool): print all information to screen. Default: True. """ - if path_to_yolov5 is not None: - model = torch.hub.load(path_to_yolov5, 'custom', path=custom_path_path, source='local') - else: - model = torch.hub.load('ultralytics/yolov5', 'custom', path=custom_path_path) + from yolort.v5 import load_model + model = load_model(model_path, autoshape=False, verbose=verbose) key_arch = f'{arch}_{feature_fusion_type.lower()}_v4.0' if key_arch not in ARCHITECTURE_MAPS: