From 4a07cb45b10b8c426c5fbc66c7cc7a628dfd0a42 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 29 Jan 2021 11:25:01 -0800 Subject: [PATCH] GhostConv update (#2082) --- models/experimental.py | 2 +- models/yolo.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/models/experimental.py b/models/experimental.py index 72dc877c83cf..5fe56858c54a 100644 --- a/models/experimental.py +++ b/models/experimental.py @@ -58,7 +58,7 @@ def forward(self, x): class GhostBottleneck(nn.Module): # Ghost Bottleneck https://github.com/huawei-noah/ghostnet - def __init__(self, c1, c2, k, s): + def __init__(self, c1, c2, k=3, s=1): # ch_in, ch_out, kernel, stride super(GhostBottleneck, self).__init__() c_ = c2 // 2 self.conv = nn.Sequential(GhostConv(c1, c_, 1, 1), # pw diff --git a/models/yolo.py b/models/yolo.py index db6ad01af541..11e6a65921a4 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -8,7 +8,7 @@ logger = logging.getLogger(__name__) from models.common import * -from models.experimental import MixConv2d, CrossConv +from models.experimental import * from utils.autoanchor import check_anchor_order from utils.general import make_divisible, check_file, set_logging from utils.torch_utils import time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \ @@ -210,7 +210,8 @@ def parse_model(d, ch): # model_dict, input_channels(3) pass n = max(round(n * gd), 1) if n > 1 else n # depth gain - if m in [Conv, Bottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3]: + if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, + C3]: c1, c2 = ch[f], args[0] # Normal