Skip to content

Commit

Permalink
Update yolo.py channel array (#2223)
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher authored Feb 16, 2021
1 parent 7b833e3 commit f8464b4
Showing 1 changed file with 10 additions and 25 deletions.
35 changes: 10 additions & 25 deletions models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import logging
import sys
from copy import deepcopy
from pathlib import Path

sys.path.append('./') # to run '$ python *.py' files in subdirectories
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -213,43 +212,27 @@ def parse_model(d, ch): # model_dict, input_channels(3)
if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP,
C3]:
c1, c2 = ch[f], args[0]

# Normal
# if i > 0 and args[0] != no: # channel expansion factor
# ex = 1.75 # exponential (default 2.0)
# e = math.log(c2 / ch[1]) / math.log(2)
# c2 = int(ch[1] * ex ** e)
# if m != Focus:

c2 = make_divisible(c2 * gw, 8) if c2 != no else c2

# Experimental
# if i > 0 and args[0] != no: # channel expansion factor
# ex = 1 + gw # exponential (default 2.0)
# ch1 = 32 # ch[1]
# e = math.log(c2 / ch1) / math.log(2) # level 1-n
# c2 = int(ch1 * ex ** e)
# if m != Focus:
# c2 = make_divisible(c2, 8) if c2 != no else c2
if c2 != no: # if not output
c2 = make_divisible(c2 * gw, 8)

args = [c1, c2, *args[1:]]
if m in [BottleneckCSP, C3]:
args.insert(2, n)
args.insert(2, n) # number of repeats
n = 1
elif m is nn.BatchNorm2d:
args = [ch[f]]
elif m is Concat:
c2 = sum([ch[x if x < 0 else x + 1] for x in f])
c2 = sum([ch[x] for x in f])
elif m is Detect:
args.append([ch[x + 1] for x in f])
args.append([ch[x] for x in f])
if isinstance(args[1], int): # number of anchors
args[1] = [list(range(args[1] * 2))] * len(f)
elif m is Contract:
c2 = ch[f if f < 0 else f + 1] * args[0] ** 2
c2 = ch[f] * args[0] ** 2
elif m is Expand:
c2 = ch[f if f < 0 else f + 1] // args[0] ** 2
c2 = ch[f] // args[0] ** 2
else:
c2 = ch[f if f < 0 else f + 1]
c2 = ch[f]

m_ = nn.Sequential(*[m(*args) for _ in range(n)]) if n > 1 else m(*args) # module
t = str(m)[8:-2].replace('__main__.', '') # module type
Expand All @@ -258,6 +241,8 @@ def parse_model(d, ch): # model_dict, input_channels(3)
logger.info('%3s%18s%3s%10.0f %-40s%-30s' % (i, f, n, np, t, args)) # print
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
layers.append(m_)
if i == 0:
ch = []
ch.append(c2)
return nn.Sequential(*layers), sorted(save)

Expand Down

0 comments on commit f8464b4

Please sign in to comment.