diff --git a/hubconf.py b/hubconf.py index a65f866c..c85acf43 100644 --- a/hubconf.py +++ b/hubconf.py @@ -1,5 +1,5 @@ # Optional list of dependencies required by the package dependencies = ['yaml', 'torch', 'torchvision'] -from models import yolov5s as yolov5 +from models import yolov5 from models import yolov5_onnx diff --git a/models/__init__.py b/models/__init__.py index a9769ad3..2c1ff2ec 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1,20 +1,21 @@ from torch import nn from .common import Conv -from .yolo import yolov5s +from .yolo import yolov5 from utils.activations import Hardswish def yolov5_onnx( + cfg_path='yolov5s.yaml', pretrained=False, progress=True, num_classes=80, **kwargs, ): - model = yolov5s(pretrained=pretrained, progress=progress, - num_classes=num_classes, **kwargs) + model = yolov5(cfg_path=cfg_path, pretrained=pretrained, progress=progress, + num_classes=num_classes, **kwargs) for m in model.modules(): m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility if isinstance(m, Conv) and isinstance(m.act, nn.Hardswish): diff --git a/models/backbone.py b/models/backbone.py index e0e30a8c..d3ab848c 100644 --- a/models/backbone.py +++ b/models/backbone.py @@ -1,5 +1,5 @@ # Modified from ultralytics/yolov5 by Zhiqiang Wang -import pathlib +from pathlib import Path from collections import OrderedDict import yaml @@ -189,7 +189,7 @@ def forward(self, x): def darknet(cfg_path='yolov5s.yaml', pretrained=False): - cfg_path = pathlib.Path(__file__).parent.absolute().joinpath(cfg_path) + cfg_path = Path(__file__).parent.absolute().joinpath(cfg_path) with open(cfg_path) as f: model_dict = yaml.load(f, Loader=yaml.FullLoader) diff --git a/models/yolo.py b/models/yolo.py index f96927b2..dbde5db0 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -1,6 +1,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Modified by Zhiqiang Wang (zhiqwang@outlook.com) import warnings +from pathlib import Path import torch from torch import nn, Tensor @@ -151,13 +152,14 @@ def forward( model_urls = { - 'yolov5s': - 'https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.2.1/yolov5s.pt', + 'yolov5s': 'https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.2.1/yolov5s.pt', + 'yolov5m': 'https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.2.4/yolov5m.pt', + 'yolov5l': 'https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.2.4/yolov5l.pt', } -def yolov5s(pretrained=False, progress=True, - num_classes=80, pretrained_backbone=True, **kwargs): +def yolov5(cfg_path='yolov5s.yaml', pretrained=False, progress=True, + num_classes=80, pretrained_backbone=True, **kwargs): """ Constructs a YOLO model. @@ -185,7 +187,7 @@ def yolov5s(pretrained=False, progress=True, Example:: - >>> model = yolov5s(pretrained=True) + >>> model = yolov5(pretrained=True) >>> model.eval() >>> x = [torch.rand(3, 416, 320), torch.rand(3, 480, 352)] >>> predictions = model(x) @@ -198,9 +200,9 @@ def yolov5s(pretrained=False, progress=True, # no need to download the backbone if pretrained is set pretrained_backbone = False # skip P2 because it generates too many anchors (according to their paper) - backbone, anchor_grids = darknet(cfg_path='yolov5s.yaml', pretrained=pretrained_backbone) + backbone, anchor_grids = darknet(cfg_path=cfg_path, pretrained=pretrained_backbone) model = YOLO(backbone, num_classes, anchor_grids, **kwargs) if pretrained: - state_dict = load_state_dict_from_url(model_urls['yolov5s'], progress=progress) + state_dict = load_state_dict_from_url(model_urls[Path(cfg_path).stem], progress=progress) model.load_state_dict(state_dict) return model diff --git a/models/yolov5l.yaml b/models/yolov5l.yaml new file mode 100644 index 00000000..13095541 --- /dev/null +++ b/models/yolov5l.yaml @@ -0,0 +1,48 @@ +# parameters +nc: 80 # number of classes +depth_multiple: 1.0 # model depth multiple +width_multiple: 1.0 # layer channel multiple + +# anchors +anchors: + - [10,13, 16,30, 33,23] # P3/8 + - [30,61, 62,45, 59,119] # P4/16 + - [116,90, 156,198, 373,326] # P5/32 + +# YOLOv5 backbone +backbone: + # [from, number, module, args] + [[-1, 1, Focus, [64, 3]], # 0-P1/2 + [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 + [-1, 3, BottleneckCSP, [128]], + [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 + [-1, 9, BottleneckCSP, [256]], + [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 + [-1, 9, BottleneckCSP, [512]], + [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 + [-1, 1, SPP, [1024, [5, 9, 13]]], + [-1, 3, BottleneckCSP, [1024, False]], # 9 + ] + +# YOLOv5 head +head: + [[-1, 1, Conv, [512, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 6], 1, Concat, [1]], # cat backbone P4 + [-1, 3, BottleneckCSP, [512, False]], # 13 + + [-1, 1, Conv, [256, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 4], 1, Concat, [1]], # cat backbone P3 + [-1, 3, BottleneckCSP, [256, False]], # 17 (P3/8-small) + + [-1, 1, Conv, [256, 3, 2]], + [[-1, 14], 1, Concat, [1]], # cat head P4 + [-1, 3, BottleneckCSP, [512, False]], # 20 (P4/16-medium) + + [-1, 1, Conv, [512, 3, 2]], + [[-1, 10], 1, Concat, [1]], # cat head P5 + [-1, 3, BottleneckCSP, [1024, False]], # 23 (P5/32-large) + + [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) + ] diff --git a/models/yolov5m.yaml b/models/yolov5m.yaml new file mode 100644 index 00000000..eb50a713 --- /dev/null +++ b/models/yolov5m.yaml @@ -0,0 +1,48 @@ +# parameters +nc: 80 # number of classes +depth_multiple: 0.67 # model depth multiple +width_multiple: 0.75 # layer channel multiple + +# anchors +anchors: + - [10,13, 16,30, 33,23] # P3/8 + - [30,61, 62,45, 59,119] # P4/16 + - [116,90, 156,198, 373,326] # P5/32 + +# YOLOv5 backbone +backbone: + # [from, number, module, args] + [[-1, 1, Focus, [64, 3]], # 0-P1/2 + [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 + [-1, 3, BottleneckCSP, [128]], + [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 + [-1, 9, BottleneckCSP, [256]], + [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 + [-1, 9, BottleneckCSP, [512]], + [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32 + [-1, 1, SPP, [1024, [5, 9, 13]]], + [-1, 3, BottleneckCSP, [1024, False]], # 9 + ] + +# YOLOv5 head +head: + [[-1, 1, Conv, [512, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 6], 1, Concat, [1]], # cat backbone P4 + [-1, 3, BottleneckCSP, [512, False]], # 13 + + [-1, 1, Conv, [256, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 4], 1, Concat, [1]], # cat backbone P3 + [-1, 3, BottleneckCSP, [256, False]], # 17 (P3/8-small) + + [-1, 1, Conv, [256, 3, 2]], + [[-1, 14], 1, Concat, [1]], # cat head P4 + [-1, 3, BottleneckCSP, [512, False]], # 20 (P4/16-medium) + + [-1, 1, Conv, [512, 3, 2]], + [[-1, 10], 1, Concat, [1]], # cat head P5 + [-1, 3, BottleneckCSP, [1024, False]], # 23 (P5/32-large) + + [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) + ] diff --git a/utils/updated_checkpoint.py b/utils/updated_checkpoint.py index d336684c..5d24f028 100644 --- a/utils/updated_checkpoint.py +++ b/utils/updated_checkpoint.py @@ -6,7 +6,13 @@ def update_ultralytics(model, checkpoint_path_ultralytics): + """ + It's limited that ultralytics saved model must load in their root path. + So a very important thing is to desensitize the path befor updating ultralytics's trained model as follows: + >>> checkpoints_ = torch.load(weights, map_location='cpu')['model'] + >>> torch.save(checkpoints_.state_dict(), './checkpoints/yolov5/yolov5s_ultralytics.pt') + """ state_dict = torch.load(checkpoint_path_ultralytics, map_location="cpu") # Update body features