Skip to content

Commit

Permalink
Cleanup unittest and yolort.utils
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Sep 20, 2021
1 parent 82f0d09 commit d07161d
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 101 deletions.
28 changes: 0 additions & 28 deletions test/test_hooks_utils.py

This file was deleted.

44 changes: 0 additions & 44 deletions test/test_image_utils.py

This file was deleted.

66 changes: 66 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
from pathlib import Path
import pytest

import numpy as np
import torch
from torch import nn, Tensor

from yolort.models import yolov5s
from yolort.v5 import letterbox, scale_coords
from yolort.utils.image_utils import box_cxcywh_to_xyxy
from yolort.utils import (
FeatureExtractor,
update_module_state_from_ultralytics,
read_image_to_tensor,
get_image_from_url,
Expand Down Expand Up @@ -41,3 +47,63 @@ def test_get_image_from_url():
img = get_image_from_url(url)
assert isinstance(img, np.ndarray)
assert tuple(img.shape) == (720, 1280, 3)


def test_letterbox():
img = np.random.randint(0, 255, (720, 360, 3), dtype='uint8') # As a dummy image
out = letterbox(img, new_shape=(416, 416))[0]
assert tuple(out.shape) == (416, 224, 3)


def test_box_cxcywh_to_xyxy():
box_cxcywh = np.asarray([[50, 50, 100, 100],
[0, 0, 0, 0],
[20, 25, 20, 20],
[58, 65, 70, 60]], dtype=np.float)
exp_xyxy = np.asarray([[0, 0, 100, 100],
[0, 0, 0, 0],
[10, 15, 30, 35],
[23, 35, 93, 95]], dtype=np.float)

box_xyxy = box_cxcywh_to_xyxy(box_cxcywh)
assert exp_xyxy.shape == (4, 4)
assert exp_xyxy.dtype == box_xyxy.dtype
np.testing.assert_array_equal(exp_xyxy, box_xyxy)


def test_scale_coords():
box_tensor = torch.tensor([[0., 0., 100., 100.],
[0., 0., 0., 0.],
[10., 15., 30., 35.],
[20., 35., 90., 95.]], dtype=torch.float)
exp_coords = torch.tensor([[0., 0., 108.05, 111.25],
[0., 0., 0., 0.],
[7.9250, 16.6875, 30.1750, 38.9375],
[19.05, 38.9375, 96.9250, 105.6875]], dtype=torch.float)

box_coords_scaled = scale_coords((160, 128), box_tensor, (178, 136))
assert tuple(box_coords_scaled.shape) == (4, 4)
torch.testing.assert_close(box_coords_scaled, exp_coords)


@pytest.mark.parametrize('b, h, w', [(8, 640, 640), (4, 416, 320), (8, 320, 416)])
def test_feature_extractor(b, h, w):
c = 3
in_channels = [128, 256, 512]
strides = [8, 16, 32]
num_outputs = 85
expected_features = [(b, inc, h // s, w // s) for inc, s in zip(in_channels, strides)]
expected_head_outputs = [(b, c, h // s, w // s, num_outputs) for s in strides]

model = yolov5s()
model = model.train()
yolo_features = FeatureExtractor(model.model, return_layers=['backbone', 'head'])
images = torch.rand(b, c, h, w)
targets = torch.rand(61, 6)
intermediate_features = yolo_features(images, targets)
features = intermediate_features['backbone']
head_outputs = intermediate_features['head']
assert isinstance(features, list)
assert [f.shape for f in features] == expected_features
assert isinstance(head_outputs, list)
assert [h.shape for h in head_outputs] == expected_head_outputs
24 changes: 22 additions & 2 deletions yolort/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,24 @@
from .flash_utils import get_callable_dict
from typing import Callable, Dict, Mapping, Sequence, Union

from .image_utils import cv2_imshow, get_image_from_url, read_image_to_tensor
from .update_module_state import update_module_state_from_ultralytics
from .hooks_utils import FeatureExtractor
from .hooks import FeatureExtractor

__all__ = [
'cv2_imshow', 'get_image_from_url', 'read_image_to_tensor',
'update_module_state_from_ultralytics', 'FeatureExtractor',
'get_callable_dict',
]


def get_callable_name(fn_or_class: Union[Callable, object]) -> str:
return getattr(fn_or_class, "__name__", fn_or_class.__class__.__name__).lower()


def get_callable_dict(fn: Union[Callable, Mapping, Sequence]) -> Union[Dict, Mapping]:
if isinstance(fn, Mapping):
return fn
elif isinstance(fn, Sequence):
return {get_callable_name(f): f for f in fn}
elif callable(fn):
return {get_callable_name(fn): fn}
27 changes: 0 additions & 27 deletions yolort/utils/flash_utils.py

This file was deleted.

File renamed without changes.

0 comments on commit d07161d

Please sign in to comment.