From 21ed7334df31627905613a4183f3482ee561cd9b Mon Sep 17 00:00:00 2001 From: Lumin <30328525+luminxu@users.noreply.github.com> Date: Mon, 15 Nov 2021 11:19:01 +0800 Subject: [PATCH] [Enhancement] Support minus output feature index in mobilenet_v3 (#1005) * fix typo in mobilenet_v3 * fix typo in mobilenet_v3 * use -1 to indicate output tensors from final stage * support negative out_indices --- mmpose/models/backbones/mobilenet_v3.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mmpose/models/backbones/mobilenet_v3.py b/mmpose/models/backbones/mobilenet_v3.py index f38aeb4d41..d640abec79 100644 --- a/mmpose/models/backbones/mobilenet_v3.py +++ b/mmpose/models/backbones/mobilenet_v3.py @@ -23,7 +23,7 @@ class MobileNetV3(BaseBackbone): norm_cfg (dict): Config dict for normalization layer. Default: dict(type='BN'). out_indices (None or Sequence[int]): Output from which stages. - Default: (10, ), which means output tensors from final stage. + Default: (-1, ), which means output tensors from final stage. frozen_stages (int): Stages to be frozen (all param fixed). Default: -1, which means not freezing any parameters. norm_eval (bool): Whether to set norm layers to eval mode, namely, @@ -68,7 +68,7 @@ def __init__(self, arch='small', conv_cfg=None, norm_cfg=dict(type='BN'), - out_indices=(10, ), + out_indices=(-1, ), frozen_stages=-1, norm_eval=False, with_cp=False): @@ -77,7 +77,8 @@ def __init__(self, super().__init__() assert arch in self.arch_settings for index in out_indices: - if index not in range(0, len(self.arch_settings[arch])): + if index not in range(-len(self.arch_settings[arch]), + len(self.arch_settings[arch])): raise ValueError('the item in out_indices must in ' f'range(0, {len(self.arch_settings[arch])}). ' f'But received {index}') @@ -86,8 +87,6 @@ def __init__(self, raise ValueError('frozen_stages must be in range(-1, ' f'{len(self.arch_settings[arch])}). ' f'But received {frozen_stages}') - self.out_indices = out_indices - self.frozen_stages = frozen_stages self.arch = arch self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg @@ -162,7 +161,8 @@ def forward(self, x): for i, layer_name in enumerate(self.layers): layer = getattr(self, layer_name) x = layer(x) - if i in self.out_indices: + if i in self.out_indices or \ + i - len(self.layers) in self.out_indices: outs.append(x) if len(outs) == 1: