Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] Support minus output feature index in mobilenet_v3 #1005

Merged
merged 4 commits into from
Nov 15, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions mmpose/models/backbones/mobilenet_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class MobileNetV3(BaseBackbone):
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
out_indices (None or Sequence[int]): Output from which stages.
Default: (10, ), which means output tensors from final stage.
Default: (-1, ), which means output tensors from final stage.
frozen_stages (int): Stages to be frozen (all param fixed).
Default: -1, which means not freezing any parameters.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
Expand Down Expand Up @@ -68,7 +68,7 @@ def __init__(self,
arch='small',
conv_cfg=None,
norm_cfg=dict(type='BN'),
out_indices=(10, ),
out_indices=(-1, ),
frozen_stages=-1,
norm_eval=False,
with_cp=False):
Expand All @@ -77,7 +77,8 @@ def __init__(self,
super().__init__()
assert arch in self.arch_settings
for index in out_indices:
if index not in range(0, len(self.arch_settings[arch])):
if index not in range(-len(self.arch_settings[arch]),
len(self.arch_settings[arch])):
raise ValueError('the item in out_indices must in '
f'range(0, {len(self.arch_settings[arch])}). '
f'But received {index}')
Expand All @@ -86,8 +87,6 @@ def __init__(self,
raise ValueError('frozen_stages must be in range(-1, '
f'{len(self.arch_settings[arch])}). '
f'But received {frozen_stages}')
self.out_indices = out_indices
self.frozen_stages = frozen_stages
jin-s13 marked this conversation as resolved.
Show resolved Hide resolved
self.arch = arch
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
Expand Down Expand Up @@ -162,7 +161,8 @@ def forward(self, x):
for i, layer_name in enumerate(self.layers):
layer = getattr(self, layer_name)
x = layer(x)
if i in self.out_indices:
if i in self.out_indices or \
i - len(self.layers) in self.out_indices:
outs.append(x)

if len(outs) == 1:
Expand Down