From 2ccfb6628c9197dce32d779fdf0284a90b138b40 Mon Sep 17 00:00:00 2001
From: Zhiqiang Wang <zhiqwang@foxmail.com>
Date: Thu, 10 Mar 2022 14:59:07 +0800
Subject: [PATCH] Fix `num_anchors` in load_from_ultralytics (#285)

* Enable test_load_from_ultralytics_voc

* The nms type is non-agnostic actually

* Fix num_anchors in load_from_ultralytics

* Apply pre-commit

* Update copyright
---
 test/test_utils.py                  | 89 +++++++++--------------------
 yolort/utils/update_module_state.py |  5 +-
 2 files changed, 29 insertions(+), 65 deletions(-)

diff --git a/test/test_utils.py b/test/test_utils.py
index d1d57542..199bcd4d 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -1,31 +1,22 @@
-# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
-import cv2
+# Copyright (c) 2021, yolort team. All rights reserved.
+
 import numpy as np
 import pytest
 import torch
 from torch import Tensor
 from torchvision.io import read_image
 from yolort import models
-from yolort.models import YOLO
-from yolort.utils import (
-    get_image_from_url,
-    load_from_ultralytics,
-    read_image_to_tensor,
-    FeatureExtractor,
-    Visualizer,
-)
+from yolort.models import YOLOv5
+from yolort.utils import get_image_from_url, load_from_ultralytics, read_image_to_tensor
 from yolort.utils.image_utils import box_cxcywh_to_xyxy
-from yolort.v5 import (
-    letterbox,
-    load_yolov5_model,
-    scale_coords,
-    non_max_suppression,
-    attempt_download,
-)
+from yolort.v5 import letterbox, scale_coords, attempt_download
 
 
 @pytest.mark.parametrize("arch", ["yolov5n"])
 def test_visualizer(arch):
+
+    from yolort.utils import Visualizer
+
     model = models.__dict__[arch](pretrained=True, size=(320, 320), score_thresh=0.45)
     model = model.eval()
     img_path = "test/assets/zidane.jpg"
@@ -65,47 +56,29 @@ def test_load_from_ultralytics(
     assert len(model_info["strides"]) == 4 if use_p6 else 3
 
 
-@pytest.mark.skip(reason="Due to #235")
 @pytest.mark.parametrize(
     "arch, version, upstream_version, hash_prefix",
     [("yolov5s-VOC", "r4.0", "v5.0", "23818cff")],
 )
-def test_load_from_ultralytics_voc(
-    arch: str,
-    version: str,
-    upstream_version: str,
-    hash_prefix: str,
-):
+def test_load_from_ultralytics_voc(arch, version, upstream_version, hash_prefix):
     img_path = "test/assets/bus.jpg"
 
     base_url = "https://github.com/ultralytics/yolov5/releases/download/"
     model_url = f"{base_url}/{upstream_version}/{arch}.pt"
     checkpoint_path = attempt_download(model_url, hash_prefix=hash_prefix)
 
-    # Preprocess
-    img_raw = cv2.imread(img_path)
-    img = letterbox(img_raw, new_shape=(320, 320))[0]
-    img = read_image_to_tensor(img)
-
-    conf = 0.25
-    iou = 0.45
-
-    # Define YOLOv5 model
-    model_yolov5 = load_yolov5_model(checkpoint_path)
-    with torch.no_grad():
-        outs = model_yolov5(img[None])[0]
-        outs = non_max_suppression(outs, conf, iou, agnostic=True)
-        out_yolov5 = outs[0]
-
     # Define yolort model
-    model_yolort = YOLO.load_from_yolov5(checkpoint_path, score_thresh=conf, version=version)
-    model_yolort.eval()
+    model = YOLOv5.load_from_yolov5(checkpoint_path, score_thresh=0.25, version="r4.0")
+
+    model = model.eval()
     with torch.no_grad():
-        out_yolort = model_yolort(img[None])
+        predictions = model.predict(img_path)
 
-    torch.testing.assert_allclose(out_yolort[0]["boxes"], out_yolov5[:, :4])
-    torch.testing.assert_allclose(out_yolort[0]["scores"], out_yolov5[:, 4])
-    torch.testing.assert_allclose(out_yolort[0]["labels"], out_yolov5[:, 5].to(dtype=torch.int64))
+    assert isinstance(predictions[0], dict)
+    assert isinstance(predictions[0]["boxes"], Tensor)
+    assert isinstance(predictions[0]["labels"], Tensor)
+    assert isinstance(predictions[0]["scores"], Tensor)
+    assert len(predictions[0]["labels"]) == 4
 
 
 def test_read_image_to_tensor():
@@ -171,34 +144,24 @@ def test_scale_coords():
     torch.testing.assert_close(box_coords_scaled, exp_coords)
 
 
-@pytest.mark.parametrize(
-    "batch_size, height, width",
-    [
-        (8, 640, 640),
-        (4, 416, 320),
-        (8, 320, 416),
-    ],
-)
-@pytest.mark.parametrize(
-    "arch, width_multiple",
-    [
-        ("yolov5n", 0.25),
-        ("yolov5s", 0.5),
-    ],
-)
+@pytest.mark.parametrize("batch_size, height, width", [(8, 640, 640), (4, 416, 320), (8, 320, 416)])
+@pytest.mark.parametrize("arch, width_multiple", [("yolov5n", 0.25), ("yolov5s", 0.5)])
 def test_feature_extractor(batch_size, height, width, arch, width_multiple):
-    c = 3
+
+    from yolort.utils import FeatureExtractor
+
+    channel = 3
     grow_widths = [256, 512, 1024]
     in_channels = [int(gw * width_multiple) for gw in grow_widths]
     strides = [8, 16, 32]
     num_outputs = 85
     expected_features = [(batch_size, inc, height // s, width // s) for inc, s in zip(in_channels, strides)]
-    expected_head_outputs = [(batch_size, c, height // s, width // s, num_outputs) for s in strides]
+    expected_head_outputs = [(batch_size, channel, height // s, width // s, num_outputs) for s in strides]
 
     model = models.__dict__[arch]()
     model = model.train()
     yolo_features = FeatureExtractor(model.model, return_layers=["backbone", "head"])
-    images = torch.rand(batch_size, c, height, width)
+    images = torch.rand(batch_size, channel, height, width)
     targets = torch.rand(61, 6)
     intermediate_features = yolo_features(images, targets)
     features = intermediate_features["backbone"]
diff --git a/yolort/utils/update_module_state.py b/yolort/utils/update_module_state.py
index 1d346535..30010f09 100644
--- a/yolort/utils/update_module_state.py
+++ b/yolort/utils/update_module_state.py
@@ -1,7 +1,7 @@
 # Copyright (c) 2020, yolort team. All rights reserved.
 
 from functools import reduce
-from typing import List, Dict, Optional
+from typing import Dict, List, Optional
 
 import torch
 from torch import nn
@@ -56,9 +56,10 @@ def load_from_ultralytics(checkpoint_path: str, version: str = "r6.0"):
     # YOLOv5 will change the anchors setting when using the auto-anchor mechanism. So we
     # use the following formula to compute the anchor_grids instead of attaching it via
     # checkpoint_yolov5.yaml["anchors"]
+    num_anchors = checkpoint_yolov5.model[-1].anchors.shape[1]
     anchor_grids = (
         (checkpoint_yolov5.model[-1].anchors * checkpoint_yolov5.model[-1].stride.view(-1, 1, 1))
-        .reshape(1, -1, 6)
+        .reshape(1, -1, 2 * num_anchors)
         .tolist()[0]
     )