Skip to content

Commit

Permalink
Fixing unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Sep 19, 2021
1 parent 69c4924 commit ea81728
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 8 deletions.
9 changes: 5 additions & 4 deletions test/test_image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

import torch

from yolort.utils.image_utils import box_cxcywh_to_xyxy, letterbox, scale_coords
from yolort.utils.image_utils import box_cxcywh_to_xyxy
from yolort.ultralytics import letterbox, scale_coords


def test_letterbox():
Expand All @@ -16,11 +17,11 @@ def test_box_cxcywh_to_xyxy():
box_cxcywh = np.asarray([[50, 50, 100, 100],
[0, 0, 0, 0],
[20, 25, 20, 20],
[58, 65, 70, 60]], dtype=np.float32)
[58, 65, 70, 60]], dtype=np.float)
exp_xyxy = np.asarray([[0, 0, 100, 100],
[0, 0, 0, 0],
[10, 15, 30, 35],
[23, 35, 93, 95]], dtype=np.float32)
[23, 35, 93, 95]], dtype=np.float)

box_xyxy = box_cxcywh_to_xyxy(box_cxcywh)
assert exp_xyxy.shape == (4, 4)
Expand All @@ -38,6 +39,6 @@ def test_scale_coords():
[7.9250, 16.6875, 30.1750, 38.9375],
[19.05, 38.9375, 96.9250, 105.6875]], dtype=torch.float)

box_coords_scaled = scale_coords(box_tensor, (160, 128), (178, 136))
box_coords_scaled = scale_coords((160, 128), box_tensor, (178, 136))
assert tuple(box_coords_scaled.shape) == (4, 4)
torch.testing.assert_close(box_coords_scaled, exp_coords)
2 changes: 1 addition & 1 deletion yolort/ultralytics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .models.common import *
from .models.yolo import *
from .helper import load_yolov5_model
from .helper import add_yolov5_context
from .hubconf import yolov5s
6 changes: 3 additions & 3 deletions yolort/ultralytics/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


@contextlib.contextmanager
def load_yolov5_model():
def add_yolov5_context():
"""
Temporarily add yolov5 folder to `sys.path`. Modified from:
https://github.com/fcakyon/yolov5-pip/blob/0d03de6/yolov5/utils/general.py#L739-L754
Expand All @@ -38,7 +38,6 @@ def attempt_load(weights, map_location=None, inplace=True, fuse=True):
else:
model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().eval()) # without layer fuse


# Compatibility updates
for m in model.modules():
if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model]:
Expand All @@ -52,5 +51,6 @@ def attempt_load(weights, map_location=None, inplace=True, fuse=True):
print(f'Ensemble created with {weights}\n')
for k in ['names']:
setattr(model, k, getattr(model[-1], k))
model.stride = model[torch.argmax(torch.tensor([m.stride.max() for m in model])).int()].stride # max stride
# max stride
model.stride = model[torch.argmax(torch.tensor([m.stride.max() for m in model])).int()].stride
return model # return ensemble

0 comments on commit ea81728

Please sign in to comment.