Skip to content

Commit

Permalink
Fixing ImprotError
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Oct 16, 2021
1 parent 7ba19aa commit b96855e
Show file tree
Hide file tree
Showing 9 changed files with 172 additions and 164 deletions.
83 changes: 0 additions & 83 deletions test/test_models_utils.py

This file was deleted.

102 changes: 95 additions & 7 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,94 @@
import pytest
import torch
from torch import Tensor
from yolort.models import yolov5s
from yolort import models
from yolort.models import YOLOv5
from yolort.utils import (
FeatureExtractor,
get_image_from_url,
load_from_ultralytics,
read_image_to_tensor,
)
from yolort.utils.image_utils import box_cxcywh_to_xyxy
from yolort.v5 import letterbox, scale_coords


@pytest.mark.parametrize(
"arch, version, upstream_version, hash_prefix, use_p6",
[
("yolov5s", "r4.0", "v4.0", "9ca9a642", False),
("yolov5s", "r4.0", "v6.0", "c3b140f3", False),
],
)
def test_load_from_ultralytics(
arch: str,
version: str,
upstream_version: str,
hash_prefix: str,
use_p6: bool,
):
checkpoint_path = f"{arch}_{version}_{hash_prefix}"
base_url = "https://github.com/ultralytics/yolov5/releases/download/"
model_url = f"{base_url}/{upstream_version}/{arch}.pt"

torch.hub.download_url_to_file(
model_url,
checkpoint_path,
hash_prefix=hash_prefix,
)
model_info = load_from_ultralytics(checkpoint_path, version=version)
assert isinstance(model_info, dict)
assert model_info["num_classes"] == 80
assert model_info["size"] == arch.replace("yolov5", "")
assert model_info["use_p6"] == use_p6
assert len(model_info["strides"]) == 4 if use_p6 else 3


@pytest.mark.parametrize(
"arch, version, upstream_version, hash_prefix",
[("yolov5s", "r4.0", "v4.0", "9ca9a642")],
)
def test_load_from_yolov5(
arch: str,
version: str,
upstream_version: str,
hash_prefix: str,
):
img_path = "test/assets/bus.jpg"
checkpoint_path = f"{arch}_{version}_{hash_prefix}"

base_url = "https://github.com/ultralytics/yolov5/releases/download/"
model_url = f"{base_url}/{upstream_version}/{arch}.pt"

torch.hub.download_url_to_file(
model_url,
checkpoint_path,
hash_prefix=hash_prefix,
)

model_yolov5 = YOLOv5.load_from_yolov5(checkpoint_path, version=version)
model_yolov5.eval()
out_from_yolov5 = model_yolov5.predict(img_path)
assert isinstance(out_from_yolov5[0], dict)
assert isinstance(out_from_yolov5[0]["boxes"], Tensor)
assert isinstance(out_from_yolov5[0]["labels"], Tensor)
assert isinstance(out_from_yolov5[0]["scores"], Tensor)

model = models.__dict__[arch](pretrained=True, score_thresh=0.25)
model.eval()
out = model.predict(img_path)

torch.testing.assert_close(
out_from_yolov5[0]["scores"], out[0]["scores"], rtol=0, atol=0
)
torch.testing.assert_close(
out_from_yolov5[0]["labels"], out[0]["labels"], rtol=0, atol=0
)
torch.testing.assert_close(
out_from_yolov5[0]["boxes"], out[0]["boxes"], rtol=0, atol=0
)


def test_read_image_to_tensor():
N, H, W = 3, 720, 360
img = np.random.randint(0, 255, (H, W, N), dtype="uint8") # As a dummy image
Expand Down Expand Up @@ -76,21 +154,31 @@ def test_scale_coords():
torch.testing.assert_close(box_coords_scaled, exp_coords)


