diff --git a/configs/wholebody/2d_kpt_sview_rgb_img/topdown_heatmap/coco-wholebody/tcformer_coco-wholebody.md b/configs/wholebody/2d_kpt_sview_rgb_img/topdown_heatmap/coco-wholebody/tcformer_coco-wholebody.md new file mode 100644 index 0000000000..6ab872452c --- /dev/null +++ b/configs/wholebody/2d_kpt_sview_rgb_img/topdown_heatmap/coco-wholebody/tcformer_coco-wholebody.md @@ -0,0 +1,38 @@ + + +
+TCFormer (CVPR'2022) + +```bibtex +@inproceedings{zeng2022not, + title={Not All Tokens Are Equal: Human-centric Visual Analysis via Token Clustering Transformer}, + author={Zeng, Wang and Jin, Sheng and Liu, Wentao and Qian, Chen and Luo, Ping and Ouyang, Wanli and Wang, Xiaogang}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, + pages={11101--11111}, + year={2022} +} +``` + +
+ + + +
+COCO-WholeBody (ECCV'2020) + +```bibtex +@inproceedings{jin2020whole, + title={Whole-Body Human Pose Estimation in the Wild}, + author={Jin, Sheng and Xu, Lumin and Xu, Jin and Wang, Can and Liu, Wentao and Qian, Chen and Ouyang, Wanli and Luo, Ping}, + booktitle={Proceedings of the European Conference on Computer Vision (ECCV)}, + year={2020} +} +``` + +
+ +Results on COCO-WholeBody v1.0 val with detector having human AP of 56.4 on COCO val2017 dataset + +| Arch | Input Size | Body AP | Body AR | Foot AP | Foot AR | Face AP | Face AR | Hand AP | Hand AR | Whole AP | Whole AR | ckpt | log | +| :-------------------------------------- | :--------: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :------: | :------: | :--------------------------------------: | :-------------------------------------: | +| [tcformer](/configs/wholebody/2d_kpt_sview_rgb_img/topdown_heatmap/coco-wholebody/tcformer_coco_wholebody_256x192.py) | 256x192 | 0.691 | 0.769 | 0.690 | 0.809 | 0.650 | 0.747 | 0.534 | 0.647 | 0.574 | 0.678 | [ckpt](https://download.openmmlab.com/mmpose/top_down/tcformer/tcformer_coco-wholebody_256x192-a0720efa_20220627.pth) | [log](https://download.openmmlab.com/mmpose/top_down/tcformer/tcformer_coco-wholebody_256x192_20220627.log.json) | diff --git a/configs/wholebody/2d_kpt_sview_rgb_img/topdown_heatmap/coco-wholebody/tcformer_coco-wholebody.yml b/configs/wholebody/2d_kpt_sview_rgb_img/topdown_heatmap/coco-wholebody/tcformer_coco-wholebody.yml new file mode 100644 index 0000000000..8504e196d1 --- /dev/null +++ b/configs/wholebody/2d_kpt_sview_rgb_img/topdown_heatmap/coco-wholebody/tcformer_coco-wholebody.yml @@ -0,0 +1,30 @@ +Collections: +- Name: TCFormer + Paper: + Title: 'Not All Tokens Are Equal: Human-centric Visual Analysis via Token Clustering + Transformer' + URL: https://openaccess.thecvf.com/content/CVPR2022/html/Zeng_Not_All_Tokens_Are_Equal_Human-Centric_Visual_Analysis_via_Token_CVPR_2022_paper.html + README: https://github.com/open-mmlab/mmpose/blob/master/docs/en/papers/backbones/tcformer.md +Models: +- Config: configs/wholebody/2d_kpt_sview_rgb_img/topdown_heatmap/coco-wholebody/tcformer_coco_wholebody_256x192.py + In Collection: TCFormer + Metadata: + Architecture: + - TCFormer + Training Data: COCO-WholeBody + Name: topdown_heatmap_tcformer_coco_wholebody_256x192 + Results: + - Dataset: COCO-WholeBody + Metrics: + Body AP: 0.691 + Body AR: 0.769 + Face AP: 0.65 + Face AR: 0.747 + Foot AP: 0.69 + Foot AR: 0.809 + Hand AP: 0.534 + Hand AR: 0.647 + Whole AP: 0.574 + Whole AR: 0.678 + Task: Wholebody 2D Keypoint + Weights: https://download.openmmlab.com/mmpose/top_down/tcformer/tcformer_coco-wholebody_256x192-a0720efa_20220627.pth diff --git a/configs/wholebody/2d_kpt_sview_rgb_img/topdown_heatmap/coco-wholebody/tcformer_coco_wholebody_256x192.py b/configs/wholebody/2d_kpt_sview_rgb_img/topdown_heatmap/coco-wholebody/tcformer_coco_wholebody_256x192.py new file mode 100644 index 0000000000..73d7898686 --- /dev/null +++ b/configs/wholebody/2d_kpt_sview_rgb_img/topdown_heatmap/coco-wholebody/tcformer_coco_wholebody_256x192.py @@ -0,0 +1,171 @@ +_base_ = ['../../../../_base_/datasets/coco_wholebody.py'] +log_level = 'INFO' +load_from = None +resume_from = None +dist_params = dict(backend='nccl') +workflow = [('train', 1)] +checkpoint_config = dict(interval=10) +evaluation = dict(interval=10, metric='mAP', save_best='AP') + +optimizer = dict( + type='AdamW', + lr=5e-4, + betas=(0.9, 0.999), + weight_decay=0.01, +) + +optimizer_config = dict(grad_clip=None) +# learning policy +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=500, + warmup_ratio=0.001, + step=[170, 200]) +total_epochs = 210 +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook'), + # dict(type='TensorboardLoggerHook') + ]) + +channel_cfg = dict( + num_output_channels=133, + dataset_joints=133, + dataset_channel=[ + list(range(133)), + ], + inference_channel=list(range(133))) + +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='TopDown', + pretrained='https://download.openmmlab.com/mmpose/' + 'pretrain_models/tcformer-4e1adbf1_20220421.pth', + backbone=dict( + type='TCFormer', + embed_dims=[64, 128, 320, 512], + num_heads=[1, 2, 5, 8], + mlp_ratios=[8, 8, 4, 4], + qkv_bias=True, + num_layers=[3, 4, 6, 3], + sr_ratios=[8, 4, 2, 1], + drop_path_rate=0.1), + neck=dict( + type='MTA', + in_channels=[64, 128, 320, 512], + out_channels=256, + start_level=0, + num_heads=[4, 4, 4, 4], + mlp_ratios=[4, 4, 4, 4], + num_outs=4, + use_sr_conv=False, + ), + keypoint_head=dict( + type='TopdownHeatmapSimpleHead', + in_channels=256, + out_channels=channel_cfg['num_output_channels'], + num_deconv_layers=0, + extra=dict(final_conv_kernel=1, ), + loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)), + train_cfg=dict(), + test_cfg=dict( + flip_test=True, + post_process='default', + shift_heatmap=True, + modulate_kernel=11)) + +data_cfg = dict( + image_size=[192, 256], + heatmap_size=[48, 64], + num_output_channels=channel_cfg['num_output_channels'], + num_joints=channel_cfg['dataset_joints'], + dataset_channel=channel_cfg['dataset_channel'], + inference_channel=channel_cfg['inference_channel'], + soft_nms=False, + nms_thr=1.0, + oks_thr=0.9, + vis_thr=0.2, + use_gt_bbox=False, + det_bbox_thr=0.0, + bbox_file='data/coco/person_detection_results/' + 'COCO_val2017_detections_AP_H_56_person.json', +) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='TopDownGetBboxCenterScale', padding=1.25), + dict(type='TopDownRandomShiftBboxCenter', shift_factor=0.16, prob=0.3), + dict(type='TopDownRandomFlip', flip_prob=0.5), + dict( + type='TopDownHalfBodyTransform', + num_joints_half_body=8, + prob_half_body=0.3), + dict( + type='TopDownGetRandomScaleRotation', rot_factor=40, scale_factor=0.5), + dict(type='TopDownAffine'), + dict(type='ToTensor'), + dict( + type='NormalizeTensor', + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), + dict(type='TopDownGenerateTarget', sigma=2), + dict( + type='Collect', + keys=['img', 'target', 'target_weight'], + meta_keys=[ + 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', + 'rotation', 'bbox_score', 'flip_pairs' + ]), +] + +val_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='TopDownGetBboxCenterScale', padding=1.25), + dict(type='TopDownAffine'), + dict(type='ToTensor'), + dict( + type='NormalizeTensor', + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), + dict( + type='Collect', + keys=['img'], + meta_keys=[ + 'image_file', 'center', 'scale', 'rotation', 'bbox_score', + 'flip_pairs' + ]), +] + +test_pipeline = val_pipeline + +data_root = 'data/coco' +data = dict( + samples_per_gpu=64, + workers_per_gpu=2, + val_dataloader=dict(samples_per_gpu=32), + test_dataloader=dict(samples_per_gpu=32), + train=dict( + type='TopDownCocoWholeBodyDataset', + ann_file=f'{data_root}/annotations/coco_wholebody_train_v1.0.json', + img_prefix=f'{data_root}/train2017/', + data_cfg=data_cfg, + pipeline=train_pipeline, + dataset_info={{_base_.dataset_info}}), + val=dict( + type='TopDownCocoWholeBodyDataset', + ann_file=f'{data_root}/annotations/coco_wholebody_val_v1.0.json', + img_prefix=f'{data_root}/val2017/', + data_cfg=data_cfg, + pipeline=val_pipeline, + dataset_info={{_base_.dataset_info}}), + test=dict( + type='TopDownCocoWholeBodyDataset', + ann_file=f'{data_root}/annotations/coco_wholebody_val_v1.0.json', + img_prefix=f'{data_root}/val2017/', + data_cfg=data_cfg, + pipeline=val_pipeline, + dataset_info={{_base_.dataset_info}}), +) diff --git a/docs/en/papers/backbones/tcformer.md b/docs/en/papers/backbones/tcformer.md new file mode 100644 index 0000000000..26e7714bba --- /dev/null +++ b/docs/en/papers/backbones/tcformer.md @@ -0,0 +1,53 @@ +# Not All Tokens Are Equal: Human-Centric Visual Analysis via Token Clustering Transformer + + + +
+TCFormer (CVPR'2022) + +```bibtex +@inproceedings{zeng2022not, + title={Not All Tokens Are Equal: Human-centric Visual Analysis via Token Clustering Transformer}, + author={Zeng, Wang and Jin, Sheng and Liu, Wentao and Qian, Chen and Luo, Ping and Ouyang, Wanli and Wang, Xiaogang}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, + pages={11101--11111}, + year={2022} +} +``` + +
+ +## Abstract + + + +Vision transformers have achieved great successes in +many computer vision tasks. Most methods generate +vision tokens by splitting an image into a regular +and fixed grid and treating each cell as a token. +However, not all regions are equally important in +human-centric vision tasks, e.g., the human body +needs a fine representation with many tokens, while +the image background can be modeled by a few tokens. +To address this problem, we propose a novel Vision +Transformer, called Token Clustering Transformer +(TCFormer), which merges tokens by progressive +clustering, where the tokens can be merged from +different locations with flexible shapes and sizes. +The tokens in TCFormer can not only focus on important +areas but also adjust the token shapes to fit the +semantic concept and adopt a fine resolution for +regions containing critical details, which is +beneficial to capturing detailed information. +Extensive experiments show that TCFormer consistently +outperforms its counterparts on different challenging +humancentric tasks and datasets, including whole-body +pose estimation on COCO-WholeBody and 3D human mesh +reconstruction on 3DPW. Code is available at +https://github.com/zengwang430521/TCFormer.git. + + + +
+ +
diff --git a/mmpose/models/backbones/__init__.py b/mmpose/models/backbones/__init__.py index 09745d443c..2fc64a8af3 100644 --- a/mmpose/models/backbones/__init__.py +++ b/mmpose/models/backbones/__init__.py @@ -22,6 +22,7 @@ from .shufflenet_v1 import ShuffleNetV1 from .shufflenet_v2 import ShuffleNetV2 from .swin import SwinTransformer +from .tcformer import TCFormer from .tcn import TCN from .v2v_net import V2VNet from .vgg import VGG @@ -34,5 +35,5 @@ 'SEResNet', 'SEResNeXt', 'ShuffleNetV1', 'ShuffleNetV2', 'CPM', 'RSN', 'MSPN', 'ResNeSt', 'VGG', 'TCN', 'ViPNAS_ResNet', 'ViPNAS_MobileNetV3', 'LiteHRNet', 'V2VNet', 'HRFormer', 'PyramidVisionTransformer', - 'PyramidVisionTransformerV2', 'SwinTransformer', 'I3D' + 'PyramidVisionTransformerV2', 'SwinTransformer', 'I3D', 'TCFormer' ] diff --git a/mmpose/models/backbones/tcformer.py b/mmpose/models/backbones/tcformer.py new file mode 100644 index 0000000000..99cd44d266 --- /dev/null +++ b/mmpose/models/backbones/tcformer.py @@ -0,0 +1,284 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.nn as nn +from mmcv.cnn import (build_norm_layer, constant_init, normal_init, + trunc_normal_init) +from mmcv.runner import _load_checkpoint, load_state_dict + +from ...utils import get_root_logger +from ..builder import BACKBONES +from ..utils import (PatchEmbed, TCFormerDynamicBlock, TCFormerRegularBlock, + TokenConv, cluster_dpc_knn, merge_tokens, + tcformer_convert, token2map) + + +class CTM(nn.Module): + """Clustering-based Token Merging module in TCFormer. + + Args: + sample_ratio (float): The sample ratio of tokens. + embed_dim (int): Input token feature dimension. + dim_out (int): Output token feature dimension. + k (int): number of the nearest neighbor used i DPC-knn algorithm. + """ + + def __init__(self, sample_ratio, embed_dim, dim_out, k=5): + super().__init__() + self.sample_ratio = sample_ratio + self.dim_out = dim_out + self.conv = TokenConv( + in_channels=embed_dim, + out_channels=dim_out, + kernel_size=3, + stride=2, + padding=1) + self.norm = nn.LayerNorm(self.dim_out) + self.score = nn.Linear(self.dim_out, 1) + self.k = k + + def forward(self, token_dict): + token_dict = token_dict.copy() + x = self.conv(token_dict) + x = self.norm(x) + token_score = self.score(x) + token_weight = token_score.exp() + + token_dict['x'] = x + B, N, C = x.shape + token_dict['token_score'] = token_score + + cluster_num = max(math.ceil(N * self.sample_ratio), 1) + idx_cluster, cluster_num = cluster_dpc_knn(token_dict, cluster_num, + self.k) + down_dict = merge_tokens(token_dict, idx_cluster, cluster_num, + token_weight) + + H, W = token_dict['map_size'] + H = math.floor((H - 1) / 2 + 1) + W = math.floor((W - 1) / 2 + 1) + down_dict['map_size'] = [H, W] + + return down_dict, token_dict + + +@BACKBONES.register_module() +class TCFormer(nn.Module): + """Token Clustering Transformer (TCFormer) + + Implementation of `Not All Tokens Are Equal: Human-centric Visual + Analysis via Token Clustering Transformer + ` + + Args: + in_channels (int): Number of input channels. Default: 3. + embed_dims (list[int]): Embedding dimension. Default: + [64, 128, 256, 512]. + num_heads (Sequence[int]): The attention heads of each transformer + encode layer. Default: [1, 2, 5, 8]. + mlp_ratios (Sequence[int]): The ratio of the mlp hidden dim to the + embedding dim of each transformer block. + qkv_bias (bool): Enable bias for qkv if True. Default: True. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + drop_rate (float): Probability of an element to be zeroed. + Default 0.0. + attn_drop_rate (float): The drop out rate for attention layer. + Default 0.0. + drop_path_rate (float): stochastic depth rate. Default 0. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN', eps=1e-6). + num_layers (Sequence[int]): The layer number of each transformer encode + layer. Default: [3, 4, 6, 3]. + sr_ratios (Sequence[int]): The spatial reduction rate of each + transformer block. Default: [8, 4, 2, 1]. + num_stages (int): The num of stages. Default: 4. + pretrained (str, optional): model pretrained path. Default: None. + k (int): number of the nearest neighbor used for local density. + sample_ratios (list[float]): The sample ratios of CTM modules. + Default: [0.25, 0.25, 0.25] + return_map (bool): If True, transfer dynamic tokens to feature map at + last. Default: False + convert_weights (bool): The flag indicates whether the + pre-trained model is from the original repo. We may need + to convert some keys to make it compatible. + Default: True. + """ + + def __init__(self, + in_channels=3, + embed_dims=[64, 128, 256, 512], + num_heads=[1, 2, 4, 8], + mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + norm_cfg=dict(type='LN', eps=1e-6), + num_layers=[3, 4, 6, 3], + sr_ratios=[8, 4, 2, 1], + num_stages=4, + pretrained=None, + k=5, + sample_ratios=[0.25, 0.25, 0.25], + return_map=False, + convert_weights=True): + super().__init__() + + self.num_layers = num_layers + self.num_stages = num_stages + self.grid_stride = sr_ratios[0] + self.embed_dims = embed_dims + self.sr_ratios = sr_ratios + self.mlp_ratios = mlp_ratios + self.sample_ratios = sample_ratios + self.return_map = return_map + self.convert_weights = convert_weights + + # stochastic depth decay rule + dpr = [ + x.item() + for x in torch.linspace(0, drop_path_rate, sum(num_layers)) + ] + cur = 0 + + # In stage 1, use the standard transformer blocks + for i in range(1): + patch_embed = PatchEmbed( + in_channels=in_channels if i == 0 else embed_dims[i - 1], + embed_dims=embed_dims[i], + kernel_size=7, + stride=4, + padding=3, + bias=True, + norm_cfg=dict(type='LN', eps=1e-6)) + + block = nn.ModuleList([ + TCFormerRegularBlock( + dim=embed_dims[i], + num_heads=num_heads[i], + mlp_ratio=mlp_ratios[i], + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[cur + j], + norm_cfg=norm_cfg, + sr_ratio=sr_ratios[i]) for j in range(num_layers[i]) + ]) + norm = build_norm_layer(norm_cfg, embed_dims[i])[1] + + cur += num_layers[i] + + setattr(self, f'patch_embed{i + 1}', patch_embed) + setattr(self, f'block{i + 1}', block) + setattr(self, f'norm{i + 1}', norm) + + # In stage 2~4, use TCFormerDynamicBlock for dynamic tokens + for i in range(1, num_stages): + ctm = CTM(sample_ratios[i - 1], embed_dims[i - 1], embed_dims[i], + k) + + block = nn.ModuleList([ + TCFormerDynamicBlock( + dim=embed_dims[i], + num_heads=num_heads[i], + mlp_ratio=mlp_ratios[i], + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[cur + j], + norm_cfg=norm_cfg, + sr_ratio=sr_ratios[i]) for j in range(num_layers[i]) + ]) + norm = build_norm_layer(norm_cfg, embed_dims[i])[1] + cur += num_layers[i] + + setattr(self, f'ctm{i}', ctm) + setattr(self, f'block{i + 1}', block) + setattr(self, f'norm{i + 1}', norm) + + self.init_weights(pretrained) + + def init_weights(self, pretrained=None): + if isinstance(pretrained, str): + logger = get_root_logger() + + checkpoint = _load_checkpoint( + pretrained, logger=logger, map_location='cpu') + logger.warning(f'Load pre-trained model for ' + f'{self.__class__.__name__} from original repo') + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + elif 'model' in checkpoint: + state_dict = checkpoint['model'] + else: + state_dict = checkpoint + + if self.convert_weights: + # We need to convert pre-trained weights to match this + # implementation. + state_dict = tcformer_convert(state_dict) + load_state_dict(self, state_dict, strict=False, logger=logger) + + elif pretrained is None: + for m in self.modules(): + if isinstance(m, nn.Linear): + trunc_normal_init(m, std=.02, bias=0.) + elif isinstance(m, nn.LayerNorm): + constant_init(m, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[ + 1] * m.out_channels + fan_out //= m.groups + normal_init(m, 0, math.sqrt(2.0 / fan_out)) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + outs = [] + + i = 0 + patch_embed = getattr(self, f'patch_embed{i + 1}') + block = getattr(self, f'block{i + 1}') + norm = getattr(self, f'norm{i + 1}') + x, (H, W) = patch_embed(x) + for blk in block: + x = blk(x, H, W) + x = norm(x) + + # init token dict + B, N, _ = x.shape + device = x.device + idx_token = torch.arange(N)[None, :].repeat(B, 1).to(device) + agg_weight = x.new_ones(B, N, 1) + token_dict = { + 'x': x, + 'token_num': N, + 'map_size': [H, W], + 'init_grid_size': [H, W], + 'idx_token': idx_token, + 'agg_weight': agg_weight + } + outs.append(token_dict.copy()) + + # stage 2~4 + for i in range(1, self.num_stages): + ctm = getattr(self, f'ctm{i}') + block = getattr(self, f'block{i + 1}') + norm = getattr(self, f'norm{i + 1}') + + token_dict = ctm(token_dict) # down sample + for j, blk in enumerate(block): + token_dict = blk(token_dict) + + token_dict['x'] = norm(token_dict['x']) + outs.append(token_dict) + + if self.return_map: + outs = [token2map(token_dict) for token_dict in outs] + + return outs diff --git a/mmpose/models/necks/__init__.py b/mmpose/models/necks/__init__.py index 7f4160f84d..16381b2659 100644 --- a/mmpose/models/necks/__init__.py +++ b/mmpose/models/necks/__init__.py @@ -2,5 +2,6 @@ from .fpn import FPN from .gap_neck import GlobalAveragePooling from .posewarper_neck import PoseWarperNeck +from .tcformer_mta_neck import MTA -__all__ = ['GlobalAveragePooling', 'PoseWarperNeck', 'FPN'] +__all__ = ['GlobalAveragePooling', 'PoseWarperNeck', 'FPN', 'MTA'] diff --git a/mmpose/models/necks/tcformer_mta_neck.py b/mmpose/models/necks/tcformer_mta_neck.py new file mode 100644 index 0000000000..6723fb018e --- /dev/null +++ b/mmpose/models/necks/tcformer_mta_neck.py @@ -0,0 +1,224 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule, constant_init, normal_init, trunc_normal_init +from mmcv.runner import BaseModule + +from ..builder import NECKS +from ..utils import TCFormerDynamicBlock, token2map, token_interp + + +@NECKS.register_module() +class MTA(BaseModule): + """Multi-stage Token feature Aggregation (MTA) module in TCFormer. + + Args: + in_channels (list[int]): Number of input channels per stage. + Default: [64, 128, 256, 512]. + out_channels (int): Number of output channels (used at each scale). + num_outs (int): Number of output scales. Default: 4. + start_level (int): Index of the start input backbone level used to + build the feature pyramid. Default: 0. + end_level (int): Index of the end input backbone level (exclusive) to + build the feature pyramid. Default: -1, which means the last level. + add_extra_convs (bool | str): If bool, it decides whether to add conv + layers on top of the original feature maps. Default to False. + If True, it is equivalent to `add_extra_convs='on_input'`. + If str, it specifies the source feature map of the extra convs. + Only the following options are allowed + - 'on_input': Last feat map of neck inputs (i.e. backbone feature). + - 'on_output': The last output feature map after fpn convs. + relu_before_extra_convs (bool): Whether to apply relu before the extra + conv. Default: False. + no_norm_on_lateral (bool): Whether to apply norm on lateral. + Default: False. + conv_cfg (dict): Config dict for convolution layer. Default: None. + norm_cfg (dict): Config dict for normalization layer. Default: None. + act_cfg (dict): Config dict for activation layer in ConvModule. + num_heads (Sequence[int]): The attention heads of each transformer + block. Default: [2, 2, 2, 2]. + mlp_ratios (Sequence[int]): The ratio of the mlp hidden dim to the + embedding dim of each transformer block. + sr_ratios (Sequence[int]): The spatial reduction rate of each + transformer block. Default: [8, 4, 2, 1]. + qkv_bias (bool): Enable bias for qkv if True. Default: True. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + drop_rate (float): Probability of an element to be zeroed. + Default 0.0. + attn_drop_rate (float): The drop out rate for attention layer. + Default 0.0. + drop_path_rate (float): stochastic depth rate. Default 0. + transformer_norm_cfg (dict): Config dict for normalization layer + in transformer blocks. Default: dict(type='LN'). + use_sr_conv (bool): If True, use a conv layer for spatial reduction. + If False, use a pooling process for spatial reduction. Defaults: + False. + """ + + def __init__( + self, + in_channels=[64, 128, 256, 512], + out_channels=128, + num_outs=4, + start_level=0, + end_level=-1, + add_extra_convs=False, + relu_before_extra_convs=False, + no_norm_on_lateral=False, + conv_cfg=None, + norm_cfg=None, + act_cfg=None, + num_heads=[2, 2, 2, 2], + mlp_ratios=[4, 4, 4, 4], + sr_ratios=[8, 4, 2, 1], + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + transformer_norm_cfg=dict(type='LN'), + use_sr_conv=False, + ): + super().__init__() + assert isinstance(in_channels, list) + self.in_channels = in_channels + self.out_channels = out_channels + self.num_ins = len(in_channels) + self.num_outs = num_outs + self.no_norm_on_lateral = no_norm_on_lateral + self.fp16_enabled = False + self.norm_cfg = norm_cfg + self.conv_cfg = conv_cfg + self.act_cfg = act_cfg + self.mlp_ratios = mlp_ratios + + if end_level == -1 or end_level == self.num_ins - 1: + self.backbone_end_level = self.num_ins + assert num_outs >= self.num_ins - start_level + else: + # if end_level is not the last level, no extra level is allowed + self.backbone_end_level = end_level + 1 + assert end_level < self.num_ins + assert num_outs == end_level - start_level + 1 + self.start_level = start_level + self.end_level = end_level + + self.lateral_convs = nn.ModuleList() + self.merge_blocks = nn.ModuleList() + + for i in range(self.start_level, self.backbone_end_level): + l_conv = ConvModule( + in_channels[i], + out_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg if not self.no_norm_on_lateral else None, + act_cfg=act_cfg, + inplace=False) + self.lateral_convs.append(l_conv) + + for i in range(self.start_level, self.backbone_end_level - 1): + merge_block = TCFormerDynamicBlock( + dim=out_channels, + num_heads=num_heads[i], + mlp_ratio=mlp_ratios[i], + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=drop_path_rate, + norm_cfg=transformer_norm_cfg, + sr_ratio=sr_ratios[i], + use_sr_conv=use_sr_conv) + self.merge_blocks.append(merge_block) + + # add extra conv layers (e.g., RetinaNet) + self.relu_before_extra_convs = relu_before_extra_convs + + self.add_extra_convs = add_extra_convs + assert isinstance(add_extra_convs, (str, bool)) + if isinstance(add_extra_convs, str): + # Extra_convs_source choices: 'on_input', 'on_output' + assert add_extra_convs in ('on_input', 'on_output') + elif add_extra_convs: # True + self.add_extra_convs = 'on_input' + + self.extra_convs = nn.ModuleList() + extra_levels = num_outs - (self.end_level + 1 - self.start_level) + if self.add_extra_convs and extra_levels >= 1: + for i in range(extra_levels): + if i == 0 and self.add_extra_convs == 'on_input': + in_channels = self.in_channels[self.end_level] + else: + in_channels = out_channels + extra_fpn_conv = ConvModule( + in_channels, + out_channels, + 3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + self.extra_convs.append(extra_fpn_conv) + + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + trunc_normal_init(m, std=.02, bias=0.) + elif isinstance(m, nn.LayerNorm): + constant_init(m, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + normal_init(m, 0, math.sqrt(2.0 / fan_out)) + + def forward(self, inputs): + """Forward function.""" + assert len(inputs) == len(self.in_channels) + + # build lateral tokens + input_dicts = [] + for i, lateral_conv in enumerate(self.lateral_convs): + tmp = inputs[i + self.start_level].copy() + tmp['x'] = lateral_conv(tmp['x'].unsqueeze(2).permute( + 0, 3, 1, 2)).permute(0, 2, 3, 1).squeeze(2) + input_dicts.append(tmp) + + # merge from high level to low level + for i in range(len(input_dicts) - 2, -1, -1): + input_dicts[i]['x'] = input_dicts[i]['x'] + token_interp( + input_dicts[i], input_dicts[i + 1]) + input_dicts[i] = self.merge_blocks[i](input_dicts[i]) + + # transform to feature map + outs = [token2map(token_dict) for token_dict in input_dicts] + + # part 2: add extra levels + used_backbone_levels = len(outs) + if self.num_outs > len(outs): + # use max pool to get more levels on top of outputs + if not self.add_extra_convs: + for i in range(self.num_outs - used_backbone_levels): + outs.append(F.max_pool2d(outs[-1], 1, stride=2)) + # add conv layers on top of original feature maps + else: + if self.add_extra_convs == 'on_input': + tmp = inputs[self.backbone_end_level - 1] + extra_source = token2map(tmp) + elif self.add_extra_convs == 'on_output': + extra_source = outs[-1] + else: + raise NotImplementedError + + outs.append(self.extra_convs[0](extra_source)) + for i in range(1, self.num_outs - used_backbone_levels): + if self.relu_before_extra_convs: + outs.append(self.extra_convs[i](F.relu(outs[-1]))) + else: + outs.append(self.extra_convs[i](outs[-1])) + return outs diff --git a/mmpose/models/utils/__init__.py b/mmpose/models/utils/__init__.py index 9be0b9a08c..d2ed972655 100644 --- a/mmpose/models/utils/__init__.py +++ b/mmpose/models/utils/__init__.py @@ -1,14 +1,19 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .ckpt_convert import pvt_convert +from .ckpt_convert import pvt_convert, tcformer_convert from .geometry import batch_rodrigues, quat_to_rotmat, rot6d_to_rotmat from .misc import torch_meshgrid_ij from .ops import resize from .realnvp import RealNVP from .smpl import SMPL +from .tcformer_utils import (TCFormerDynamicBlock, TCFormerRegularBlock, + TokenConv, cluster_dpc_knn, merge_tokens, + token2map, token_interp) from .transformer import PatchEmbed, PatchMerging, nchw_to_nlc, nlc_to_nchw __all__ = [ 'SMPL', 'PatchEmbed', 'nchw_to_nlc', 'nlc_to_nchw', 'pvt_convert', 'PatchMerging', 'batch_rodrigues', 'quat_to_rotmat', 'rot6d_to_rotmat', - 'resize', 'RealNVP', 'torch_meshgrid_ij' + 'resize', 'RealNVP', 'torch_meshgrid_ij', 'token2map', 'TokenConv', + 'TCFormerRegularBlock', 'TCFormerDynamicBlock', 'cluster_dpc_knn', + 'merge_tokens', 'token_interp', 'tcformer_convert' ] diff --git a/mmpose/models/utils/ckpt_convert.py b/mmpose/models/utils/ckpt_convert.py index 05f5cdb4a3..f5213937db 100644 --- a/mmpose/models/utils/ckpt_convert.py +++ b/mmpose/models/utils/ckpt_convert.py @@ -80,3 +80,15 @@ def pvt_convert(ckpt): new_ckpt[new_k] = new_v return new_ckpt + + +def tcformer_convert(ckpt): + new_ckpt = OrderedDict() + # Process the concat between q linear weights and kv linear weights + for k, v in ckpt.items(): + if 'patch_embed' in k: + new_k = k.replace('.proj.', '.projection.') + else: + new_k = k + new_ckpt[new_k] = v + return new_ckpt diff --git a/mmpose/models/utils/tcformer_utils.py b/mmpose/models/utils/tcformer_utils.py new file mode 100644 index 0000000000..ff7e09dba4 --- /dev/null +++ b/mmpose/models/utils/tcformer_utils.py @@ -0,0 +1,960 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import build_norm_layer, trunc_normal_init +from mmcv.cnn.bricks.transformer import build_dropout + + +def get_grid_index(init_grid_size, map_size, device): + """For every initial grid, get its index in the feature map. + Note: + [H_init, W_init]: shape of initial grid + [H, W]: shape of feature map + N_init: numbers of initial token + + Args: + init_grid_size (list[int] or tuple[int]): initial grid resolution in + format [H_init, W_init]. + map_size (list[int] or tuple[int]): feature map resolution in format + [H, W]. + device: the device of output + + Returns: + idx (torch.LongTensor[B, N_init]): index in flattened feature map. + """ + H_init, W_init = init_grid_size + H, W = map_size + idx = torch.arange(H * W, device=device).reshape(1, 1, H, W) + idx = F.interpolate(idx.float(), [H_init, W_init], mode='nearest').long() + return idx.flatten() + + +def index_points(points, idx): + """Sample features following the index. + Note: + B: batch size + N: point number + C: channel number of each point + Ns: sampled point number + + Args: + points (torch.Tensor[B, N, C]): input points data + idx (torch.LongTensor[B, S]): sample index + + Returns: + new_points (torch.Tensor[B, Ns, C]):, indexed points data + """ + device = points.device + B = points.shape[0] + view_shape = list(idx.shape) + view_shape[1:] = [1] * (len(view_shape) - 1) + repeat_shape = list(idx.shape) + repeat_shape[0] = 1 + batch_indices = torch.arange( + B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) + new_points = points[batch_indices, idx, :] + return new_points + + +def token2map(token_dict): + """Transform vision tokens to feature map. This function only works when + the resolution of the feature map is not higher than the initial grid + structure. + + Note: + B: batch size + C: channel number of each token + [H, W]: shape of feature map + N_init: numbers of initial token + + Args: + token_dict (dict): dict for token information. + + Returns: + x_out (Tensor[B, C, H, W]): feature map. + """ + + x = token_dict['x'] + H, W = token_dict['map_size'] + H_init, W_init = token_dict['init_grid_size'] + idx_token = token_dict['idx_token'] + B, N, C = x.shape + N_init = H_init * W_init + device = x.device + + if N_init == N and N == H * W: + # for the initial tokens with grid structure, just reshape + return x.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous() + + # for each initial grid, get the corresponding index in + # the flattened feature map. + idx_hw = get_grid_index([H_init, W_init], [H, W], + device=device)[None, :].expand(B, -1) + idx_batch = torch.arange(B, device=device)[:, None].expand(B, N_init) + value = x.new_ones(B * N_init) + + # choose the way with fewer flops. + if N_init < N * H * W: + # use sparse matrix multiplication + # Flops: B * N_init * (C+2) + idx_hw = idx_hw + idx_batch * H * W + idx_tokens = idx_token + idx_batch * N + coor = torch.stack([idx_hw, idx_tokens], dim=0).reshape(2, B * N_init) + + # torch.sparse do not support gradient for + # sparse tensor, so we detach it + value = value.detach().to(torch.float32) + + # build a sparse matrix with the shape [B * H * W, B * N] + A = torch.sparse.FloatTensor(coor, value, + torch.Size([B * H * W, B * N])) + + # normalize the weight for each row + all_weight = A @ x.new_ones(B * N, 1).type(torch.float32) + 1e-6 + value = value / all_weight[idx_hw.reshape(-1), 0] + + # update the matrix with normalize weight + A = torch.sparse.FloatTensor(coor, value, + torch.Size([B * H * W, B * N])) + + # sparse matrix multiplication + x_out = A @ x.reshape(B * N, C).to(torch.float32) # [B*H*W, C] + + else: + # use dense matrix multiplication + # Flops: B * N * H * W * (C+2) + coor = torch.stack([idx_batch, idx_hw, idx_token], + dim=0).reshape(3, B * N_init) + + # build a matrix with shape [B, H*W, N] + A = torch.sparse.FloatTensor(coor, value, torch.Size([B, H * W, + N])).to_dense() + # normalize the weight + A = A / (A.sum(dim=-1, keepdim=True) + 1e-6) + + x_out = A @ x # [B, H*W, C] + + x_out = x_out.type(x.dtype) + x_out = x_out.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous() + return x_out + + +def map2token(feature_map, token_dict): + """Transform feature map to vision tokens. This function only works when + the resolution of the feature map is not higher than the initial grid + structure. + + Note: + B: batch size + C: channel number + [H, W]: shape of feature map + N_init: numbers of initial token + + Args: + feature_map (Tensor[B, C, H, W]): feature map. + token_dict (dict): dict for token information. + + Returns: + out (Tensor[B, N, C]): token features. + """ + idx_token = token_dict['idx_token'] + N = token_dict['token_num'] + H_init, W_init = token_dict['init_grid_size'] + N_init = H_init * W_init + + B, C, H, W = feature_map.shape + device = feature_map.device + + if N_init == N and N == H * W: + # for the initial tokens with grid structure, just reshape + return feature_map.flatten(2).permute(0, 2, 1).contiguous() + + idx_hw = get_grid_index([H_init, W_init], [H, W], + device=device)[None, :].expand(B, -1) + + idx_batch = torch.arange(B, device=device)[:, None].expand(B, N_init) + value = feature_map.new_ones(B * N_init) + + # choose the way with fewer flops. + if N_init < N * H * W: + # use sparse matrix multiplication + # Flops: B * N_init * (C+2) + idx_token = idx_token + idx_batch * N + idx_hw = idx_hw + idx_batch * H * W + indices = torch.stack([idx_token, idx_hw], dim=0).reshape(2, -1) + + # sparse mm do not support gradient for sparse matrix + value = value.detach().to(torch.float32) + # build a sparse matrix with shape [B*N, B*H*W] + A = torch.sparse_coo_tensor(indices, value, (B * N, B * H * W)) + # normalize the matrix + all_weight = A @ torch.ones( + [B * H * W, 1], device=device, dtype=torch.float32) + 1e-6 + value = value / all_weight[idx_token.reshape(-1), 0] + + A = torch.sparse_coo_tensor(indices, value, (B * N, B * H * W)) + # out: [B*N, C] + out = A @ feature_map. \ + permute(0, 2, 3, 1).contiguous().reshape(B * H * W, C).float() + else: + # use dense matrix multiplication + # Flops: B * N * H * W * (C+2) + indices = torch.stack([idx_batch, idx_token, idx_hw], + dim=0).reshape(3, -1) + value = value.detach() # To reduce the training time, we detach here. + A = torch.sparse_coo_tensor(indices, value, (B, N, H * W)).to_dense() + # normalize the matrix + A = A / (A.sum(dim=-1, keepdim=True) + 1e-6) + + out = A @ feature_map.permute(0, 2, 3, 1).reshape(B, H * W, + C).contiguous() + + out = out.type(feature_map.dtype) + out = out.reshape(B, N, C) + return out + + +def token_interp(target_dict, source_dict): + """Transform token features between different distribution. + + Note: + B: batch size + N: token number + C: channel number + + Args: + target_dict (dict): dict for target token information + source_dict (dict): dict for source token information. + + Returns: + x_out (Tensor[B, N, C]): token features. + """ + + x_s = source_dict['x'] + idx_token_s = source_dict['idx_token'] + idx_token_t = target_dict['idx_token'] + T = target_dict['token_num'] + B, S, C = x_s.shape + N_init = idx_token_s.shape[1] + + weight = target_dict['agg_weight'] if 'agg_weight' in target_dict.keys( + ) else None + if weight is None: + weight = x_s.new_ones(B, N_init, 1) + weight = weight.reshape(-1) + + # choose the way with fewer flops. + if N_init < T * S: + # use sparse matrix multiplication + # Flops: B * N_init * (C+2) + idx_token_t = idx_token_t + torch.arange( + B, device=x_s.device)[:, None] * T + idx_token_s = idx_token_s + torch.arange( + B, device=x_s.device)[:, None] * S + coor = torch.stack([idx_token_t, idx_token_s], + dim=0).reshape(2, B * N_init) + + # torch.sparse does not support grad for sparse matrix + weight = weight.float().detach().to(torch.float32) + # build a matrix with shape [B*T, B*S] + A = torch.sparse.FloatTensor(coor, weight, torch.Size([B * T, B * S])) + # normalize the matrix + all_weight = A.type(torch.float32) @ x_s.new_ones(B * S, 1).type( + torch.float32) + 1e-6 + weight = weight / all_weight[(idx_token_t).reshape(-1), 0] + A = torch.sparse.FloatTensor(coor, weight, torch.Size([B * T, B * S])) + # sparse matmul + x_out = A.type(torch.float32) @ x_s.reshape(B * S, C).type( + torch.float32) + else: + # use dense matrix multiplication + # Flops: B * T * S * (C+2) + idx_batch = torch.arange( + B, device=x_s.device)[:, None].expand(B, N_init) + coor = torch.stack([idx_batch, idx_token_t, idx_token_s], + dim=0).reshape(3, B * N_init) + weight = weight.detach() # detach to reduce training time + # build a matrix with shape [B, T, S] + A = torch.sparse.FloatTensor(coor, weight, torch.Size([B, T, + S])).to_dense() + # normalize the matrix + A = A / (A.sum(dim=-1, keepdim=True) + 1e-6) + # dense matmul + x_out = A @ x_s + + x_out = x_out.reshape(B, T, C).type(x_s.dtype) + return x_out + + +def cluster_dpc_knn(token_dict, cluster_num, k=5, token_mask=None): + """Cluster tokens with DPC-KNN algorithm. + + Note: + B: batch size + N: token number + C: channel number + + Args: + token_dict (dict): dict for token information + cluster_num (int): cluster number + k (int): number of the nearest neighbor used for local density. + token_mask (Tensor[B, N]): mask indicating which token is the + padded empty token. Non-zero value means the token is meaningful, + zero value means the token is an empty token. If set to None, all + tokens are regarded as meaningful. + + Return: + idx_cluster (Tensor[B, N]): cluster index of each token. + cluster_num (int): actual cluster number. In this function, it equals + to the input cluster number. + """ + + with torch.no_grad(): + x = token_dict['x'] + B, N, C = x.shape + + dist_matrix = torch.cdist(x, x) / (C**0.5) + + if token_mask is not None: + token_mask = token_mask > 0 + # in order to not affect the local density, the + # distance between empty tokens and any other + # tokens should be the maximal distance. + dist_matrix = \ + dist_matrix * token_mask[:, None, :] +\ + (dist_matrix.max() + 1) * (~token_mask[:, None, :]) + + # get local density + dist_nearest, index_nearest = torch.topk( + dist_matrix, k=k, dim=-1, largest=False) + + density = (-(dist_nearest**2).mean(dim=-1)).exp() + # add a little noise to ensure no tokens have the same density. + density = density + torch.rand( + density.shape, device=density.device, dtype=density.dtype) * 1e-6 + + if token_mask is not None: + # the density of empty token should be 0 + density = density * token_mask + + # get distance indicator + mask = density[:, None, :] > density[:, :, None] + mask = mask.type(x.dtype) + dist_max = dist_matrix.flatten(1).max(dim=-1)[0][:, None, None] + dist, index_parent = (dist_matrix * mask + dist_max * + (1 - mask)).min(dim=-1) + + # select clustering center according to score + score = dist * density + _, index_down = torch.topk(score, k=cluster_num, dim=-1) + + # assign tokens to the nearest center + dist_matrix = index_points(dist_matrix, index_down) + + idx_cluster = dist_matrix.argmin(dim=1) + + # make sure cluster center merge to itself + idx_batch = torch.arange( + B, device=x.device)[:, None].expand(B, cluster_num) + idx_tmp = torch.arange( + cluster_num, device=x.device)[None, :].expand(B, cluster_num) + idx_cluster[idx_batch.reshape(-1), + index_down.reshape(-1)] = idx_tmp.reshape(-1) + + return idx_cluster, cluster_num + + +def merge_tokens(token_dict, idx_cluster, cluster_num, token_weight=None): + """Merge tokens in the same cluster to a single cluster. Implemented by + torch.index_add(). Flops: B*N*(C+2) + + Note: + B: batch size + N: token number + C: channel number + + Args: + token_dict (dict): dict for input token information + idx_cluster (Tensor[B, N]): cluster index of each token. + cluster_num (int): cluster number + token_weight (Tensor[B, N, 1]): weight for each token. + + Return: + out_dict (dict): dict for output token information + """ + + x = token_dict['x'] + idx_token = token_dict['idx_token'] + agg_weight = token_dict['agg_weight'] + + B, N, C = x.shape + if token_weight is None: + token_weight = x.new_ones(B, N, 1) + + idx_batch = torch.arange(B, device=x.device)[:, None] + idx = idx_cluster + idx_batch * cluster_num + + all_weight = token_weight.new_zeros(B * cluster_num, 1) + all_weight.index_add_( + dim=0, index=idx.reshape(B * N), source=token_weight.reshape(B * N, 1)) + all_weight = all_weight + 1e-6 + norm_weight = token_weight / all_weight[idx] + + # average token features + x_merged = x.new_zeros(B * cluster_num, C) + source = x * norm_weight + x_merged.index_add_( + dim=0, + index=idx.reshape(B * N), + source=source.reshape(B * N, C).type(x.dtype)) + x_merged = x_merged.reshape(B, cluster_num, C) + + idx_token_new = index_points(idx_cluster[..., None], idx_token).squeeze(-1) + weight_t = index_points(norm_weight, idx_token) + agg_weight_new = agg_weight * weight_t + agg_weight_new / agg_weight_new.max(dim=1, keepdim=True)[0] + + out_dict = {} + out_dict['x'] = x_merged + out_dict['token_num'] = cluster_num + out_dict['map_size'] = token_dict['map_size'] + out_dict['init_grid_size'] = token_dict['init_grid_size'] + out_dict['idx_token'] = idx_token_new + out_dict['agg_weight'] = agg_weight_new + return out_dict + + +class MLP(nn.Module): + """FFN with Depthwise Conv of TCFormer. + + Args: + in_features (int): The feature dimension. + hidden_features (int, optional): The hidden dimension of FFNs. + Defaults: The same as in_features. + out_features (int, optional): The output feature dimension. + Defaults: The same as in_features. + act_layer (nn.Module, optional): The activation config for FFNs. + Default: nn.GELU. + drop (float, optional): drop out rate. Default: 0. + """ + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.dwconv = DWConv(hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def init_weights(self): + """init weights.""" + for m in self.modules(): + if isinstance(m, nn.Linear): + trunc_normal_init(m, std=.02, bias=0.) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + x = self.fc1(x) + x = self.dwconv(x, H, W) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DWConv(nn.Module): + """Depthwise Conv for regular grid-based tokens. + + Args: + dim (int): The feature dimension. + """ + + def __init__(self, dim=768): + super(DWConv, self).__init__() + self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) + + def forward(self, x, H, W): + B, N, C = x.shape + x = x.transpose(1, 2).view(B, C, H, W) + x = self.dwconv(x) + x = x.flatten(2).transpose(1, 2) + return x + + +class TCFormerRegularAttention(nn.Module): + """Spatial Reduction Attention for regular grid-based tokens. + + Args: + dim (int): The feature dimension of tokens, + num_heads (int): Parallel attention heads. + qkv_bias (bool): enable bias for qkv if True. Default: False. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + attn_drop (float): A Dropout layer on attn_output_weights. + Default: 0.0. + proj_drop (float): A Dropout layer after attention process. + Default: 0.0. + sr_ratio (int): The ratio of spatial reduction of Spatial Reduction + Attention. Default: 1. + use_sr_conv (bool): If True, use a conv layer for spatial reduction. + If False, use a pooling process for spatial reduction. Defaults: + True. + """ + + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0., + proj_drop=0., + sr_ratio=1, + use_sr_conv=True, + ): + super().__init__() + assert dim % num_heads == 0, \ + f'dim {dim} should be divided by num_heads {num_heads}.' + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.sr_ratio = sr_ratio + self.use_sr_conv = use_sr_conv + if sr_ratio > 1 and self.use_sr_conv: + self.sr = nn.Conv2d( + dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = nn.LayerNorm(dim) + + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + trunc_normal_init(m, std=.02, bias=0.) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + B, N, C = x.shape + q = self.q(x).reshape(B, N, self.num_heads, + C // self.num_heads).permute(0, 2, 1, 3) + + if self.sr_ratio > 1: + kv = x.permute(0, 2, 1).reshape(B, C, H, W) + if self.use_sr_conv: + kv = self.sr(kv).reshape(B, C, -1).permute(0, 2, + 1).contiguous() + kv = self.norm(kv) + else: + kv = F.avg_pool2d( + kv, kernel_size=self.sr_ratio, stride=self.sr_ratio) + kv = kv.reshape(B, C, -1).permute(0, 2, 1).contiguous() + else: + kv = x + + kv = self.kv(kv).reshape(B, -1, 2, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, + 4).contiguous() + k, v = kv[0], kv[1] + + attn = (q * self.scale) @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class TCFormerRegularBlock(nn.Module): + """Transformer block for regular grid-based tokens. + + Args: + dim (int): The feature dimension. + num_heads (int): Parallel attention heads. + mlp_ratio (int): The expansion ratio for the FFNs. + qkv_bias (bool): enable bias for qkv if True. Default: False. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + drop (float): Dropout layers after attention process and in FFN. + Default: 0.0. + attn_drop (float): A Dropout layer on attn_output_weights. + Default: 0.0. + drop_path (int, optional): The drop path rate of transformer block. + Default: 0.0 + act_layer (nn.Module, optional): The activation config for FFNs. + Default: nn.GELU. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + sr_ratio (int): The ratio of spatial reduction of Spatial Reduction + Attention. Default: 1. + use_sr_conv (bool): If True, use a conv layer for spatial reduction. + If False, use a pooling process for spatial reduction. Defaults: + True. + """ + + def __init__(self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_cfg=dict(type='LN'), + sr_ratio=1, + use_sr_conv=True): + super().__init__() + self.norm1 = build_norm_layer(norm_cfg, dim)[1] + + self.attn = TCFormerRegularAttention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + sr_ratio=sr_ratio, + use_sr_conv=use_sr_conv) + self.drop_path = build_dropout( + dict(type='DropPath', drop_prob=drop_path)) + + self.norm2 = build_norm_layer(norm_cfg, dim)[1] + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = MLP( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Linear): + trunc_normal_init(m, std=.02, bias=0.) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, x, H, W): + x = x + self.drop_path(self.attn(self.norm1(x), H, W)) + x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) + return x + + +class TokenConv(nn.Conv2d): + """Conv layer for dynamic tokens. + + A skip link is added between the input and output tokens to reserve detail + tokens. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + groups = kwargs['groups'] if 'groups' in kwargs.keys() else 1 + self.skip = nn.Conv1d( + in_channels=kwargs['in_channels'], + out_channels=kwargs['out_channels'], + kernel_size=1, + bias=False, + groups=groups) + + def forward(self, token_dict): + x = token_dict['x'] + x = self.skip(x.permute(0, 2, 1)).permute(0, 2, 1) + x_map = token2map(token_dict) + x_map = super().forward(x_map) + x = x + map2token(x_map, token_dict) + return x + + +class TCMLP(nn.Module): + """FFN with Depthwise Conv for dynamic tokens. + + Args: + in_features (int): The feature dimension. + hidden_features (int, optional): The hidden dimension of FFNs. + Defaults: The same as in_features. + out_features (int, optional): The output feature dimension. + Defaults: The same as in_features. + act_layer (nn.Module, optional): The activation config for FFNs. + Default: nn.GELU. + drop (float, optional): drop out rate. Default: 0. + """ + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.dwconv = TokenConv( + in_channels=hidden_features, + out_channels=hidden_features, + kernel_size=3, + padding=1, + stride=1, + bias=True, + groups=hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def init_weights(self): + """init weights.""" + for m in self.modules(): + if isinstance(m, nn.Linear): + trunc_normal_init(m, std=.02, bias=0.) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def forward(self, token_dict): + token_dict['x'] = self.fc1(token_dict['x']) + x = self.dwconv(token_dict) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class TCFormerDynamicAttention(TCFormerRegularAttention): + """Spatial Reduction Attention for dynamic tokens.""" + + def forward(self, q_dict, kv_dict): + """Attention process for dynamic tokens. + Dynamic tokens are represented by a dict with the following keys: + x (torch.Tensor[B, N, C]): token features. + token_num(int): token number. + map_size(list[int] or tuple[int]): feature map resolution in + format [H, W]. + init_grid_size(list[int] or tuple[int]): initial grid resolution + in format [H_init, W_init]. + idx_token(torch.LongTensor[B, N_init]): indicates which token + the initial grid belongs to. + agg_weight(torch.LongTensor[B, N_init] or None): weight for + aggregation. Indicates the weight of each token in its + cluster. If set to None, uniform weight is used. + + Note: + B: batch size + N: token number + C: channel number + Ns: sampled point number + [H_init, W_init]: shape of initial grid + [H, W]: shape of feature map + N_init: numbers of initial token + + Args: + q_dict (dict): dict for query token information + kv_dict (dict): dict for key and value token information + + Return: + x (torch.Tensor[B, N, C]): output token features. + """ + + q = q_dict['x'] + kv = kv_dict['x'] + B, Nq, C = q.shape + Nkv = kv.shape[1] + conf_kv = kv_dict['token_score'] if 'token_score' in kv_dict.keys( + ) else kv.new_zeros(B, Nkv, 1) + + q = self.q(q).reshape(B, Nq, self.num_heads, + C // self.num_heads).permute(0, 2, 1, + 3).contiguous() + + if self.sr_ratio > 1: + tmp = torch.cat([kv, conf_kv], dim=-1) + tmp_dict = kv_dict.copy() + tmp_dict['x'] = tmp + tmp_dict['map_size'] = q_dict['map_size'] + tmp = token2map(tmp_dict) + + kv = tmp[:, :C] + conf_kv = tmp[:, C:] + + if self.use_sr_conv: + kv = self.sr(kv) + _, _, h, w = kv.shape + kv = kv.reshape(B, C, -1).permute(0, 2, 1).contiguous() + kv = self.norm(kv) + else: + kv = F.avg_pool2d( + kv, kernel_size=self.sr_ratio, stride=self.sr_ratio) + kv = kv.reshape(B, C, -1).permute(0, 2, 1).contiguous() + + conf_kv = F.avg_pool2d( + conf_kv, kernel_size=self.sr_ratio, stride=self.sr_ratio) + conf_kv = conf_kv.reshape(B, 1, -1).permute(0, 2, 1).contiguous() + + kv = self.kv(kv).reshape(B, -1, 2, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, + 4).contiguous() + k, v = kv[0], kv[1] + + attn = (q * self.scale) @ k.transpose(-2, -1) + + conf_kv = conf_kv.squeeze(-1)[:, None, None, :] + attn = attn + conf_kv + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, Nq, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +# Transformer block for dynamic tokens +class TCFormerDynamicBlock(TCFormerRegularBlock): + """Transformer block for dynamic tokens. + + Args: + dim (int): The feature dimension. + num_heads (int): Parallel attention heads. + mlp_ratio (int): The expansion ratio for the FFNs. + qkv_bias (bool): enable bias for qkv if True. Default: False. + qk_scale (float | None, optional): Override default qk scale of + head_dim ** -0.5 if set. Default: None. + drop (float): Dropout layers after attention process and in FFN. + Default: 0.0. + attn_drop (float): A Dropout layer on attn_output_weights. + Default: 0.0. + drop_path (int, optional): The drop path rate of transformer block. + Default: 0.0 + act_layer (nn.Module, optional): The activation config for FFNs. + Default: nn.GELU. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + sr_ratio (int): The ratio of spatial reduction of Spatial Reduction + Attention. Default: 1. + use_sr_conv (bool): If True, use a conv layer for spatial reduction. + If False, use a pooling process for spatial reduction. Defaults: + True. + """ + + def __init__(self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_cfg=dict(type='LN'), + sr_ratio=1, + use_sr_conv=True): + super(TCFormerRegularBlock, self).__init__() + self.norm1 = build_norm_layer(norm_cfg, dim)[1] + + self.attn = TCFormerDynamicAttention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + sr_ratio=sr_ratio, + use_sr_conv=use_sr_conv) + self.drop_path = build_dropout( + dict(type='DropPath', drop_prob=drop_path)) + + self.norm2 = build_norm_layer(norm_cfg, dim)[1] + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = TCMLP( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + + def forward(self, inputs): + """Forward function. + + Args: + inputs (dict or tuple[dict] or list[dict]): input dynamic + token information. If a single dict is provided, it's + regraded as query and key, value. If a tuple or list + of dict is provided, the first one is regarded as key + and the second one is regarded as key, value. + + Return: + q_dict (dict): dict for output token information + """ + if isinstance(inputs, tuple) or isinstance(inputs, list): + q_dict, kv_dict = inputs + else: + q_dict, kv_dict = inputs, None + + x = q_dict['x'] + # norm1 + q_dict['x'] = self.norm1(q_dict['x']) + if kv_dict is None: + kv_dict = q_dict + else: + kv_dict['x'] = self.norm1(kv_dict['x']) + + # attn + x = x + self.drop_path(self.attn(q_dict, kv_dict)) + + # mlp + q_dict['x'] = self.norm2(x) + x = x + self.drop_path(self.mlp(q_dict)) + q_dict['x'] = x + + return q_dict diff --git a/model-index.yml b/model-index.yml index ab193e060b..aa7691be16 100644 --- a/model-index.yml +++ b/model-index.yml @@ -141,6 +141,7 @@ Import: - configs/wholebody/2d_kpt_sview_rgb_img/topdown_heatmap/coco-wholebody/hrnet_coco-wholebody.yml - configs/wholebody/2d_kpt_sview_rgb_img/topdown_heatmap/coco-wholebody/hrnet_dark_coco-wholebody.yml - configs/wholebody/2d_kpt_sview_rgb_img/topdown_heatmap/coco-wholebody/resnet_coco-wholebody.yml +- configs/wholebody/2d_kpt_sview_rgb_img/topdown_heatmap/coco-wholebody/tcformer_coco-wholebody.yml - configs/wholebody/2d_kpt_sview_rgb_img/topdown_heatmap/coco-wholebody/vipnas_coco-wholebody.yml - configs/wholebody/2d_kpt_sview_rgb_img/topdown_heatmap/coco-wholebody/vipnas_dark_coco-wholebody.yml - configs/wholebody/2d_kpt_sview_rgb_img/topdown_heatmap/halpe/hrnet_dark_halpe.yml diff --git a/tests/test_backbones/test_tcformer.py b/tests/test_backbones/test_tcformer.py new file mode 100644 index 0000000000..848de97a82 --- /dev/null +++ b/tests/test_backbones/test_tcformer.py @@ -0,0 +1,66 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest +import torch + +from mmpose.models.backbones.tcformer import TCFormer + + +def test_tcformer(): + with pytest.raises(TypeError): + # Pretrained arg must be str or None. + TCFormer(pretrained=123) + + # test load pretrained weights + model = TCFormer( + embed_dims=[64, 128, 320, 512], + num_heads=[1, 2, 5, 8], + mlp_ratios=[8, 8, 4, 4], + qkv_bias=True, + num_layers=[3, 4, 6, 3], + sr_ratios=[8, 4, 2, 1], + pretrained='https://download.openmmlab.com/mmpose/' + 'pretrain_models/tcformer-4e1adbf1_20220421.pth') + model.init_weights() + + # test init weights from scratch + model = TCFormer(embed_dims=[32, 32, 32, 32], num_layers=[2, 2, 2, 2]) + model.init_weights() + + # Test normal inference + model = TCFormer() + temp = torch.randn((1, 3, 256, 192)) + outs = model(temp) + assert len(outs) == 4 + assert isinstance(outs[0], dict) + for key in [ + 'x', 'token_num', 'map_size', 'init_grid_size', 'idx_token', + 'agg_weight' + ]: + assert key in outs[0].keys() + + assert outs[0]['x'].shape == (1, 3072, 64) + assert outs[1]['x'].shape == (1, 768, 128) + assert outs[2]['x'].shape == (1, 192, 256) + assert outs[3]['x'].shape == (1, 48, 512) + assert outs[3]['idx_token'].shape == (1, 3072) + assert outs[3]['token_num'] == 48 + assert outs[3]['map_size'] == [8, 6] + assert outs[3]['init_grid_size'] == [64, 48] + + # Test abnormal inference size + temp = torch.randn((1, 3, 193, 255)) + outs = model(temp) + assert outs[0]['x'].shape == (1, 3136, 64) + assert outs[1]['x'].shape == (1, 784, 128) + assert outs[2]['x'].shape == (1, 196, 256) + assert outs[3]['x'].shape == (1, 49, 512) + + # Test output feature map + model = TCFormer(return_map=True) + temp = torch.randn((1, 3, 256, 192)) + outs = model(temp) + assert len(outs) == 4 + assert outs[0].shape == (1, 64, 64, 48) + assert outs[1].shape == (1, 128, 32, 24) + assert outs[2].shape == (1, 256, 16, 12) + assert outs[3].shape == (1, 512, 8, 6) diff --git a/tests/test_necks/test_tcformer_mta_neck.py b/tests/test_necks/test_tcformer_mta_neck.py new file mode 100644 index 0000000000..3114077564 --- /dev/null +++ b/tests/test_necks/test_tcformer_mta_neck.py @@ -0,0 +1,150 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest +import torch +from torch.nn.modules.batchnorm import _BatchNorm + +from mmpose.models.backbones.tcformer import TCFormer +from mmpose.models.necks.tcformer_mta_neck import MTA + + +def test_mta(): + in_channels = [8, 16, 32, 64] + out_channels = 8 + + # end_level=-1 is equal to end_level=3 + MTA(in_channels=in_channels, + out_channels=out_channels, + start_level=0, + end_level=-1, + num_outs=5) + MTA(in_channels=in_channels, + out_channels=out_channels, + start_level=0, + end_level=3, + num_outs=5) + + # `num_outs` is not equal to end_level - start_level + 1 + with pytest.raises(AssertionError): + MTA(in_channels=in_channels, + out_channels=out_channels, + start_level=1, + end_level=2, + num_outs=3) + + # `num_outs` SMALLER len(in_channels) - start_level + with pytest.raises(AssertionError): + MTA(in_channels=in_channels, + out_channels=out_channels, + start_level=1, + num_outs=2) + + # `end_level` is larger than len(in_channels) - 1 + with pytest.raises(AssertionError): + MTA(in_channels=in_channels, + out_channels=out_channels, + start_level=1, + end_level=4, + num_outs=2) + + # `num_outs` is not equal to end_level - start_level + with pytest.raises(AssertionError): + MTA(in_channels=in_channels, + out_channels=out_channels, + start_level=1, + end_level=3, + num_outs=1) + + # Invalid `add_extra_convs` option + with pytest.raises(AssertionError): + MTA(in_channels=in_channels, + out_channels=out_channels, + start_level=1, + add_extra_convs='on_xxx', + num_outs=5) + + backbone = TCFormer(embed_dims=[8, 16, 32, 64]) + temp = torch.randn((1, 3, 256, 192)) + feats = backbone(temp) + h, w = 64, 48 + + # normal forward + mta_model = MTA( + in_channels=in_channels, + out_channels=out_channels, + add_extra_convs=True, + num_outs=5) + assert mta_model.add_extra_convs == 'on_input' + outs = mta_model(feats) + assert len(outs) == 5 + for i in range(mta_model.num_outs): + outs[i].shape[1] == out_channels + outs[i].shape[2] == h // (2**i) + outs[i].shape[3] == w // (2**i) + + # Tests for mta with no extra convs (pooling is used instead) + mta_model = MTA( + in_channels=in_channels, + out_channels=out_channels, + start_level=1, + add_extra_convs=False, + num_outs=5) + outs = mta_model(feats) + assert len(outs) == mta_model.num_outs + assert not mta_model.add_extra_convs + for i in range(mta_model.num_outs): + outs[i].shape[1] == out_channels + outs[i].shape[2] == h // (2**i) + outs[i].shape[3] == w // (2**i) + + # Tests for mta with lateral bns + mta_model = MTA( + in_channels=in_channels, + out_channels=out_channels, + start_level=1, + add_extra_convs=True, + no_norm_on_lateral=False, + norm_cfg=dict(type='BN', requires_grad=True), + num_outs=5) + outs = mta_model(feats) + assert len(outs) == mta_model.num_outs + assert mta_model.add_extra_convs == 'on_input' + for i in range(mta_model.num_outs): + outs[i].shape[1] == out_channels + outs[i].shape[2] == h // (2**i) + outs[i].shape[3] == w // (2**i) + + bn_exist = False + for m in mta_model.modules(): + if isinstance(m, _BatchNorm): + bn_exist = True + assert bn_exist + + # Extra convs source is 'inputs' + mta_model = MTA( + in_channels=in_channels, + out_channels=out_channels, + add_extra_convs='on_input', + start_level=1, + num_outs=5) + assert mta_model.add_extra_convs == 'on_input' + outs = mta_model(feats) + assert len(outs) == mta_model.num_outs + for i in range(mta_model.num_outs): + outs[i].shape[1] == out_channels + outs[i].shape[2] == h // (2**i) + outs[i].shape[3] == w // (2**i) + + # Extra convs source is 'outputs' + mta_model = MTA( + in_channels=in_channels, + out_channels=out_channels, + add_extra_convs='on_output', + start_level=1, + num_outs=5) + assert mta_model.add_extra_convs == 'on_output' + outs = mta_model(feats) + assert len(outs) == mta_model.num_outs + for i in range(mta_model.num_outs): + outs[i].shape[1] == out_channels + outs[i].shape[2] == h // (2**i) + outs[i].shape[3] == w // (2**i)