Skip to content

Commit

Permalink
Enable test_update_module_state_from_ultralytics
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Sep 19, 2021
1 parent 0ab1501 commit 518aa0f
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 17 deletions.
16 changes: 10 additions & 6 deletions test/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
import pytest

from pathlib import Path
import numpy as np
import torch
from torch import nn, Tensor

from yolort.utils import (
Expand All @@ -11,14 +11,18 @@
)


@pytest.mark.skip("Temporarily close the test here.")
def test_update_module_state_from_ultralytics():
yolov5s_r40_path = Path('yolov5s.pt')

if not yolov5s_r40_path.is_file():
yolov5s_r40_url = 'https://github.com/ultralytics/yolov5/releases/download/v4.0/yolov5s.pt'
torch.hub.download_url_to_file(yolov5s_r40_url, yolov5s_r40_path, hash_prefix='a67d2887')

model = update_module_state_from_ultralytics(
str(yolov5s_r40_path),
arch='yolov5s',
version='v4.0',
feature_fusion_type='PAN',
num_classes=80,
custom_path_or_model=None,
)
assert isinstance(model, nn.Module)

Expand All @@ -33,7 +37,7 @@ def test_read_image_to_tensor():


def test_get_image_from_url():
url = "https://github.com/ultralytics/yolov5/raw/master/data/images/zidane.jpg"
url = 'https://raw.githubusercontent.com/zhiqwang/yolov5-rt-stack/master/test/assets/zidane.jpg'
img = get_image_from_url(url)
assert isinstance(img, np.ndarray)
assert tuple(img.shape) == (720, 1280, 3)
18 changes: 7 additions & 11 deletions yolort/utils/update_module_state.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# Copyright (c) 2020, Zhiqiang Wang. All Rights Reserved.
from typing import Any, Optional
from typing import Any

from functools import reduce
import torch
from torch import nn

from yolort.models import yolo
Expand All @@ -17,12 +16,12 @@


def update_module_state_from_ultralytics(
custom_path_path: str,
path_to_yolov5: Optional[str] = None,
model_path: str,
arch: str = 'yolov5s',
feature_fusion_type: str = 'PAN',
num_classes: int = 80,
set_fp16: bool = True,
verbose: bool = False,
**kwargs: Any,
):
"""
Expand All @@ -31,9 +30,7 @@ def update_module_state_from_ultralytics(
wish to re-train.
Args:
custom_path_path (str): Path to your custom model.
path_to_yolov5 (Optional[str]): Path of the local yolov5 repo.
Default: None.
model_path (str): Path to your custom model.
arch (str): yolo architecture. Possible values are 'yolov5s', 'yolov5m' and 'yolov5l'.
Default: 'yolov5s'.
feature_fusion_type (str): the type of fature fusion. Possible values are PAN and TAN.
Expand All @@ -42,12 +39,11 @@ def update_module_state_from_ultralytics(
Default: 80.
set_fp16 (bool): allow selective conversion to fp16 or not.
Default: True.
verbose (bool): print all information to screen. Default: True.
"""

if path_to_yolov5 is not None:
model = torch.hub.load(path_to_yolov5, 'custom', path=custom_path_path, source='local')
else:
model = torch.hub.load('ultralytics/yolov5', 'custom', path=custom_path_path)
from yolort.v5 import load_model
model = load_model(model_path, autoshape=False, verbose=verbose)

key_arch = f'{arch}_{feature_fusion_type.lower()}_v4.0'
if key_arch not in ARCHITECTURE_MAPS:
Expand Down

0 comments on commit 518aa0f

Please sign in to comment.