Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix num_anchors in load_from_ultralytics #285

Merged
merged 5 commits into from
Mar 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 26 additions & 63 deletions test/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,22 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
import cv2
# Copyright (c) 2021, yolort team. All rights reserved.

import numpy as np
import pytest
import torch
from torch import Tensor
from torchvision.io import read_image
from yolort import models
from yolort.models import YOLO
from yolort.utils import (
get_image_from_url,
load_from_ultralytics,
read_image_to_tensor,
FeatureExtractor,
Visualizer,
)
from yolort.models import YOLOv5
from yolort.utils import get_image_from_url, load_from_ultralytics, read_image_to_tensor
from yolort.utils.image_utils import box_cxcywh_to_xyxy
from yolort.v5 import (
letterbox,
load_yolov5_model,
scale_coords,
non_max_suppression,
attempt_download,
)
from yolort.v5 import letterbox, scale_coords, attempt_download


@pytest.mark.parametrize("arch", ["yolov5n"])
def test_visualizer(arch):

from yolort.utils import Visualizer

model = models.__dict__[arch](pretrained=True, size=(320, 320), score_thresh=0.45)
model = model.eval()
img_path = "test/assets/zidane.jpg"
Expand Down Expand Up @@ -65,47 +56,29 @@ def test_load_from_ultralytics(
assert len(model_info["strides"]) == 4 if use_p6 else 3


@pytest.mark.skip(reason="Due to #235")
@pytest.mark.parametrize(
"arch, version, upstream_version, hash_prefix",
[("yolov5s-VOC", "r4.0", "v5.0", "23818cff")],
)
def test_load_from_ultralytics_voc(
arch: str,
version: str,
upstream_version: str,
hash_prefix: str,
):
def test_load_from_ultralytics_voc(arch, version, upstream_version, hash_prefix):
img_path = "test/assets/bus.jpg"

base_url = "https://github.com/ultralytics/yolov5/releases/download/"
model_url = f"{base_url}/{upstream_version}/{arch}.pt"
checkpoint_path = attempt_download(model_url, hash_prefix=hash_prefix)

# Preprocess
img_raw = cv2.imread(img_path)
img = letterbox(img_raw, new_shape=(320, 320))[0]
img = read_image_to_tensor(img)

conf = 0.25
iou = 0.45

# Define YOLOv5 model
model_yolov5 = load_yolov5_model(checkpoint_path)
with torch.no_grad():
outs = model_yolov5(img[None])[0]
outs = non_max_suppression(outs, conf, iou, agnostic=True)
out_yolov5 = outs[0]

# Define yolort model
model_yolort = YOLO.load_from_yolov5(checkpoint_path, score_thresh=conf, version=version)
model_yolort.eval()
model = YOLOv5.load_from_yolov5(checkpoint_path, score_thresh=0.25, version="r4.0")

model = model.eval()
with torch.no_grad():
out_yolort = model_yolort(img[None])
predictions = model.predict(img_path)

torch.testing.assert_allclose(out_yolort[0]["boxes"], out_yolov5[:, :4])
torch.testing.assert_allclose(out_yolort[0]["scores"], out_yolov5[:, 4])
torch.testing.assert_allclose(out_yolort[0]["labels"], out_yolov5[:, 5].to(dtype=torch.int64))
assert isinstance(predictions[0], dict)
assert isinstance(predictions[0]["boxes"], Tensor)
assert isinstance(predictions[0]["labels"], Tensor)
assert isinstance(predictions[0]["scores"], Tensor)
assert len(predictions[0]["labels"]) == 4


def test_read_image_to_tensor():
Expand Down Expand Up @@ -171,34 +144,24 @@ def test_scale_coords():
torch.testing.assert_close(box_coords_scaled, exp_coords)


@pytest.mark.parametrize(
"batch_size, height, width",
[
(8, 640, 640),
(4, 416, 320),
(8, 320, 416),
],
)
@pytest.mark.parametrize(
"arch, width_multiple",
[
("yolov5n", 0.25),
("yolov5s", 0.5),
],
)
@pytest.mark.parametrize("batch_size, height, width", [(8, 640, 640), (4, 416, 320), (8, 320, 416)])
@pytest.mark.parametrize("arch, width_multiple", [("yolov5n", 0.25), ("yolov5s", 0.5)])
def test_feature_extractor(batch_size, height, width, arch, width_multiple):
c = 3

from yolort.utils import FeatureExtractor

channel = 3
grow_widths = [256, 512, 1024]
in_channels = [int(gw * width_multiple) for gw in grow_widths]
strides = [8, 16, 32]
num_outputs = 85
expected_features = [(batch_size, inc, height // s, width // s) for inc, s in zip(in_channels, strides)]
expected_head_outputs = [(batch_size, c, height // s, width // s, num_outputs) for s in strides]
expected_head_outputs = [(batch_size, channel, height // s, width // s, num_outputs) for s in strides]

model = models.__dict__[arch]()
model = model.train()
yolo_features = FeatureExtractor(model.model, return_layers=["backbone", "head"])
images = torch.rand(batch_size, c, height, width)
images = torch.rand(batch_size, channel, height, width)
targets = torch.rand(61, 6)
intermediate_features = yolo_features(images, targets)
features = intermediate_features["backbone"]
Expand Down
5 changes: 3 additions & 2 deletions yolort/utils/update_module_state.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) 2020, yolort team. All rights reserved.

from functools import reduce
from typing import List, Dict, Optional
from typing import Dict, List, Optional

import torch
from torch import nn
Expand Down Expand Up @@ -56,9 +56,10 @@ def load_from_ultralytics(checkpoint_path: str, version: str = "r6.0"):
# YOLOv5 will change the anchors setting when using the auto-anchor mechanism. So we
# use the following formula to compute the anchor_grids instead of attaching it via
# checkpoint_yolov5.yaml["anchors"]
num_anchors = checkpoint_yolov5.model[-1].anchors.shape[1]
anchor_grids = (
(checkpoint_yolov5.model[-1].anchors * checkpoint_yolov5.model[-1].stride.view(-1, 1, 1))
.reshape(1, -1, 6)
.reshape(1, -1, 2 * num_anchors)
.tolist()[0]
)

Expand Down