Skip to content

Commit

Permalink
[Enhancement] Support minus output feature index in mobilenet_v3 (#1005)
Browse files Browse the repository at this point in the history
* fix typo in mobilenet_v3

* fix typo in mobilenet_v3

* use -1 to indicate output tensors from final stage

* support negative out_indices
  • Loading branch information
luminxu authored Nov 15, 2021
1 parent ec2d4d0 commit 21ed733
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions mmpose/models/backbones/mobilenet_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class MobileNetV3(BaseBackbone):
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
out_indices (None or Sequence[int]): Output from which stages.
Default: (10, ), which means output tensors from final stage.
Default: (-1, ), which means output tensors from final stage.
frozen_stages (int): Stages to be frozen (all param fixed).
Default: -1, which means not freezing any parameters.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
Expand Down Expand Up @@ -68,7 +68,7 @@ def __init__(self,
arch='small',
conv_cfg=None,
norm_cfg=dict(type='BN'),
out_indices=(10, ),
out_indices=(-1, ),
frozen_stages=-1,
norm_eval=False,
with_cp=False):
Expand All @@ -77,7 +77,8 @@ def __init__(self,
super().__init__()
assert arch in self.arch_settings
for index in out_indices:
if index not in range(0, len(self.arch_settings[arch])):
if index not in range(-len(self.arch_settings[arch]),
len(self.arch_settings[arch])):
raise ValueError('the item in out_indices must in '
f'range(0, {len(self.arch_settings[arch])}). '
f'But received {index}')
Expand All @@ -86,8 +87,6 @@ def __init__(self,
raise ValueError('frozen_stages must be in range(-1, '
f'{len(self.arch_settings[arch])}). '
f'But received {frozen_stages}')
self.out_indices = out_indices
self.frozen_stages = frozen_stages
self.arch = arch
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
Expand Down Expand Up @@ -162,7 +161,8 @@ def forward(self, x):
for i, layer_name in enumerate(self.layers):
layer = getattr(self, layer_name)
x = layer(x)
if i in self.out_indices:
if i in self.out_indices or \
i - len(self.layers) in self.out_indices:
outs.append(x)

if len(outs) == 1:
Expand Down

0 comments on commit 21ed733

Please sign in to comment.