Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support loading model weight from ultralytics #167

Merged
merged 30 commits into from
Sep 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
696524c
Updating yolov5 module updating to master
zhiqwang Sep 17, 2021
6406ce1
Support loading local yolov5 repo
zhiqwang Sep 17, 2021
89be2f0
Add ultralytics/yolov5 to yolort (#168)
zhiqwang Sep 17, 2021
ad7033d
Move ultralytics/yolov5 to yolort/ultralytics
zhiqwang Sep 18, 2021
f67fe8d
Add docstrings for exporting friendly activation modules
zhiqwang Sep 18, 2021
2f30492
Fixing loading yolov5 from ultralytics
zhiqwang Sep 18, 2021
46178b6
Fixing loading yolov5 from ultralytics
zhiqwang Sep 18, 2021
dbb6283
Fixing loading yolov5 from ultralytics
zhiqwang Sep 18, 2021
cb43764
Fixing unit-test
zhiqwang Sep 18, 2021
bf8cfeb
Cleanup codes
zhiqwang Sep 18, 2021
df3f033
Load yolov5 models in ultralytics
zhiqwang Sep 19, 2021
7668972
Move yolo_vanilla back to yolo.py
zhiqwang Sep 19, 2021
4288f6e
Cleanup utils
zhiqwang Sep 19, 2021
69c4924
Add *.pt to .gitignore
zhiqwang Sep 19, 2021
ea81728
Fixing unittest
zhiqwang Sep 19, 2021
a923118
Adopt torch.rand instead of torch.randn
zhiqwang Sep 19, 2021
0d16620
Add load_model method and fixing loading from ultralytics
zhiqwang Sep 19, 2021
fda207a
Remove hubconf.py
zhiqwang Sep 19, 2021
2f1bbbf
Resolve importing modules
zhiqwang Sep 19, 2021
5fb33a7
Move ultralytics to v5
zhiqwang Sep 19, 2021
0ab1501
Remove train_vanilla.py
zhiqwang Sep 19, 2021
f1276a0
Enable `test_update_module_state_from_ultralytics`
zhiqwang Sep 19, 2021
9a652d2
Remove loggers temporarily
zhiqwang Sep 19, 2021
82f0d09
Reduce column length
zhiqwang Sep 20, 2021
d07161d
Cleanup unittest and yolort.utils
zhiqwang Sep 20, 2021
3698657
Rename load_model to load_yolov5_model
zhiqwang Sep 21, 2021
769c142
Add unittest for load_yolov5_model
zhiqwang Sep 21, 2021
251226f
Rename model_path to checkpoint_path
zhiqwang Sep 21, 2021
0d0b5f1
Add load_from_yolov5 in YOLOModule
zhiqwang Sep 21, 2021
348f053
Add unittest for load_from_yolov5
zhiqwang Sep 21, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@ logs
lightning_logs
*.ipynb
runs
yolov5s.pt
*.pt
yolort/version.py
.idea
*.ttf
wandb

# macOS dir files
.DS_Store

Expand Down Expand Up @@ -39,6 +42,7 @@ htmlcov/
.coverage
.coverage.*
.cache
test-output.xml
nosetests.xml
coverage.xml
*.cover
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# opencv-python>=4.1.2
matplotlib>=3.2.2
numpy>=1.18.5
pillow
Pillow>=8.0.0
scipy>=1.4.1
tqdm>=4.41.0

Expand All @@ -22,6 +22,7 @@ onnx>=1.8.0
# plotting ------------------------------------
ipython
tabulate
pandas

# Lightning -----------------------------------
pytorch_lightning>=1.3.1
Expand Down
4 changes: 2 additions & 2 deletions test/test_data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
def test_contains_any_tensor():
dummy_numpy = np.random.randn(3, 6)
assert not contains_any_tensor(dummy_numpy)
dummy_tensor = torch.randn(3, 6)
dummy_tensor = torch.rand(3, 6)
assert contains_any_tensor(dummy_tensor)
dummy_tensors = [torch.randn(3, 6), torch.randn(9, 5)]
dummy_tensors = [torch.rand(3, 6), torch.rand(9, 5)]
assert contains_any_tensor(dummy_tensors)


Expand Down
28 changes: 0 additions & 28 deletions test/test_hooks_utils.py

This file was deleted.

43 changes: 0 additions & 43 deletions test/test_image_utils.py

This file was deleted.

43 changes: 37 additions & 6 deletions test/test_models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
import io
from pathlib import Path
import contextlib
import warnings
import pytest

import torch
from torch import Tensor

Expand Down Expand Up @@ -235,9 +237,9 @@ def test_postprocessors(self):

assert len(out) == N
assert isinstance(out[0], dict)
assert isinstance(out[0]["boxes"], Tensor)
assert isinstance(out[0]["labels"], Tensor)
assert isinstance(out[0]["scores"], Tensor)
assert isinstance(out[0]['boxes'], Tensor)
assert isinstance(out[0]['labels'], Tensor)
assert isinstance(out[0]['scores'], Tensor)
_check_jit_scriptable(model, (head_outputs, anchors_tuple))

def test_criterion(self):
Expand Down Expand Up @@ -272,6 +274,35 @@ def test_torchscript(arch):
out = model(x)
out_script = scripted_model(x)

torch.testing.assert_close(out[0]["scores"], out_script[1][0]["scores"], rtol=0, atol=0)
torch.testing.assert_close(out[0]["labels"], out_script[1][0]["labels"], rtol=0, atol=0)
torch.testing.assert_close(out[0]["boxes"], out_script[1][0]["boxes"], rtol=0, atol=0)
torch.testing.assert_close(out[0]['scores'], out_script[1][0]['scores'], rtol=0, atol=0)
torch.testing.assert_close(out[0]['labels'], out_script[1][0]['labels'], rtol=0, atol=0)
torch.testing.assert_close(out[0]['boxes'], out_script[1][0]['boxes'], rtol=0, atol=0)


@pytest.mark.parametrize('arch, version, hash_prefix', [
('yolov5s', 'v4.0', '9ca9a642')
])
def test_load_from_yolov5(arch, version, hash_prefix):
img_path = 'test/assets/bus.jpg'
yolov5s_r40_path = Path(f'{arch}.pt')

if not yolov5s_r40_path.is_file():
yolov5s_r40_url = f'https://github.com/ultralytics/yolov5/releases/download/{version}/{arch}.pt'
torch.hub.download_url_to_file(yolov5s_r40_url, yolov5s_r40_path, hash_prefix=hash_prefix)

model_load_from_yolov5 = models.__dict__[arch](score_thresh=0.25)
model_load_from_yolov5.load_from_yolov5(yolov5s_r40_path)
model_load_from_yolov5.eval()
out_from_yolov5 = model_load_from_yolov5.predict(img_path)
assert isinstance(out_from_yolov5[0], dict)
assert isinstance(out_from_yolov5[0]['boxes'], Tensor)
assert isinstance(out_from_yolov5[0]['labels'], Tensor)
assert isinstance(out_from_yolov5[0]['scores'], Tensor)

model = models.__dict__[arch](pretrained=True, score_thresh=0.25)
model.eval()
out = model.predict(img_path)

torch.testing.assert_close(out_from_yolov5[0]['scores'], out[0]['scores'], rtol=0, atol=0)
torch.testing.assert_close(out_from_yolov5[0]['labels'], out[0]['labels'], rtol=0, atol=0)
torch.testing.assert_close(out_from_yolov5[0]['boxes'], out[0]['boxes'], rtol=0, atol=0)
2 changes: 1 addition & 1 deletion test/test_models_anchor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def get_features(self, images):
return features

def test_anchor_generator(self):
images = torch.randn(2, 3, 10, 10)
images = torch.rand(2, 3, 10, 10)
features = self.get_features(images)
model = AnchorGenerator(self.strides, self.anchor_grids)
model.eval()
Expand Down
4 changes: 2 additions & 2 deletions test/test_models_common.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import pytest
import torch
from yolort.models.common import focus_transform, space_to_depth
from yolort.v5 import focus_transform, space_to_depth


@pytest.mark.parametrize('n, b, h, w', [(1, 3, 480, 640), (4, 3, 416, 320), (4, 3, 320, 416)])
def test_space_to_depth(n, b, h, w):
tensor_input = torch.randn((n, b, h, w))
tensor_input = torch.rand((n, b, h, w))
out1 = focus_transform(tensor_input)
out2 = space_to_depth(tensor_input)
torch.testing.assert_close(out2, out1)
78 changes: 74 additions & 4 deletions test/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,34 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
from pathlib import Path
import pytest

import numpy as np
import torch
from torch import nn, Tensor

from yolort.models import yolov5s
from yolort.v5 import letterbox, scale_coords
from yolort.utils.image_utils import box_cxcywh_to_xyxy
from yolort.utils import (
FeatureExtractor,
update_module_state_from_ultralytics,
read_image_to_tensor,
get_image_from_url,
)


@pytest.mark.skip("Temporarily close the test here.")
def test_update_module_state_from_ultralytics():
yolov5s_r40_path = Path('yolov5s.pt')

if not yolov5s_r40_path.is_file():
yolov5s_r40_url = 'https://github.com/ultralytics/yolov5/releases/download/v4.0/yolov5s.pt'
torch.hub.download_url_to_file(yolov5s_r40_url, yolov5s_r40_path, hash_prefix='9ca9a642')

model = update_module_state_from_ultralytics(
str(yolov5s_r40_path),
arch='yolov5s',
version='v4.0',
feature_fusion_type='PAN',
num_classes=80,
custom_path_or_model=None,
)
assert isinstance(model, nn.Module)

Expand All @@ -33,7 +43,67 @@ def test_read_image_to_tensor():


def test_get_image_from_url():
url = "https://github.com/ultralytics/yolov5/raw/master/data/images/zidane.jpg"
url = 'https://raw.githubusercontent.com/zhiqwang/yolov5-rt-stack/master/test/assets/zidane.jpg'
img = get_image_from_url(url)
assert isinstance(img, np.ndarray)
assert tuple(img.shape) == (720, 1280, 3)


def test_letterbox():
img = np.random.randint(0, 255, (720, 360, 3), dtype='uint8') # As a dummy image
out = letterbox(img, new_shape=(416, 416))[0]
assert tuple(out.shape) == (416, 224, 3)


def test_box_cxcywh_to_xyxy():
box_cxcywh = np.asarray([[50, 50, 100, 100],
[0, 0, 0, 0],
[20, 25, 20, 20],
[58, 65, 70, 60]], dtype=np.float)
exp_xyxy = np.asarray([[0, 0, 100, 100],
[0, 0, 0, 0],
[10, 15, 30, 35],
[23, 35, 93, 95]], dtype=np.float)

box_xyxy = box_cxcywh_to_xyxy(box_cxcywh)
assert exp_xyxy.shape == (4, 4)
assert exp_xyxy.dtype == box_xyxy.dtype
np.testing.assert_array_equal(exp_xyxy, box_xyxy)


def test_scale_coords():
box_tensor = torch.tensor([[0., 0., 100., 100.],
[0., 0., 0., 0.],
[10., 15., 30., 35.],
[20., 35., 90., 95.]], dtype=torch.float)
exp_coords = torch.tensor([[0., 0., 108.05, 111.25],
[0., 0., 0., 0.],
[7.9250, 16.6875, 30.1750, 38.9375],
[19.05, 38.9375, 96.9250, 105.6875]], dtype=torch.float)

box_coords_scaled = scale_coords((160, 128), box_tensor, (178, 136))
assert tuple(box_coords_scaled.shape) == (4, 4)
torch.testing.assert_close(box_coords_scaled, exp_coords)


@pytest.mark.parametrize('b, h, w', [(8, 640, 640), (4, 416, 320), (8, 320, 416)])
def test_feature_extractor(b, h, w):
c = 3
in_channels = [128, 256, 512]
strides = [8, 16, 32]
num_outputs = 85
expected_features = [(b, inc, h // s, w // s) for inc, s in zip(in_channels, strides)]
expected_head_outputs = [(b, c, h // s, w // s, num_outputs) for s in strides]

model = yolov5s()
model = model.train()
yolo_features = FeatureExtractor(model.model, return_layers=['backbone', 'head'])
images = torch.rand(b, c, h, w)
targets = torch.rand(61, 6)
intermediate_features = yolo_features(images, targets)
features = intermediate_features['backbone']
head_outputs = intermediate_features['head']
assert isinstance(features, list)
assert [f.shape for f in features] == expected_features
assert isinstance(head_outputs, list)
assert [h.shape for h in head_outputs] == expected_head_outputs
24 changes: 24 additions & 0 deletions test/test_v5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
from pathlib import Path
import torch
from torch import Tensor

from yolort.v5 import load_yolov5_model


def test_load_yolov5_model():
img_path = 'test/assets/zidane.jpg'

yolov5s_r40_path = Path('yolov5s.pt')

if not yolov5s_r40_path.is_file():
yolov5s_r40_url = 'https://github.com/ultralytics/yolov5/releases/download/v4.0/yolov5s.pt'
torch.hub.download_url_to_file(yolov5s_r40_url, yolov5s_r40_path, hash_prefix='9ca9a642')

model = load_yolov5_model(str(yolov5s_r40_path), autoshape=True, verbose=False)
results = model(img_path)

assert isinstance(results.pred, list)
assert len(results.pred) == 1
assert isinstance(results.pred[0], Tensor)
assert results.pred[0].shape == (3, 6)
1 change: 1 addition & 0 deletions yolort/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from yolort import models
from yolort import data
from yolort import utils
from yolort import v5

try:
from .version import __version__ # noqa: F401
Expand Down
9 changes: 4 additions & 5 deletions yolort/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
from typing import Any
from torch import nn

from .common import Conv
from .yolo_module import YOLOModule

from ..utils.activations import Hardswish, SiLU
from yolort.utils.activations import Hardswish, SiLU
from yolort.v5 import Conv

from typing import Any
from .yolo_module import YOLOModule


def yolov5s(upstream_version: str = 'r4.0', export_friendly: bool = False, **kwargs: Any):
Expand Down
2 changes: 1 addition & 1 deletion yolort/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch import nn, Tensor
from torchvision.ops import box_convert, box_iou

from typing import Tuple, List
from typing import Tuple


def _evaluate_iou(target, pred):
Expand Down
Loading