@pytest.mark.parametrize("b, h, w", [(8, 640, 640), (4, 416, 320), (8, 320, 416)])
def test_feature_extractor(b, h, w):
@pytest.mark.parametrize(
"batch_size, height, width",
[
(8, 640, 640),
(4, 416, 320),
(8, 320, 416),
]
)
@pytest.mark.parametrize("arch", ["yolov5n", "yolov5s"])
def test_feature_extractor(batch_size, height, width, arch):
c = 3
in_channels = [128, 256, 512]
strides = [8, 16, 32]
num_outputs = 85
expected_features = [
(b, inc, h // s, w // s) for inc, s in zip(in_channels, strides)
(batch_size, inc, height // s, width // s) for inc, s in zip(in_channels, strides)
]
expected_head_outputs = [
(batch_size, c, height // s, width // s, num_outputs) for s in strides
]
expected_head_outputs = [(b, c, h // s, w // s, num_outputs) for s in strides]

model = yolov5s()
model = models.__dict__[arch]()
model = model.train()
yolo_features = FeatureExtractor(model.model, return_layers=["backbone", "head"])
images = torch.rand(b, c, h, w)
images = torch.rand(batch_size, c, height, width)
targets = torch.rand(61, 6)
intermediate_features = yolo_features(images, targets)
features = intermediate_features["backbone"]
Expand Down
2 changes: 1 addition & 1 deletion yolort/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

from torch import nn

from yolort.utils.activations import Hardswish, SiLU
from yolort.v5 import Conv
from yolort.v5.utils.activations import Hardswish, SiLU
from .yolo import YOLO
from .yolo_module import YOLOv5

Expand Down
65 changes: 0 additions & 65 deletions yolort/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,71 +5,6 @@
from torch import nn, Tensor
from torchvision.ops import box_convert, box_iou

from yolort.utils import ModuleStateUpdate
from yolort.v5 import load_yolov5_model, get_yolov5_size


def load_from_ultralytics(checkpoint_path: str, version: str = "r6.0"):
"""
Allows the user to load model state file from the checkpoint trained from
the ultralytics/yolov5.
Args:
checkpoint_path (str): Path of the YOLOv5 checkpoint model.
version (str): upstream version released by the ultralytics/yolov5, Possible
values are ["r3.1", "r4.0", "r6.0"]. Default: "r6.0".
"""

assert version in [
"r3.1",
"r4.0",
"r6.0",
], "Currently does not support this version."

checkpoint_yolov5 = load_yolov5_model(checkpoint_path)
num_classes = checkpoint_yolov5.yaml["nc"]
strides = checkpoint_yolov5.stride
anchor_grids = checkpoint_yolov5.yaml["anchors"]
depth_multiple = checkpoint_yolov5.yaml["depth_multiple"]
width_multiple = checkpoint_yolov5.yaml["width_multiple"]

use_p6 = False
if len(strides) == 4:
use_p6 = True

if use_p6:
inner_block_maps = {"0": "9", "1": "10", "3": "13", "4": "14"}
layer_block_maps = {"0": "17", "1": "18", "2": "20", "3": "21", "4": "23"}
else:
inner_block_maps = {"0": "9", "1": "10", "3": "13", "4": "14"}
layer_block_maps = {"0": "17", "1": "18", "2": "20", "3": "21", "4": "23"}

module_state_updater = ModuleStateUpdate(
arch=None,
depth_multiple=depth_multiple,
width_multiple=width_multiple,
version=version,
num_classes=num_classes,
inner_block_maps=inner_block_maps,
layer_block_maps=layer_block_maps,
use_p6=use_p6,
)
module_state_updater.updating(checkpoint_yolov5)
state_dict = module_state_updater.model.state_dict()

size = get_yolov5_size(depth_multiple, width_multiple)

return {
"num_classes": num_classes,
"depth_multiple": depth_multiple,
"width_multiple": width_multiple,
"strides": strides,
"anchor_grids": anchor_grids,
"use_p6": use_p6,
"size": size,
"state_dict": state_dict,
}


def _evaluate_iou(target, pred):
"""
Expand Down
2 changes: 1 addition & 1 deletion yolort/models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch import nn, Tensor
from torchvision.models.utils import load_state_dict_from_url

from ._utils import load_from_ultralytics
from yolort.utils import load_from_ultralytics
from .anchor_utils import AnchorGenerator
from .backbone_utils import darknet_pan_backbone
from .box_head import YOLOHead, SetCriterion, PostProcess
Expand Down
3 changes: 2 additions & 1 deletion yolort/models/yolo_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from torchvision.io import read_image

from yolort.data import COCOEvaluator, contains_any_tensor
from yolort.utils import load_from_ultralytics
from . import yolo
from ._utils import _evaluate_iou, load_from_ultralytics
from ._utils import _evaluate_iou
from .transform import YOLOTransform

__all__ = ["YOLOv5"]
Expand Down
5 changes: 3 additions & 2 deletions yolort/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@

from .hooks import FeatureExtractor
from .image_utils import cv2_imshow, get_image_from_url, read_image_to_tensor
from .update_module_state import ModuleStateUpdate
from .update_module_state import load_from_ultralytics


__all__ = [
"FeatureExtractor",
"ModuleStateUpdate",
"cv2_imshow",
"get_image_from_url",
"get_callable_dict",
"load_from_ultralytics",
"read_image_to_tensor",
]

Expand Down
Loading

0 comments on commit b96855e

Please sign in to comment.