From 09ddede70d1b8737c78b110711e1936fcce1c2df Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 20 Mar 2022 18:03:37 +0100 Subject: [PATCH] Update common.py lists for tuples (#7063) Improved profiling. --- models/common.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/models/common.py b/models/common.py index 5561d92ecb73..066f8774d3c3 100644 --- a/models/common.py +++ b/models/common.py @@ -31,7 +31,7 @@ def autopad(k, p=None): # kernel, padding # Pad to 'same' if p is None: - p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad + p = k // 2 if isinstance(k, int) else (x // 2 for x in k) # auto-pad return p @@ -133,7 +133,7 @@ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, nu self.cv2 = Conv(c1, c_, 1, 1) self.cv3 = Conv(2 * c_, c2, 1) # act=FReLU(c2) self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n))) - # self.m = nn.Sequential(*[CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)]) + # self.m = nn.Sequential(*(CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n))) def forward(self, x): return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1)) @@ -194,7 +194,7 @@ def forward(self, x): warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning y1 = self.m(x) y2 = self.m(y1) - return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1)) + return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1)) class Focus(nn.Module): @@ -205,7 +205,7 @@ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, k # self.contract = Contract(gain=2) def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2) - return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1)) + return self.conv(torch.cat((x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]), 1)) # return self.conv(self.contract(x)) @@ -219,7 +219,7 @@ def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, s def forward(self, x): y = self.cv1(x) - return torch.cat([y, self.cv2(y)], 1) + return torch.cat((y, self.cv2(y)), 1) class GhostBottleneck(nn.Module):