Skip to content

Commit

Permalink
Simpler code for DWConvClass (#4310)
Browse files Browse the repository at this point in the history
* more simpler code for DWConvClass

more simpler code for DWConvClass

* remove DWConv function

* Replace DWConvClass with DWConv
  • Loading branch information
developer0hye authored Aug 5, 2021
1 parent f409d8e commit e96c74b
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 9 deletions.
10 changes: 2 additions & 8 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,6 @@ def autopad(k, p=None): # kernel, padding
return p


def DWConv(c1, c2, k=1, s=1, act=True):
# Depth-wise convolution function
return Conv(c1, c2, k, s, g=math.gcd(c1, c2), act=act)


class Conv(nn.Module):
# Standard convolution
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
Expand All @@ -49,11 +44,10 @@ def forward_fuse(self, x):
return self.act(self.conv(x))


class DWConvClass(Conv):
class DWConv(Conv):
# Depth-wise convolution class
def __init__(self, c1, c2, k=1, s=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
super().__init__(c1, c2, k, s, act)
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k), groups=math.gcd(c1, c2), bias=False)
super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), act=act)


class TransformerLayer(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def _print_biases(self):
def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
LOGGER.info('Fusing layers... ')
for m in self.model.modules():
if isinstance(m, (Conv, DWConvClass)) and hasattr(m, 'bn'):
if isinstance(m, (Conv, DWConv)) and hasattr(m, 'bn'):
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
delattr(m, 'bn') # remove batchnorm
m.forward = m.forward_fuse # update forward
Expand Down

0 comments on commit e96c74b

Please sign in to comment.