Skip to content

Commit

Permalink
add Upsample CARAFE op additonl parms and cmputecost
Browse files Browse the repository at this point in the history
  • Loading branch information
positive666 committed Sep 8, 2021
1 parent 4ecf8ff commit b632b9d
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 1 deletion.
42 changes: 42 additions & 0 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,48 @@ def forward(self, x):

return out

class CARAFE(nn.Module):
#CARAFE: Content-Aware ReAssembly of FEatures https://arxiv.org/pdf/1905.02188.pdf
def __init__(self, c1, c2, kernel_size=3, up_factor=2):
super(CARAFE, self).__init__()
self.kernel_size = kernel_size
self.up_factor = up_factor
self.down = nn.Conv2d(c1, c1 // 4, 1)
self.encoder = nn.Conv2d(c1 // 4, self.up_factor ** 2 * self.kernel_size ** 2,
self.kernel_size, 1, self.kernel_size // 2)
self.out = nn.Conv2d(c1, c2, 1)

def forward(self, x):
N, C, H, W = x.size()
# N,C,H,W -> N,C,delta*H,delta*W
# kernel prediction module
kernel_tensor = self.down(x) # (N, Cm, H, W)
kernel_tensor = self.encoder(kernel_tensor) # (N, S^2 * Kup^2, H, W)
kernel_tensor = F.pixel_shuffle(kernel_tensor, self.up_factor) # (N, S^2 * Kup^2, H, W)->(N, Kup^2, S*H, S*W)
kernel_tensor = F.softmax(kernel_tensor, dim=1) # (N, Kup^2, S*H, S*W)
kernel_tensor = kernel_tensor.unfold(2, self.up_factor, step=self.up_factor) # (N, Kup^2, H, W*S, S)
kernel_tensor = kernel_tensor.unfold(3, self.up_factor, step=self.up_factor) # (N, Kup^2, H, W, S, S)
kernel_tensor = kernel_tensor.reshape(N, self.kernel_size ** 2, H, W, self.up_factor ** 2) # (N, Kup^2, H, W, S^2)
kernel_tensor = kernel_tensor.permute(0, 2, 3, 1, 4) # (N, H, W, Kup^2, S^2)

# content-aware reassembly module
# tensor.unfold: dim, size, step
x = F.pad(x, pad=(self.kernel_size // 2, self.kernel_size // 2,
self.kernel_size // 2, self.kernel_size // 2),
mode='constant', value=0) # (N, C, H+Kup//2+Kup//2, W+Kup//2+Kup//2)
x = x.unfold(2, self.kernel_size, step=1) # (N, C, H, W+Kup//2+Kup//2, Kup)
x = x.unfold(3, self.kernel_size, step=1) # (N, C, H, W, Kup, Kup)
x = x.reshape(N, C, H, W, -1) # (N, C, H, W, Kup^2)
x = x.permute(0, 2, 3, 1, 4) # (N, H, W, C, Kup^2)

out_tensor = torch.matmul(x, kernel_tensor) # (N, H, W, C, S^2)
out_tensor = out_tensor.reshape(N, H, W, -1)
out_tensor = out_tensor.permute(0, 3, 1, 2)
out_tensor = F.pixel_shuffle(out_tensor, self.up_factor)
out_tensor = self.out(out_tensor)
#print("up shape:",out_tensor.shape)
return out_tensor

class TransformerLayer(nn.Module):
# Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance)
def __init__(self, c, num_heads):
Expand Down
2 changes: 1 addition & 1 deletion models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def parse_model(d, ch): # model_dict, input_channels(3)

n = max(round(n * gd), 1) if n > 1 else n # depth gain
if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP,CBAM,ResBlock_CBAM,
CoordAtt,CrossConv,C3,CTR3,Involution, C3SPP, C3Ghost]:
CoordAtt,CrossConv,C3,CTR3,Involution, C3SPP, C3Ghost,CARAFE]:
c1, c2 = ch[f], args[0]
if c2 != no: # if not output
c2 = make_divisible(c2 * gw, 8)
Expand Down
48 changes: 48 additions & 0 deletions models/yolov5s-carafe.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# parameters
nc: 80 # number of classes
depth_multiple: 0.33 # model depth multiple
width_multiple: 0.50 # layer channel multiple

# anchors
anchors:
- [10,13, 16,30, 33,23] # P3/8
- [30,61, 62,45, 59,119] # P4/16
- [116,90, 156,198, 373,326] # P5/32

# YOLOv5 backbone
backbone:
# [from, number, module, args]
[[-1, 1, Focus, [64, 3]], # 0-P1/2
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
[-1, 3, C3, [128]],
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
[-1, 9, C3, [256]],
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
[-1, 9, C3, [512]],
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
[-1, 1, SPP, [1024, [5, 9, 13]]],
[-1, 3, C3, [1024, False]], # 9
]

# YOLOv5 head
head:
[[-1, 1, Conv, [512, 1, 1]],
[-1, 1, CARAFE, [512,3,2]],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 3, C3, [512, False]], # 13

[-1, 1, Conv, [256, 1, 1]],
[-1, 1, CARAFE, [256,3,2]],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 3, C3, [256, False]], # 17 (P3/8-small)

[-1, 1, Conv, [256, 3, 2]],
[[-1, 14], 1, Concat, [1]], # cat head P4
[-1, 3, C3, [512, False]], # 20 (P4/16-medium)

[-1, 1, Conv, [512, 3, 2]],
[[-1, 10], 1, Concat, [1]], # cat head P5
[-1, 3, C3, [1024, False]], # 23 (P5/32-large)

[[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
]

0 comments on commit b632b9d

Please sign in to comment.