diff --git a/configs/wholebody/2d_kpt_sview_rgb_img/topdown_heatmap/coco-wholebody/tcformer_coco-wholebody.md b/configs/wholebody/2d_kpt_sview_rgb_img/topdown_heatmap/coco-wholebody/tcformer_coco-wholebody.md
new file mode 100644
index 0000000000..6ab872452c
--- /dev/null
+++ b/configs/wholebody/2d_kpt_sview_rgb_img/topdown_heatmap/coco-wholebody/tcformer_coco-wholebody.md
@@ -0,0 +1,38 @@
+
+
+
+TCFormer (CVPR'2022)
+
+```bibtex
+@inproceedings{zeng2022not,
+ title={Not All Tokens Are Equal: Human-centric Visual Analysis via Token Clustering Transformer},
+ author={Zeng, Wang and Jin, Sheng and Liu, Wentao and Qian, Chen and Luo, Ping and Ouyang, Wanli and Wang, Xiaogang},
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
+ pages={11101--11111},
+ year={2022}
+}
+```
+
+
+
+
+
+
+COCO-WholeBody (ECCV'2020)
+
+```bibtex
+@inproceedings{jin2020whole,
+ title={Whole-Body Human Pose Estimation in the Wild},
+ author={Jin, Sheng and Xu, Lumin and Xu, Jin and Wang, Can and Liu, Wentao and Qian, Chen and Ouyang, Wanli and Luo, Ping},
+ booktitle={Proceedings of the European Conference on Computer Vision (ECCV)},
+ year={2020}
+}
+```
+
+
+
+Results on COCO-WholeBody v1.0 val with detector having human AP of 56.4 on COCO val2017 dataset
+
+| Arch | Input Size | Body AP | Body AR | Foot AP | Foot AR | Face AP | Face AR | Hand AP | Hand AR | Whole AP | Whole AR | ckpt | log |
+| :-------------------------------------- | :--------: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :------: | :------: | :--------------------------------------: | :-------------------------------------: |
+| [tcformer](/configs/wholebody/2d_kpt_sview_rgb_img/topdown_heatmap/coco-wholebody/tcformer_coco_wholebody_256x192.py) | 256x192 | 0.691 | 0.769 | 0.690 | 0.809 | 0.650 | 0.747 | 0.534 | 0.647 | 0.574 | 0.678 | [ckpt](https://download.openmmlab.com/mmpose/top_down/tcformer/tcformer_coco-wholebody_256x192-a0720efa_20220627.pth) | [log](https://download.openmmlab.com/mmpose/top_down/tcformer/tcformer_coco-wholebody_256x192_20220627.log.json) |
diff --git a/configs/wholebody/2d_kpt_sview_rgb_img/topdown_heatmap/coco-wholebody/tcformer_coco-wholebody.yml b/configs/wholebody/2d_kpt_sview_rgb_img/topdown_heatmap/coco-wholebody/tcformer_coco-wholebody.yml
new file mode 100644
index 0000000000..8504e196d1
--- /dev/null
+++ b/configs/wholebody/2d_kpt_sview_rgb_img/topdown_heatmap/coco-wholebody/tcformer_coco-wholebody.yml
@@ -0,0 +1,30 @@
+Collections:
+- Name: TCFormer
+ Paper:
+ Title: 'Not All Tokens Are Equal: Human-centric Visual Analysis via Token Clustering
+ Transformer'
+ URL: https://openaccess.thecvf.com/content/CVPR2022/html/Zeng_Not_All_Tokens_Are_Equal_Human-Centric_Visual_Analysis_via_Token_CVPR_2022_paper.html
+ README: https://github.com/open-mmlab/mmpose/blob/master/docs/en/papers/backbones/tcformer.md
+Models:
+- Config: configs/wholebody/2d_kpt_sview_rgb_img/topdown_heatmap/coco-wholebody/tcformer_coco_wholebody_256x192.py
+ In Collection: TCFormer
+ Metadata:
+ Architecture:
+ - TCFormer
+ Training Data: COCO-WholeBody
+ Name: topdown_heatmap_tcformer_coco_wholebody_256x192
+ Results:
+ - Dataset: COCO-WholeBody
+ Metrics:
+ Body AP: 0.691
+ Body AR: 0.769
+ Face AP: 0.65
+ Face AR: 0.747
+ Foot AP: 0.69
+ Foot AR: 0.809
+ Hand AP: 0.534
+ Hand AR: 0.647
+ Whole AP: 0.574
+ Whole AR: 0.678
+ Task: Wholebody 2D Keypoint
+ Weights: https://download.openmmlab.com/mmpose/top_down/tcformer/tcformer_coco-wholebody_256x192-a0720efa_20220627.pth
diff --git a/configs/wholebody/2d_kpt_sview_rgb_img/topdown_heatmap/coco-wholebody/tcformer_coco_wholebody_256x192.py b/configs/wholebody/2d_kpt_sview_rgb_img/topdown_heatmap/coco-wholebody/tcformer_coco_wholebody_256x192.py
new file mode 100644
index 0000000000..73d7898686
--- /dev/null
+++ b/configs/wholebody/2d_kpt_sview_rgb_img/topdown_heatmap/coco-wholebody/tcformer_coco_wholebody_256x192.py
@@ -0,0 +1,171 @@
+_base_ = ['../../../../_base_/datasets/coco_wholebody.py']
+log_level = 'INFO'
+load_from = None
+resume_from = None
+dist_params = dict(backend='nccl')
+workflow = [('train', 1)]
+checkpoint_config = dict(interval=10)
+evaluation = dict(interval=10, metric='mAP', save_best='AP')
+
+optimizer = dict(
+ type='AdamW',
+ lr=5e-4,
+ betas=(0.9, 0.999),
+ weight_decay=0.01,
+)
+
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=500,
+ warmup_ratio=0.001,
+ step=[170, 200])
+total_epochs = 210
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ # dict(type='TensorboardLoggerHook')
+ ])
+
+channel_cfg = dict(
+ num_output_channels=133,
+ dataset_joints=133,
+ dataset_channel=[
+ list(range(133)),
+ ],
+ inference_channel=list(range(133)))
+
+# model settings
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='TopDown',
+ pretrained='https://download.openmmlab.com/mmpose/'
+ 'pretrain_models/tcformer-4e1adbf1_20220421.pth',
+ backbone=dict(
+ type='TCFormer',
+ embed_dims=[64, 128, 320, 512],
+ num_heads=[1, 2, 5, 8],
+ mlp_ratios=[8, 8, 4, 4],
+ qkv_bias=True,
+ num_layers=[3, 4, 6, 3],
+ sr_ratios=[8, 4, 2, 1],
+ drop_path_rate=0.1),
+ neck=dict(
+ type='MTA',
+ in_channels=[64, 128, 320, 512],
+ out_channels=256,
+ start_level=0,
+ num_heads=[4, 4, 4, 4],
+ mlp_ratios=[4, 4, 4, 4],
+ num_outs=4,
+ use_sr_conv=False,
+ ),
+ keypoint_head=dict(
+ type='TopdownHeatmapSimpleHead',
+ in_channels=256,
+ out_channels=channel_cfg['num_output_channels'],
+ num_deconv_layers=0,
+ extra=dict(final_conv_kernel=1, ),
+ loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)),
+ train_cfg=dict(),
+ test_cfg=dict(
+ flip_test=True,
+ post_process='default',
+ shift_heatmap=True,
+ modulate_kernel=11))
+
+data_cfg = dict(
+ image_size=[192, 256],
+ heatmap_size=[48, 64],
+ num_output_channels=channel_cfg['num_output_channels'],
+ num_joints=channel_cfg['dataset_joints'],
+ dataset_channel=channel_cfg['dataset_channel'],
+ inference_channel=channel_cfg['inference_channel'],
+ soft_nms=False,
+ nms_thr=1.0,
+ oks_thr=0.9,
+ vis_thr=0.2,
+ use_gt_bbox=False,
+ det_bbox_thr=0.0,
+ bbox_file='data/coco/person_detection_results/'
+ 'COCO_val2017_detections_AP_H_56_person.json',
+)
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownGetBboxCenterScale', padding=1.25),
+ dict(type='TopDownRandomShiftBboxCenter', shift_factor=0.16, prob=0.3),
+ dict(type='TopDownRandomFlip', flip_prob=0.5),
+ dict(
+ type='TopDownHalfBodyTransform',
+ num_joints_half_body=8,
+ prob_half_body=0.3),
+ dict(
+ type='TopDownGetRandomScaleRotation', rot_factor=40, scale_factor=0.5),
+ dict(type='TopDownAffine'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTarget', sigma=2),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
+ 'rotation', 'bbox_score', 'flip_pairs'
+ ]),
+]
+
+val_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownGetBboxCenterScale', padding=1.25),
+ dict(type='TopDownAffine'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(
+ type='Collect',
+ keys=['img'],
+ meta_keys=[
+ 'image_file', 'center', 'scale', 'rotation', 'bbox_score',
+ 'flip_pairs'
+ ]),
+]
+
+test_pipeline = val_pipeline
+
+data_root = 'data/coco'
+data = dict(
+ samples_per_gpu=64,
+ workers_per_gpu=2,
+ val_dataloader=dict(samples_per_gpu=32),
+ test_dataloader=dict(samples_per_gpu=32),
+ train=dict(
+ type='TopDownCocoWholeBodyDataset',
+ ann_file=f'{data_root}/annotations/coco_wholebody_train_v1.0.json',
+ img_prefix=f'{data_root}/train2017/',
+ data_cfg=data_cfg,
+ pipeline=train_pipeline,
+ dataset_info={{_base_.dataset_info}}),
+ val=dict(
+ type='TopDownCocoWholeBodyDataset',
+ ann_file=f'{data_root}/annotations/coco_wholebody_val_v1.0.json',
+ img_prefix=f'{data_root}/val2017/',
+ data_cfg=data_cfg,
+ pipeline=val_pipeline,
+ dataset_info={{_base_.dataset_info}}),
+ test=dict(
+ type='TopDownCocoWholeBodyDataset',
+ ann_file=f'{data_root}/annotations/coco_wholebody_val_v1.0.json',
+ img_prefix=f'{data_root}/val2017/',
+ data_cfg=data_cfg,
+ pipeline=val_pipeline,
+ dataset_info={{_base_.dataset_info}}),
+)
diff --git a/docs/en/papers/backbones/tcformer.md b/docs/en/papers/backbones/tcformer.md
new file mode 100644
index 0000000000..26e7714bba
--- /dev/null
+++ b/docs/en/papers/backbones/tcformer.md
@@ -0,0 +1,53 @@
+# Not All Tokens Are Equal: Human-Centric Visual Analysis via Token Clustering Transformer
+
+
+
+
+TCFormer (CVPR'2022)
+
+```bibtex
+@inproceedings{zeng2022not,
+ title={Not All Tokens Are Equal: Human-centric Visual Analysis via Token Clustering Transformer},
+ author={Zeng, Wang and Jin, Sheng and Liu, Wentao and Qian, Chen and Luo, Ping and Ouyang, Wanli and Wang, Xiaogang},
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
+ pages={11101--11111},
+ year={2022}
+}
+```
+
+
+
+## Abstract
+
+
+
+Vision transformers have achieved great successes in
+many computer vision tasks. Most methods generate
+vision tokens by splitting an image into a regular
+and fixed grid and treating each cell as a token.
+However, not all regions are equally important in
+human-centric vision tasks, e.g., the human body
+needs a fine representation with many tokens, while
+the image background can be modeled by a few tokens.
+To address this problem, we propose a novel Vision
+Transformer, called Token Clustering Transformer
+(TCFormer), which merges tokens by progressive
+clustering, where the tokens can be merged from
+different locations with flexible shapes and sizes.
+The tokens in TCFormer can not only focus on important
+areas but also adjust the token shapes to fit the
+semantic concept and adopt a fine resolution for
+regions containing critical details, which is
+beneficial to capturing detailed information.
+Extensive experiments show that TCFormer consistently
+outperforms its counterparts on different challenging
+humancentric tasks and datasets, including whole-body
+pose estimation on COCO-WholeBody and 3D human mesh
+reconstruction on 3DPW. Code is available at
+https://github.com/zengwang430521/TCFormer.git.
+
+
+
+
+
+
diff --git a/mmpose/models/backbones/__init__.py b/mmpose/models/backbones/__init__.py
index 09745d443c..2fc64a8af3 100644
--- a/mmpose/models/backbones/__init__.py
+++ b/mmpose/models/backbones/__init__.py
@@ -22,6 +22,7 @@
from .shufflenet_v1 import ShuffleNetV1
from .shufflenet_v2 import ShuffleNetV2
from .swin import SwinTransformer
+from .tcformer import TCFormer
from .tcn import TCN
from .v2v_net import V2VNet
from .vgg import VGG
@@ -34,5 +35,5 @@
'SEResNet', 'SEResNeXt', 'ShuffleNetV1', 'ShuffleNetV2', 'CPM', 'RSN',
'MSPN', 'ResNeSt', 'VGG', 'TCN', 'ViPNAS_ResNet', 'ViPNAS_MobileNetV3',
'LiteHRNet', 'V2VNet', 'HRFormer', 'PyramidVisionTransformer',
- 'PyramidVisionTransformerV2', 'SwinTransformer', 'I3D'
+ 'PyramidVisionTransformerV2', 'SwinTransformer', 'I3D', 'TCFormer'
]
diff --git a/mmpose/models/backbones/tcformer.py b/mmpose/models/backbones/tcformer.py
new file mode 100644
index 0000000000..99cd44d266
--- /dev/null
+++ b/mmpose/models/backbones/tcformer.py
@@ -0,0 +1,284 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+import torch
+import torch.nn as nn
+from mmcv.cnn import (build_norm_layer, constant_init, normal_init,
+ trunc_normal_init)
+from mmcv.runner import _load_checkpoint, load_state_dict
+
+from ...utils import get_root_logger
+from ..builder import BACKBONES
+from ..utils import (PatchEmbed, TCFormerDynamicBlock, TCFormerRegularBlock,
+ TokenConv, cluster_dpc_knn, merge_tokens,
+ tcformer_convert, token2map)
+
+
+class CTM(nn.Module):
+ """Clustering-based Token Merging module in TCFormer.
+
+ Args:
+ sample_ratio (float): The sample ratio of tokens.
+ embed_dim (int): Input token feature dimension.
+ dim_out (int): Output token feature dimension.
+ k (int): number of the nearest neighbor used i DPC-knn algorithm.
+ """
+
+ def __init__(self, sample_ratio, embed_dim, dim_out, k=5):
+ super().__init__()
+ self.sample_ratio = sample_ratio
+ self.dim_out = dim_out
+ self.conv = TokenConv(
+ in_channels=embed_dim,
+ out_channels=dim_out,
+ kernel_size=3,
+ stride=2,
+ padding=1)
+ self.norm = nn.LayerNorm(self.dim_out)
+ self.score = nn.Linear(self.dim_out, 1)
+ self.k = k
+
+ def forward(self, token_dict):
+ token_dict = token_dict.copy()
+ x = self.conv(token_dict)
+ x = self.norm(x)
+ token_score = self.score(x)
+ token_weight = token_score.exp()
+
+ token_dict['x'] = x
+ B, N, C = x.shape
+ token_dict['token_score'] = token_score
+
+ cluster_num = max(math.ceil(N * self.sample_ratio), 1)
+ idx_cluster, cluster_num = cluster_dpc_knn(token_dict, cluster_num,
+ self.k)
+ down_dict = merge_tokens(token_dict, idx_cluster, cluster_num,
+ token_weight)
+
+ H, W = token_dict['map_size']
+ H = math.floor((H - 1) / 2 + 1)
+ W = math.floor((W - 1) / 2 + 1)
+ down_dict['map_size'] = [H, W]
+
+ return down_dict, token_dict
+
+
+@BACKBONES.register_module()
+class TCFormer(nn.Module):
+ """Token Clustering Transformer (TCFormer)
+
+ Implementation of `Not All Tokens Are Equal: Human-centric Visual
+ Analysis via Token Clustering Transformer
+ `
+
+ Args:
+ in_channels (int): Number of input channels. Default: 3.
+ embed_dims (list[int]): Embedding dimension. Default:
+ [64, 128, 256, 512].
+ num_heads (Sequence[int]): The attention heads of each transformer
+ encode layer. Default: [1, 2, 5, 8].
+ mlp_ratios (Sequence[int]): The ratio of the mlp hidden dim to the
+ embedding dim of each transformer block.
+ qkv_bias (bool): Enable bias for qkv if True. Default: True.
+ qk_scale (float | None, optional): Override default qk scale of
+ head_dim ** -0.5 if set. Default: None.
+ drop_rate (float): Probability of an element to be zeroed.
+ Default 0.0.
+ attn_drop_rate (float): The drop out rate for attention layer.
+ Default 0.0.
+ drop_path_rate (float): stochastic depth rate. Default 0.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='LN', eps=1e-6).
+ num_layers (Sequence[int]): The layer number of each transformer encode
+ layer. Default: [3, 4, 6, 3].
+ sr_ratios (Sequence[int]): The spatial reduction rate of each
+ transformer block. Default: [8, 4, 2, 1].
+ num_stages (int): The num of stages. Default: 4.
+ pretrained (str, optional): model pretrained path. Default: None.
+ k (int): number of the nearest neighbor used for local density.
+ sample_ratios (list[float]): The sample ratios of CTM modules.
+ Default: [0.25, 0.25, 0.25]
+ return_map (bool): If True, transfer dynamic tokens to feature map at
+ last. Default: False
+ convert_weights (bool): The flag indicates whether the
+ pre-trained model is from the original repo. We may need
+ to convert some keys to make it compatible.
+ Default: True.
+ """
+
+ def __init__(self,
+ in_channels=3,
+ embed_dims=[64, 128, 256, 512],
+ num_heads=[1, 2, 4, 8],
+ mlp_ratios=[4, 4, 4, 4],
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.,
+ norm_cfg=dict(type='LN', eps=1e-6),
+ num_layers=[3, 4, 6, 3],
+ sr_ratios=[8, 4, 2, 1],
+ num_stages=4,
+ pretrained=None,
+ k=5,
+ sample_ratios=[0.25, 0.25, 0.25],
+ return_map=False,
+ convert_weights=True):
+ super().__init__()
+
+ self.num_layers = num_layers
+ self.num_stages = num_stages
+ self.grid_stride = sr_ratios[0]
+ self.embed_dims = embed_dims
+ self.sr_ratios = sr_ratios
+ self.mlp_ratios = mlp_ratios
+ self.sample_ratios = sample_ratios
+ self.return_map = return_map
+ self.convert_weights = convert_weights
+
+ # stochastic depth decay rule
+ dpr = [
+ x.item()
+ for x in torch.linspace(0, drop_path_rate, sum(num_layers))
+ ]
+ cur = 0
+
+ # In stage 1, use the standard transformer blocks
+ for i in range(1):
+ patch_embed = PatchEmbed(
+ in_channels=in_channels if i == 0 else embed_dims[i - 1],
+ embed_dims=embed_dims[i],
+ kernel_size=7,
+ stride=4,
+ padding=3,
+ bias=True,
+ norm_cfg=dict(type='LN', eps=1e-6))
+
+ block = nn.ModuleList([
+ TCFormerRegularBlock(
+ dim=embed_dims[i],
+ num_heads=num_heads[i],
+ mlp_ratio=mlp_ratios[i],
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[cur + j],
+ norm_cfg=norm_cfg,
+ sr_ratio=sr_ratios[i]) for j in range(num_layers[i])
+ ])
+ norm = build_norm_layer(norm_cfg, embed_dims[i])[1]
+
+ cur += num_layers[i]
+
+ setattr(self, f'patch_embed{i + 1}', patch_embed)
+ setattr(self, f'block{i + 1}', block)
+ setattr(self, f'norm{i + 1}', norm)
+
+ # In stage 2~4, use TCFormerDynamicBlock for dynamic tokens
+ for i in range(1, num_stages):
+ ctm = CTM(sample_ratios[i - 1], embed_dims[i - 1], embed_dims[i],
+ k)
+
+ block = nn.ModuleList([
+ TCFormerDynamicBlock(
+ dim=embed_dims[i],
+ num_heads=num_heads[i],
+ mlp_ratio=mlp_ratios[i],
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[cur + j],
+ norm_cfg=norm_cfg,
+ sr_ratio=sr_ratios[i]) for j in range(num_layers[i])
+ ])
+ norm = build_norm_layer(norm_cfg, embed_dims[i])[1]
+ cur += num_layers[i]
+
+ setattr(self, f'ctm{i}', ctm)
+ setattr(self, f'block{i + 1}', block)
+ setattr(self, f'norm{i + 1}', norm)
+
+ self.init_weights(pretrained)
+
+ def init_weights(self, pretrained=None):
+ if isinstance(pretrained, str):
+ logger = get_root_logger()
+
+ checkpoint = _load_checkpoint(
+ pretrained, logger=logger, map_location='cpu')
+ logger.warning(f'Load pre-trained model for '
+ f'{self.__class__.__name__} from original repo')
+ if 'state_dict' in checkpoint:
+ state_dict = checkpoint['state_dict']
+ elif 'model' in checkpoint:
+ state_dict = checkpoint['model']
+ else:
+ state_dict = checkpoint
+
+ if self.convert_weights:
+ # We need to convert pre-trained weights to match this
+ # implementation.
+ state_dict = tcformer_convert(state_dict)
+ load_state_dict(self, state_dict, strict=False, logger=logger)
+
+ elif pretrained is None:
+ for m in self.modules():
+ if isinstance(m, nn.Linear):
+ trunc_normal_init(m, std=.02, bias=0.)
+ elif isinstance(m, nn.LayerNorm):
+ constant_init(m, 1.0)
+ elif isinstance(m, nn.Conv2d):
+ fan_out = m.kernel_size[0] * m.kernel_size[
+ 1] * m.out_channels
+ fan_out //= m.groups
+ normal_init(m, 0, math.sqrt(2.0 / fan_out))
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ def forward(self, x):
+ outs = []
+
+ i = 0
+ patch_embed = getattr(self, f'patch_embed{i + 1}')
+ block = getattr(self, f'block{i + 1}')
+ norm = getattr(self, f'norm{i + 1}')
+ x, (H, W) = patch_embed(x)
+ for blk in block:
+ x = blk(x, H, W)
+ x = norm(x)
+
+ # init token dict
+ B, N, _ = x.shape
+ device = x.device
+ idx_token = torch.arange(N)[None, :].repeat(B, 1).to(device)
+ agg_weight = x.new_ones(B, N, 1)
+ token_dict = {
+ 'x': x,
+ 'token_num': N,
+ 'map_size': [H, W],
+ 'init_grid_size': [H, W],
+ 'idx_token': idx_token,
+ 'agg_weight': agg_weight
+ }
+ outs.append(token_dict.copy())
+
+ # stage 2~4
+ for i in range(1, self.num_stages):
+ ctm = getattr(self, f'ctm{i}')
+ block = getattr(self, f'block{i + 1}')
+ norm = getattr(self, f'norm{i + 1}')
+
+ token_dict = ctm(token_dict) # down sample
+ for j, blk in enumerate(block):
+ token_dict = blk(token_dict)
+
+ token_dict['x'] = norm(token_dict['x'])
+ outs.append(token_dict)
+
+ if self.return_map:
+ outs = [token2map(token_dict) for token_dict in outs]
+
+ return outs
diff --git a/mmpose/models/necks/__init__.py b/mmpose/models/necks/__init__.py
index 7f4160f84d..16381b2659 100644
--- a/mmpose/models/necks/__init__.py
+++ b/mmpose/models/necks/__init__.py
@@ -2,5 +2,6 @@
from .fpn import FPN
from .gap_neck import GlobalAveragePooling
from .posewarper_neck import PoseWarperNeck
+from .tcformer_mta_neck import MTA
-__all__ = ['GlobalAveragePooling', 'PoseWarperNeck', 'FPN']
+__all__ = ['GlobalAveragePooling', 'PoseWarperNeck', 'FPN', 'MTA']
diff --git a/mmpose/models/necks/tcformer_mta_neck.py b/mmpose/models/necks/tcformer_mta_neck.py
new file mode 100644
index 0000000000..6723fb018e
--- /dev/null
+++ b/mmpose/models/necks/tcformer_mta_neck.py
@@ -0,0 +1,224 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule, constant_init, normal_init, trunc_normal_init
+from mmcv.runner import BaseModule
+
+from ..builder import NECKS
+from ..utils import TCFormerDynamicBlock, token2map, token_interp
+
+
+@NECKS.register_module()
+class MTA(BaseModule):
+ """Multi-stage Token feature Aggregation (MTA) module in TCFormer.
+
+ Args:
+ in_channels (list[int]): Number of input channels per stage.
+ Default: [64, 128, 256, 512].
+ out_channels (int): Number of output channels (used at each scale).
+ num_outs (int): Number of output scales. Default: 4.
+ start_level (int): Index of the start input backbone level used to
+ build the feature pyramid. Default: 0.
+ end_level (int): Index of the end input backbone level (exclusive) to
+ build the feature pyramid. Default: -1, which means the last level.
+ add_extra_convs (bool | str): If bool, it decides whether to add conv
+ layers on top of the original feature maps. Default to False.
+ If True, it is equivalent to `add_extra_convs='on_input'`.
+ If str, it specifies the source feature map of the extra convs.
+ Only the following options are allowed
+ - 'on_input': Last feat map of neck inputs (i.e. backbone feature).
+ - 'on_output': The last output feature map after fpn convs.
+ relu_before_extra_convs (bool): Whether to apply relu before the extra
+ conv. Default: False.
+ no_norm_on_lateral (bool): Whether to apply norm on lateral.
+ Default: False.
+ conv_cfg (dict): Config dict for convolution layer. Default: None.
+ norm_cfg (dict): Config dict for normalization layer. Default: None.
+ act_cfg (dict): Config dict for activation layer in ConvModule.
+ num_heads (Sequence[int]): The attention heads of each transformer
+ block. Default: [2, 2, 2, 2].
+ mlp_ratios (Sequence[int]): The ratio of the mlp hidden dim to the
+ embedding dim of each transformer block.
+ sr_ratios (Sequence[int]): The spatial reduction rate of each
+ transformer block. Default: [8, 4, 2, 1].
+ qkv_bias (bool): Enable bias for qkv if True. Default: True.
+ qk_scale (float | None, optional): Override default qk scale of
+ head_dim ** -0.5 if set. Default: None.
+ drop_rate (float): Probability of an element to be zeroed.
+ Default 0.0.
+ attn_drop_rate (float): The drop out rate for attention layer.
+ Default 0.0.
+ drop_path_rate (float): stochastic depth rate. Default 0.
+ transformer_norm_cfg (dict): Config dict for normalization layer
+ in transformer blocks. Default: dict(type='LN').
+ use_sr_conv (bool): If True, use a conv layer for spatial reduction.
+ If False, use a pooling process for spatial reduction. Defaults:
+ False.
+ """
+
+ def __init__(
+ self,
+ in_channels=[64, 128, 256, 512],
+ out_channels=128,
+ num_outs=4,
+ start_level=0,
+ end_level=-1,
+ add_extra_convs=False,
+ relu_before_extra_convs=False,
+ no_norm_on_lateral=False,
+ conv_cfg=None,
+ norm_cfg=None,
+ act_cfg=None,
+ num_heads=[2, 2, 2, 2],
+ mlp_ratios=[4, 4, 4, 4],
+ sr_ratios=[8, 4, 2, 1],
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.,
+ transformer_norm_cfg=dict(type='LN'),
+ use_sr_conv=False,
+ ):
+ super().__init__()
+ assert isinstance(in_channels, list)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.num_ins = len(in_channels)
+ self.num_outs = num_outs
+ self.no_norm_on_lateral = no_norm_on_lateral
+ self.fp16_enabled = False
+ self.norm_cfg = norm_cfg
+ self.conv_cfg = conv_cfg
+ self.act_cfg = act_cfg
+ self.mlp_ratios = mlp_ratios
+
+ if end_level == -1 or end_level == self.num_ins - 1:
+ self.backbone_end_level = self.num_ins
+ assert num_outs >= self.num_ins - start_level
+ else:
+ # if end_level is not the last level, no extra level is allowed
+ self.backbone_end_level = end_level + 1
+ assert end_level < self.num_ins
+ assert num_outs == end_level - start_level + 1
+ self.start_level = start_level
+ self.end_level = end_level
+
+ self.lateral_convs = nn.ModuleList()
+ self.merge_blocks = nn.ModuleList()
+
+ for i in range(self.start_level, self.backbone_end_level):
+ l_conv = ConvModule(
+ in_channels[i],
+ out_channels,
+ 1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg if not self.no_norm_on_lateral else None,
+ act_cfg=act_cfg,
+ inplace=False)
+ self.lateral_convs.append(l_conv)
+
+ for i in range(self.start_level, self.backbone_end_level - 1):
+ merge_block = TCFormerDynamicBlock(
+ dim=out_channels,
+ num_heads=num_heads[i],
+ mlp_ratio=mlp_ratios[i],
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=drop_path_rate,
+ norm_cfg=transformer_norm_cfg,
+ sr_ratio=sr_ratios[i],
+ use_sr_conv=use_sr_conv)
+ self.merge_blocks.append(merge_block)
+
+ # add extra conv layers (e.g., RetinaNet)
+ self.relu_before_extra_convs = relu_before_extra_convs
+
+ self.add_extra_convs = add_extra_convs
+ assert isinstance(add_extra_convs, (str, bool))
+ if isinstance(add_extra_convs, str):
+ # Extra_convs_source choices: 'on_input', 'on_output'
+ assert add_extra_convs in ('on_input', 'on_output')
+ elif add_extra_convs: # True
+ self.add_extra_convs = 'on_input'
+
+ self.extra_convs = nn.ModuleList()
+ extra_levels = num_outs - (self.end_level + 1 - self.start_level)
+ if self.add_extra_convs and extra_levels >= 1:
+ for i in range(extra_levels):
+ if i == 0 and self.add_extra_convs == 'on_input':
+ in_channels = self.in_channels[self.end_level]
+ else:
+ in_channels = out_channels
+ extra_fpn_conv = ConvModule(
+ in_channels,
+ out_channels,
+ 3,
+ stride=2,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ inplace=False)
+ self.extra_convs.append(extra_fpn_conv)
+
+ def init_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Linear):
+ trunc_normal_init(m, std=.02, bias=0.)
+ elif isinstance(m, nn.LayerNorm):
+ constant_init(m, 1.0)
+ elif isinstance(m, nn.Conv2d):
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ fan_out //= m.groups
+ normal_init(m, 0, math.sqrt(2.0 / fan_out))
+
+ def forward(self, inputs):
+ """Forward function."""
+ assert len(inputs) == len(self.in_channels)
+
+ # build lateral tokens
+ input_dicts = []
+ for i, lateral_conv in enumerate(self.lateral_convs):
+ tmp = inputs[i + self.start_level].copy()
+ tmp['x'] = lateral_conv(tmp['x'].unsqueeze(2).permute(
+ 0, 3, 1, 2)).permute(0, 2, 3, 1).squeeze(2)
+ input_dicts.append(tmp)
+
+ # merge from high level to low level
+ for i in range(len(input_dicts) - 2, -1, -1):
+ input_dicts[i]['x'] = input_dicts[i]['x'] + token_interp(
+ input_dicts[i], input_dicts[i + 1])
+ input_dicts[i] = self.merge_blocks[i](input_dicts[i])
+
+ # transform to feature map
+ outs = [token2map(token_dict) for token_dict in input_dicts]
+
+ # part 2: add extra levels
+ used_backbone_levels = len(outs)
+ if self.num_outs > len(outs):
+ # use max pool to get more levels on top of outputs
+ if not self.add_extra_convs:
+ for i in range(self.num_outs - used_backbone_levels):
+ outs.append(F.max_pool2d(outs[-1], 1, stride=2))
+ # add conv layers on top of original feature maps
+ else:
+ if self.add_extra_convs == 'on_input':
+ tmp = inputs[self.backbone_end_level - 1]
+ extra_source = token2map(tmp)
+ elif self.add_extra_convs == 'on_output':
+ extra_source = outs[-1]
+ else:
+ raise NotImplementedError
+
+ outs.append(self.extra_convs[0](extra_source))
+ for i in range(1, self.num_outs - used_backbone_levels):
+ if self.relu_before_extra_convs:
+ outs.append(self.extra_convs[i](F.relu(outs[-1])))
+ else:
+ outs.append(self.extra_convs[i](outs[-1]))
+ return outs
diff --git a/mmpose/models/utils/__init__.py b/mmpose/models/utils/__init__.py
index 9be0b9a08c..d2ed972655 100644
--- a/mmpose/models/utils/__init__.py
+++ b/mmpose/models/utils/__init__.py
@@ -1,14 +1,19 @@
# Copyright (c) OpenMMLab. All rights reserved.
-from .ckpt_convert import pvt_convert
+from .ckpt_convert import pvt_convert, tcformer_convert
from .geometry import batch_rodrigues, quat_to_rotmat, rot6d_to_rotmat
from .misc import torch_meshgrid_ij
from .ops import resize
from .realnvp import RealNVP
from .smpl import SMPL
+from .tcformer_utils import (TCFormerDynamicBlock, TCFormerRegularBlock,
+ TokenConv, cluster_dpc_knn, merge_tokens,
+ token2map, token_interp)
from .transformer import PatchEmbed, PatchMerging, nchw_to_nlc, nlc_to_nchw
__all__ = [
'SMPL', 'PatchEmbed', 'nchw_to_nlc', 'nlc_to_nchw', 'pvt_convert',
'PatchMerging', 'batch_rodrigues', 'quat_to_rotmat', 'rot6d_to_rotmat',
- 'resize', 'RealNVP', 'torch_meshgrid_ij'
+ 'resize', 'RealNVP', 'torch_meshgrid_ij', 'token2map', 'TokenConv',
+ 'TCFormerRegularBlock', 'TCFormerDynamicBlock', 'cluster_dpc_knn',
+ 'merge_tokens', 'token_interp', 'tcformer_convert'
]
diff --git a/mmpose/models/utils/ckpt_convert.py b/mmpose/models/utils/ckpt_convert.py
index 05f5cdb4a3..f5213937db 100644
--- a/mmpose/models/utils/ckpt_convert.py
+++ b/mmpose/models/utils/ckpt_convert.py
@@ -80,3 +80,15 @@ def pvt_convert(ckpt):
new_ckpt[new_k] = new_v
return new_ckpt
+
+
+def tcformer_convert(ckpt):
+ new_ckpt = OrderedDict()
+ # Process the concat between q linear weights and kv linear weights
+ for k, v in ckpt.items():
+ if 'patch_embed' in k:
+ new_k = k.replace('.proj.', '.projection.')
+ else:
+ new_k = k
+ new_ckpt[new_k] = v
+ return new_ckpt
diff --git a/mmpose/models/utils/tcformer_utils.py b/mmpose/models/utils/tcformer_utils.py
new file mode 100644
index 0000000000..ff7e09dba4
--- /dev/null
+++ b/mmpose/models/utils/tcformer_utils.py
@@ -0,0 +1,960 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import build_norm_layer, trunc_normal_init
+from mmcv.cnn.bricks.transformer import build_dropout
+
+
+def get_grid_index(init_grid_size, map_size, device):
+ """For every initial grid, get its index in the feature map.
+ Note:
+ [H_init, W_init]: shape of initial grid
+ [H, W]: shape of feature map
+ N_init: numbers of initial token
+
+ Args:
+ init_grid_size (list[int] or tuple[int]): initial grid resolution in
+ format [H_init, W_init].
+ map_size (list[int] or tuple[int]): feature map resolution in format
+ [H, W].
+ device: the device of output
+
+ Returns:
+ idx (torch.LongTensor[B, N_init]): index in flattened feature map.
+ """
+ H_init, W_init = init_grid_size
+ H, W = map_size
+ idx = torch.arange(H * W, device=device).reshape(1, 1, H, W)
+ idx = F.interpolate(idx.float(), [H_init, W_init], mode='nearest').long()
+ return idx.flatten()
+
+
+def index_points(points, idx):
+ """Sample features following the index.
+ Note:
+ B: batch size
+ N: point number
+ C: channel number of each point
+ Ns: sampled point number
+
+ Args:
+ points (torch.Tensor[B, N, C]): input points data
+ idx (torch.LongTensor[B, S]): sample index
+
+ Returns:
+ new_points (torch.Tensor[B, Ns, C]):, indexed points data
+ """
+ device = points.device
+ B = points.shape[0]
+ view_shape = list(idx.shape)
+ view_shape[1:] = [1] * (len(view_shape) - 1)
+ repeat_shape = list(idx.shape)
+ repeat_shape[0] = 1
+ batch_indices = torch.arange(
+ B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
+ new_points = points[batch_indices, idx, :]
+ return new_points
+
+
+def token2map(token_dict):
+ """Transform vision tokens to feature map. This function only works when
+ the resolution of the feature map is not higher than the initial grid
+ structure.
+
+ Note:
+ B: batch size
+ C: channel number of each token
+ [H, W]: shape of feature map
+ N_init: numbers of initial token
+
+ Args:
+ token_dict (dict): dict for token information.
+
+ Returns:
+ x_out (Tensor[B, C, H, W]): feature map.
+ """
+
+ x = token_dict['x']
+ H, W = token_dict['map_size']
+ H_init, W_init = token_dict['init_grid_size']
+ idx_token = token_dict['idx_token']
+ B, N, C = x.shape
+ N_init = H_init * W_init
+ device = x.device
+
+ if N_init == N and N == H * W:
+ # for the initial tokens with grid structure, just reshape
+ return x.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous()
+
+ # for each initial grid, get the corresponding index in
+ # the flattened feature map.
+ idx_hw = get_grid_index([H_init, W_init], [H, W],
+ device=device)[None, :].expand(B, -1)
+ idx_batch = torch.arange(B, device=device)[:, None].expand(B, N_init)
+ value = x.new_ones(B * N_init)
+
+ # choose the way with fewer flops.
+ if N_init < N * H * W:
+ # use sparse matrix multiplication
+ # Flops: B * N_init * (C+2)
+ idx_hw = idx_hw + idx_batch * H * W
+ idx_tokens = idx_token + idx_batch * N
+ coor = torch.stack([idx_hw, idx_tokens], dim=0).reshape(2, B * N_init)
+
+ # torch.sparse do not support gradient for
+ # sparse tensor, so we detach it
+ value = value.detach().to(torch.float32)
+
+ # build a sparse matrix with the shape [B * H * W, B * N]
+ A = torch.sparse.FloatTensor(coor, value,
+ torch.Size([B * H * W, B * N]))
+
+ # normalize the weight for each row
+ all_weight = A @ x.new_ones(B * N, 1).type(torch.float32) + 1e-6
+ value = value / all_weight[idx_hw.reshape(-1), 0]
+
+ # update the matrix with normalize weight
+ A = torch.sparse.FloatTensor(coor, value,
+ torch.Size([B * H * W, B * N]))
+
+ # sparse matrix multiplication
+ x_out = A @ x.reshape(B * N, C).to(torch.float32) # [B*H*W, C]
+
+ else:
+ # use dense matrix multiplication
+ # Flops: B * N * H * W * (C+2)
+ coor = torch.stack([idx_batch, idx_hw, idx_token],
+ dim=0).reshape(3, B * N_init)
+
+ # build a matrix with shape [B, H*W, N]
+ A = torch.sparse.FloatTensor(coor, value, torch.Size([B, H * W,
+ N])).to_dense()
+ # normalize the weight
+ A = A / (A.sum(dim=-1, keepdim=True) + 1e-6)
+
+ x_out = A @ x # [B, H*W, C]
+
+ x_out = x_out.type(x.dtype)
+ x_out = x_out.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous()
+ return x_out
+
+
+def map2token(feature_map, token_dict):
+ """Transform feature map to vision tokens. This function only works when
+ the resolution of the feature map is not higher than the initial grid
+ structure.
+
+ Note:
+ B: batch size
+ C: channel number
+ [H, W]: shape of feature map
+ N_init: numbers of initial token
+
+ Args:
+ feature_map (Tensor[B, C, H, W]): feature map.
+ token_dict (dict): dict for token information.
+
+ Returns:
+ out (Tensor[B, N, C]): token features.
+ """
+ idx_token = token_dict['idx_token']
+ N = token_dict['token_num']
+ H_init, W_init = token_dict['init_grid_size']
+ N_init = H_init * W_init
+
+ B, C, H, W = feature_map.shape
+ device = feature_map.device
+
+ if N_init == N and N == H * W:
+ # for the initial tokens with grid structure, just reshape
+ return feature_map.flatten(2).permute(0, 2, 1).contiguous()
+
+ idx_hw = get_grid_index([H_init, W_init], [H, W],
+ device=device)[None, :].expand(B, -1)
+
+ idx_batch = torch.arange(B, device=device)[:, None].expand(B, N_init)
+ value = feature_map.new_ones(B * N_init)
+
+ # choose the way with fewer flops.
+ if N_init < N * H * W:
+ # use sparse matrix multiplication
+ # Flops: B * N_init * (C+2)
+ idx_token = idx_token + idx_batch * N
+ idx_hw = idx_hw + idx_batch * H * W
+ indices = torch.stack([idx_token, idx_hw], dim=0).reshape(2, -1)
+
+ # sparse mm do not support gradient for sparse matrix
+ value = value.detach().to(torch.float32)
+ # build a sparse matrix with shape [B*N, B*H*W]
+ A = torch.sparse_coo_tensor(indices, value, (B * N, B * H * W))
+ # normalize the matrix
+ all_weight = A @ torch.ones(
+ [B * H * W, 1], device=device, dtype=torch.float32) + 1e-6
+ value = value / all_weight[idx_token.reshape(-1), 0]
+
+ A = torch.sparse_coo_tensor(indices, value, (B * N, B * H * W))
+ # out: [B*N, C]
+ out = A @ feature_map. \
+ permute(0, 2, 3, 1).contiguous().reshape(B * H * W, C).float()
+ else:
+ # use dense matrix multiplication
+ # Flops: B * N * H * W * (C+2)
+ indices = torch.stack([idx_batch, idx_token, idx_hw],
+ dim=0).reshape(3, -1)
+ value = value.detach() # To reduce the training time, we detach here.
+ A = torch.sparse_coo_tensor(indices, value, (B, N, H * W)).to_dense()
+ # normalize the matrix
+ A = A / (A.sum(dim=-1, keepdim=True) + 1e-6)
+
+ out = A @ feature_map.permute(0, 2, 3, 1).reshape(B, H * W,
+ C).contiguous()
+
+ out = out.type(feature_map.dtype)
+ out = out.reshape(B, N, C)
+ return out
+
+
+def token_interp(target_dict, source_dict):
+ """Transform token features between different distribution.
+
+ Note:
+ B: batch size
+ N: token number
+ C: channel number
+
+ Args:
+ target_dict (dict): dict for target token information
+ source_dict (dict): dict for source token information.
+
+ Returns:
+ x_out (Tensor[B, N, C]): token features.
+ """
+
+ x_s = source_dict['x']
+ idx_token_s = source_dict['idx_token']
+ idx_token_t = target_dict['idx_token']
+ T = target_dict['token_num']
+ B, S, C = x_s.shape
+ N_init = idx_token_s.shape[1]
+
+ weight = target_dict['agg_weight'] if 'agg_weight' in target_dict.keys(
+ ) else None
+ if weight is None:
+ weight = x_s.new_ones(B, N_init, 1)
+ weight = weight.reshape(-1)
+
+ # choose the way with fewer flops.
+ if N_init < T * S:
+ # use sparse matrix multiplication
+ # Flops: B * N_init * (C+2)
+ idx_token_t = idx_token_t + torch.arange(
+ B, device=x_s.device)[:, None] * T
+ idx_token_s = idx_token_s + torch.arange(
+ B, device=x_s.device)[:, None] * S
+ coor = torch.stack([idx_token_t, idx_token_s],
+ dim=0).reshape(2, B * N_init)
+
+ # torch.sparse does not support grad for sparse matrix
+ weight = weight.float().detach().to(torch.float32)
+ # build a matrix with shape [B*T, B*S]
+ A = torch.sparse.FloatTensor(coor, weight, torch.Size([B * T, B * S]))
+ # normalize the matrix
+ all_weight = A.type(torch.float32) @ x_s.new_ones(B * S, 1).type(
+ torch.float32) + 1e-6
+ weight = weight / all_weight[(idx_token_t).reshape(-1), 0]
+ A = torch.sparse.FloatTensor(coor, weight, torch.Size([B * T, B * S]))
+ # sparse matmul
+ x_out = A.type(torch.float32) @ x_s.reshape(B * S, C).type(
+ torch.float32)
+ else:
+ # use dense matrix multiplication
+ # Flops: B * T * S * (C+2)
+ idx_batch = torch.arange(
+ B, device=x_s.device)[:, None].expand(B, N_init)
+ coor = torch.stack([idx_batch, idx_token_t, idx_token_s],
+ dim=0).reshape(3, B * N_init)
+ weight = weight.detach() # detach to reduce training time
+ # build a matrix with shape [B, T, S]
+ A = torch.sparse.FloatTensor(coor, weight, torch.Size([B, T,
+ S])).to_dense()
+ # normalize the matrix
+ A = A / (A.sum(dim=-1, keepdim=True) + 1e-6)
+ # dense matmul
+ x_out = A @ x_s
+
+ x_out = x_out.reshape(B, T, C).type(x_s.dtype)
+ return x_out
+
+
+def cluster_dpc_knn(token_dict, cluster_num, k=5, token_mask=None):
+ """Cluster tokens with DPC-KNN algorithm.
+
+ Note:
+ B: batch size
+ N: token number
+ C: channel number
+
+ Args:
+ token_dict (dict): dict for token information
+ cluster_num (int): cluster number
+ k (int): number of the nearest neighbor used for local density.
+ token_mask (Tensor[B, N]): mask indicating which token is the
+ padded empty token. Non-zero value means the token is meaningful,
+ zero value means the token is an empty token. If set to None, all
+ tokens are regarded as meaningful.
+
+ Return:
+ idx_cluster (Tensor[B, N]): cluster index of each token.
+ cluster_num (int): actual cluster number. In this function, it equals
+ to the input cluster number.
+ """
+
+ with torch.no_grad():
+ x = token_dict['x']
+ B, N, C = x.shape
+
+ dist_matrix = torch.cdist(x, x) / (C**0.5)
+
+ if token_mask is not None:
+ token_mask = token_mask > 0
+ # in order to not affect the local density, the
+ # distance between empty tokens and any other
+ # tokens should be the maximal distance.
+ dist_matrix = \
+ dist_matrix * token_mask[:, None, :] +\
+ (dist_matrix.max() + 1) * (~token_mask[:, None, :])
+
+ # get local density
+ dist_nearest, index_nearest = torch.topk(
+ dist_matrix, k=k, dim=-1, largest=False)
+
+ density = (-(dist_nearest**2).mean(dim=-1)).exp()
+ # add a little noise to ensure no tokens have the same density.
+ density = density + torch.rand(
+ density.shape, device=density.device, dtype=density.dtype) * 1e-6
+
+ if token_mask is not None:
+ # the density of empty token should be 0
+ density = density * token_mask
+
+ # get distance indicator
+ mask = density[:, None, :] > density[:, :, None]
+ mask = mask.type(x.dtype)
+ dist_max = dist_matrix.flatten(1).max(dim=-1)[0][:, None, None]
+ dist, index_parent = (dist_matrix * mask + dist_max *
+ (1 - mask)).min(dim=-1)
+
+ # select clustering center according to score
+ score = dist * density
+ _, index_down = torch.topk(score, k=cluster_num, dim=-1)
+
+ # assign tokens to the nearest center
+ dist_matrix = index_points(dist_matrix, index_down)
+
+ idx_cluster = dist_matrix.argmin(dim=1)
+
+ # make sure cluster center merge to itself
+ idx_batch = torch.arange(
+ B, device=x.device)[:, None].expand(B, cluster_num)
+ idx_tmp = torch.arange(
+ cluster_num, device=x.device)[None, :].expand(B, cluster_num)
+ idx_cluster[idx_batch.reshape(-1),
+ index_down.reshape(-1)] = idx_tmp.reshape(-1)
+
+ return idx_cluster, cluster_num
+
+
+def merge_tokens(token_dict, idx_cluster, cluster_num, token_weight=None):
+ """Merge tokens in the same cluster to a single cluster. Implemented by
+ torch.index_add(). Flops: B*N*(C+2)
+
+ Note:
+ B: batch size
+ N: token number
+ C: channel number
+
+ Args:
+ token_dict (dict): dict for input token information
+ idx_cluster (Tensor[B, N]): cluster index of each token.
+ cluster_num (int): cluster number
+ token_weight (Tensor[B, N, 1]): weight for each token.
+
+ Return:
+ out_dict (dict): dict for output token information
+ """
+
+ x = token_dict['x']
+ idx_token = token_dict['idx_token']
+ agg_weight = token_dict['agg_weight']
+
+ B, N, C = x.shape
+ if token_weight is None:
+ token_weight = x.new_ones(B, N, 1)
+
+ idx_batch = torch.arange(B, device=x.device)[:, None]
+ idx = idx_cluster + idx_batch * cluster_num
+
+ all_weight = token_weight.new_zeros(B * cluster_num, 1)
+ all_weight.index_add_(
+ dim=0, index=idx.reshape(B * N), source=token_weight.reshape(B * N, 1))
+ all_weight = all_weight + 1e-6
+ norm_weight = token_weight / all_weight[idx]
+
+ # average token features
+ x_merged = x.new_zeros(B * cluster_num, C)
+ source = x * norm_weight
+ x_merged.index_add_(
+ dim=0,
+ index=idx.reshape(B * N),
+ source=source.reshape(B * N, C).type(x.dtype))
+ x_merged = x_merged.reshape(B, cluster_num, C)
+
+ idx_token_new = index_points(idx_cluster[..., None], idx_token).squeeze(-1)
+ weight_t = index_points(norm_weight, idx_token)
+ agg_weight_new = agg_weight * weight_t
+ agg_weight_new / agg_weight_new.max(dim=1, keepdim=True)[0]
+
+ out_dict = {}
+ out_dict['x'] = x_merged
+ out_dict['token_num'] = cluster_num
+ out_dict['map_size'] = token_dict['map_size']
+ out_dict['init_grid_size'] = token_dict['init_grid_size']
+ out_dict['idx_token'] = idx_token_new
+ out_dict['agg_weight'] = agg_weight_new
+ return out_dict
+
+
+class MLP(nn.Module):
+ """FFN with Depthwise Conv of TCFormer.
+
+ Args:
+ in_features (int): The feature dimension.
+ hidden_features (int, optional): The hidden dimension of FFNs.
+ Defaults: The same as in_features.
+ out_features (int, optional): The output feature dimension.
+ Defaults: The same as in_features.
+ act_layer (nn.Module, optional): The activation config for FFNs.
+ Default: nn.GELU.
+ drop (float, optional): drop out rate. Default: 0.
+ """
+
+ def __init__(self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.dwconv = DWConv(hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def init_weights(self):
+ """init weights."""
+ for m in self.modules():
+ if isinstance(m, nn.Linear):
+ trunc_normal_init(m, std=.02, bias=0.)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+ elif isinstance(m, nn.Conv2d):
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ fan_out //= m.groups
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
+ if m.bias is not None:
+ m.bias.data.zero_()
+
+ def forward(self, x, H, W):
+ x = self.fc1(x)
+ x = self.dwconv(x, H, W)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class DWConv(nn.Module):
+ """Depthwise Conv for regular grid-based tokens.
+
+ Args:
+ dim (int): The feature dimension.
+ """
+
+ def __init__(self, dim=768):
+ super(DWConv, self).__init__()
+ self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
+
+ def forward(self, x, H, W):
+ B, N, C = x.shape
+ x = x.transpose(1, 2).view(B, C, H, W)
+ x = self.dwconv(x)
+ x = x.flatten(2).transpose(1, 2)
+ return x
+
+
+class TCFormerRegularAttention(nn.Module):
+ """Spatial Reduction Attention for regular grid-based tokens.
+
+ Args:
+ dim (int): The feature dimension of tokens,
+ num_heads (int): Parallel attention heads.
+ qkv_bias (bool): enable bias for qkv if True. Default: False.
+ qk_scale (float | None, optional): Override default qk scale of
+ head_dim ** -0.5 if set. Default: None.
+ attn_drop (float): A Dropout layer on attn_output_weights.
+ Default: 0.0.
+ proj_drop (float): A Dropout layer after attention process.
+ Default: 0.0.
+ sr_ratio (int): The ratio of spatial reduction of Spatial Reduction
+ Attention. Default: 1.
+ use_sr_conv (bool): If True, use a conv layer for spatial reduction.
+ If False, use a pooling process for spatial reduction. Defaults:
+ True.
+ """
+
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ qkv_bias=False,
+ qk_scale=None,
+ attn_drop=0.,
+ proj_drop=0.,
+ sr_ratio=1,
+ use_sr_conv=True,
+ ):
+ super().__init__()
+ assert dim % num_heads == 0, \
+ f'dim {dim} should be divided by num_heads {num_heads}.'
+
+ self.dim = dim
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim**-0.5
+
+ self.q = nn.Linear(dim, dim, bias=qkv_bias)
+ self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ self.sr_ratio = sr_ratio
+ self.use_sr_conv = use_sr_conv
+ if sr_ratio > 1 and self.use_sr_conv:
+ self.sr = nn.Conv2d(
+ dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
+ self.norm = nn.LayerNorm(dim)
+
+ def init_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Linear):
+ trunc_normal_init(m, std=.02, bias=0.)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+ elif isinstance(m, nn.Conv2d):
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ fan_out //= m.groups
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
+ if m.bias is not None:
+ m.bias.data.zero_()
+
+ def forward(self, x, H, W):
+ B, N, C = x.shape
+ q = self.q(x).reshape(B, N, self.num_heads,
+ C // self.num_heads).permute(0, 2, 1, 3)
+
+ if self.sr_ratio > 1:
+ kv = x.permute(0, 2, 1).reshape(B, C, H, W)
+ if self.use_sr_conv:
+ kv = self.sr(kv).reshape(B, C, -1).permute(0, 2,
+ 1).contiguous()
+ kv = self.norm(kv)
+ else:
+ kv = F.avg_pool2d(
+ kv, kernel_size=self.sr_ratio, stride=self.sr_ratio)
+ kv = kv.reshape(B, C, -1).permute(0, 2, 1).contiguous()
+ else:
+ kv = x
+
+ kv = self.kv(kv).reshape(B, -1, 2, self.num_heads,
+ C // self.num_heads).permute(2, 0, 3, 1,
+ 4).contiguous()
+ k, v = kv[0], kv[1]
+
+ attn = (q * self.scale) @ k.transpose(-2, -1)
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+
+ return x
+
+
+class TCFormerRegularBlock(nn.Module):
+ """Transformer block for regular grid-based tokens.
+
+ Args:
+ dim (int): The feature dimension.
+ num_heads (int): Parallel attention heads.
+ mlp_ratio (int): The expansion ratio for the FFNs.
+ qkv_bias (bool): enable bias for qkv if True. Default: False.
+ qk_scale (float | None, optional): Override default qk scale of
+ head_dim ** -0.5 if set. Default: None.
+ drop (float): Dropout layers after attention process and in FFN.
+ Default: 0.0.
+ attn_drop (float): A Dropout layer on attn_output_weights.
+ Default: 0.0.
+ drop_path (int, optional): The drop path rate of transformer block.
+ Default: 0.0
+ act_layer (nn.Module, optional): The activation config for FFNs.
+ Default: nn.GELU.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='LN').
+ sr_ratio (int): The ratio of spatial reduction of Spatial Reduction
+ Attention. Default: 1.
+ use_sr_conv (bool): If True, use a conv layer for spatial reduction.
+ If False, use a pooling process for spatial reduction. Defaults:
+ True.
+ """
+
+ def __init__(self,
+ dim,
+ num_heads,
+ mlp_ratio=4.,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.,
+ attn_drop=0.,
+ drop_path=0.,
+ act_layer=nn.GELU,
+ norm_cfg=dict(type='LN'),
+ sr_ratio=1,
+ use_sr_conv=True):
+ super().__init__()
+ self.norm1 = build_norm_layer(norm_cfg, dim)[1]
+
+ self.attn = TCFormerRegularAttention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ sr_ratio=sr_ratio,
+ use_sr_conv=use_sr_conv)
+ self.drop_path = build_dropout(
+ dict(type='DropPath', drop_prob=drop_path))
+
+ self.norm2 = build_norm_layer(norm_cfg, dim)[1]
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = MLP(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop)
+
+ def init_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Linear):
+ trunc_normal_init(m, std=.02, bias=0.)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+ elif isinstance(m, nn.Conv2d):
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ fan_out //= m.groups
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
+ if m.bias is not None:
+ m.bias.data.zero_()
+
+ def forward(self, x, H, W):
+ x = x + self.drop_path(self.attn(self.norm1(x), H, W))
+ x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
+ return x
+
+
+class TokenConv(nn.Conv2d):
+ """Conv layer for dynamic tokens.
+
+ A skip link is added between the input and output tokens to reserve detail
+ tokens.
+ """
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ groups = kwargs['groups'] if 'groups' in kwargs.keys() else 1
+ self.skip = nn.Conv1d(
+ in_channels=kwargs['in_channels'],
+ out_channels=kwargs['out_channels'],
+ kernel_size=1,
+ bias=False,
+ groups=groups)
+
+ def forward(self, token_dict):
+ x = token_dict['x']
+ x = self.skip(x.permute(0, 2, 1)).permute(0, 2, 1)
+ x_map = token2map(token_dict)
+ x_map = super().forward(x_map)
+ x = x + map2token(x_map, token_dict)
+ return x
+
+
+class TCMLP(nn.Module):
+ """FFN with Depthwise Conv for dynamic tokens.
+
+ Args:
+ in_features (int): The feature dimension.
+ hidden_features (int, optional): The hidden dimension of FFNs.
+ Defaults: The same as in_features.
+ out_features (int, optional): The output feature dimension.
+ Defaults: The same as in_features.
+ act_layer (nn.Module, optional): The activation config for FFNs.
+ Default: nn.GELU.
+ drop (float, optional): drop out rate. Default: 0.
+ """
+
+ def __init__(self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.dwconv = TokenConv(
+ in_channels=hidden_features,
+ out_channels=hidden_features,
+ kernel_size=3,
+ padding=1,
+ stride=1,
+ bias=True,
+ groups=hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def init_weights(self):
+ """init weights."""
+ for m in self.modules():
+ if isinstance(m, nn.Linear):
+ trunc_normal_init(m, std=.02, bias=0.)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+ elif isinstance(m, nn.Conv2d):
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ fan_out //= m.groups
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
+ if m.bias is not None:
+ m.bias.data.zero_()
+
+ def forward(self, token_dict):
+ token_dict['x'] = self.fc1(token_dict['x'])
+ x = self.dwconv(token_dict)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class TCFormerDynamicAttention(TCFormerRegularAttention):
+ """Spatial Reduction Attention for dynamic tokens."""
+
+ def forward(self, q_dict, kv_dict):
+ """Attention process for dynamic tokens.
+ Dynamic tokens are represented by a dict with the following keys:
+ x (torch.Tensor[B, N, C]): token features.
+ token_num(int): token number.
+ map_size(list[int] or tuple[int]): feature map resolution in
+ format [H, W].
+ init_grid_size(list[int] or tuple[int]): initial grid resolution
+ in format [H_init, W_init].
+ idx_token(torch.LongTensor[B, N_init]): indicates which token
+ the initial grid belongs to.
+ agg_weight(torch.LongTensor[B, N_init] or None): weight for
+ aggregation. Indicates the weight of each token in its
+ cluster. If set to None, uniform weight is used.
+
+ Note:
+ B: batch size
+ N: token number
+ C: channel number
+ Ns: sampled point number
+ [H_init, W_init]: shape of initial grid
+ [H, W]: shape of feature map
+ N_init: numbers of initial token
+
+ Args:
+ q_dict (dict): dict for query token information
+ kv_dict (dict): dict for key and value token information
+
+ Return:
+ x (torch.Tensor[B, N, C]): output token features.
+ """
+
+ q = q_dict['x']
+ kv = kv_dict['x']
+ B, Nq, C = q.shape
+ Nkv = kv.shape[1]
+ conf_kv = kv_dict['token_score'] if 'token_score' in kv_dict.keys(
+ ) else kv.new_zeros(B, Nkv, 1)
+
+ q = self.q(q).reshape(B, Nq, self.num_heads,
+ C // self.num_heads).permute(0, 2, 1,
+ 3).contiguous()
+
+ if self.sr_ratio > 1:
+ tmp = torch.cat([kv, conf_kv], dim=-1)
+ tmp_dict = kv_dict.copy()
+ tmp_dict['x'] = tmp
+ tmp_dict['map_size'] = q_dict['map_size']
+ tmp = token2map(tmp_dict)
+
+ kv = tmp[:, :C]
+ conf_kv = tmp[:, C:]
+
+ if self.use_sr_conv:
+ kv = self.sr(kv)
+ _, _, h, w = kv.shape
+ kv = kv.reshape(B, C, -1).permute(0, 2, 1).contiguous()
+ kv = self.norm(kv)
+ else:
+ kv = F.avg_pool2d(
+ kv, kernel_size=self.sr_ratio, stride=self.sr_ratio)
+ kv = kv.reshape(B, C, -1).permute(0, 2, 1).contiguous()
+
+ conf_kv = F.avg_pool2d(
+ conf_kv, kernel_size=self.sr_ratio, stride=self.sr_ratio)
+ conf_kv = conf_kv.reshape(B, 1, -1).permute(0, 2, 1).contiguous()
+
+ kv = self.kv(kv).reshape(B, -1, 2, self.num_heads,
+ C // self.num_heads).permute(2, 0, 3, 1,
+ 4).contiguous()
+ k, v = kv[0], kv[1]
+
+ attn = (q * self.scale) @ k.transpose(-2, -1)
+
+ conf_kv = conf_kv.squeeze(-1)[:, None, None, :]
+ attn = attn + conf_kv
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, Nq, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+# Transformer block for dynamic tokens
+class TCFormerDynamicBlock(TCFormerRegularBlock):
+ """Transformer block for dynamic tokens.
+
+ Args:
+ dim (int): The feature dimension.
+ num_heads (int): Parallel attention heads.
+ mlp_ratio (int): The expansion ratio for the FFNs.
+ qkv_bias (bool): enable bias for qkv if True. Default: False.
+ qk_scale (float | None, optional): Override default qk scale of
+ head_dim ** -0.5 if set. Default: None.
+ drop (float): Dropout layers after attention process and in FFN.
+ Default: 0.0.
+ attn_drop (float): A Dropout layer on attn_output_weights.
+ Default: 0.0.
+ drop_path (int, optional): The drop path rate of transformer block.
+ Default: 0.0
+ act_layer (nn.Module, optional): The activation config for FFNs.
+ Default: nn.GELU.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='LN').
+ sr_ratio (int): The ratio of spatial reduction of Spatial Reduction
+ Attention. Default: 1.
+ use_sr_conv (bool): If True, use a conv layer for spatial reduction.
+ If False, use a pooling process for spatial reduction. Defaults:
+ True.
+ """
+
+ def __init__(self,
+ dim,
+ num_heads,
+ mlp_ratio=4.,
+ qkv_bias=False,
+ qk_scale=None,
+ drop=0.,
+ attn_drop=0.,
+ drop_path=0.,
+ act_layer=nn.GELU,
+ norm_cfg=dict(type='LN'),
+ sr_ratio=1,
+ use_sr_conv=True):
+ super(TCFormerRegularBlock, self).__init__()
+ self.norm1 = build_norm_layer(norm_cfg, dim)[1]
+
+ self.attn = TCFormerDynamicAttention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ sr_ratio=sr_ratio,
+ use_sr_conv=use_sr_conv)
+ self.drop_path = build_dropout(
+ dict(type='DropPath', drop_prob=drop_path))
+
+ self.norm2 = build_norm_layer(norm_cfg, dim)[1]
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = TCMLP(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop)
+
+ def forward(self, inputs):
+ """Forward function.
+
+ Args:
+ inputs (dict or tuple[dict] or list[dict]): input dynamic
+ token information. If a single dict is provided, it's
+ regraded as query and key, value. If a tuple or list
+ of dict is provided, the first one is regarded as key
+ and the second one is regarded as key, value.
+
+ Return:
+ q_dict (dict): dict for output token information
+ """
+ if isinstance(inputs, tuple) or isinstance(inputs, list):
+ q_dict, kv_dict = inputs
+ else:
+ q_dict, kv_dict = inputs, None
+
+ x = q_dict['x']
+ # norm1
+ q_dict['x'] = self.norm1(q_dict['x'])
+ if kv_dict is None:
+ kv_dict = q_dict
+ else:
+ kv_dict['x'] = self.norm1(kv_dict['x'])
+
+ # attn
+ x = x + self.drop_path(self.attn(q_dict, kv_dict))
+
+ # mlp
+ q_dict['x'] = self.norm2(x)
+ x = x + self.drop_path(self.mlp(q_dict))
+ q_dict['x'] = x
+
+ return q_dict
diff --git a/model-index.yml b/model-index.yml
index ab193e060b..aa7691be16 100644
--- a/model-index.yml
+++ b/model-index.yml
@@ -141,6 +141,7 @@ Import:
- configs/wholebody/2d_kpt_sview_rgb_img/topdown_heatmap/coco-wholebody/hrnet_coco-wholebody.yml
- configs/wholebody/2d_kpt_sview_rgb_img/topdown_heatmap/coco-wholebody/hrnet_dark_coco-wholebody.yml
- configs/wholebody/2d_kpt_sview_rgb_img/topdown_heatmap/coco-wholebody/resnet_coco-wholebody.yml
+- configs/wholebody/2d_kpt_sview_rgb_img/topdown_heatmap/coco-wholebody/tcformer_coco-wholebody.yml
- configs/wholebody/2d_kpt_sview_rgb_img/topdown_heatmap/coco-wholebody/vipnas_coco-wholebody.yml
- configs/wholebody/2d_kpt_sview_rgb_img/topdown_heatmap/coco-wholebody/vipnas_dark_coco-wholebody.yml
- configs/wholebody/2d_kpt_sview_rgb_img/topdown_heatmap/halpe/hrnet_dark_halpe.yml
diff --git a/tests/test_backbones/test_tcformer.py b/tests/test_backbones/test_tcformer.py
new file mode 100644
index 0000000000..848de97a82
--- /dev/null
+++ b/tests/test_backbones/test_tcformer.py
@@ -0,0 +1,66 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import pytest
+import torch
+
+from mmpose.models.backbones.tcformer import TCFormer
+
+
+def test_tcformer():
+ with pytest.raises(TypeError):
+ # Pretrained arg must be str or None.
+ TCFormer(pretrained=123)
+
+ # test load pretrained weights
+ model = TCFormer(
+ embed_dims=[64, 128, 320, 512],
+ num_heads=[1, 2, 5, 8],
+ mlp_ratios=[8, 8, 4, 4],
+ qkv_bias=True,
+ num_layers=[3, 4, 6, 3],
+ sr_ratios=[8, 4, 2, 1],
+ pretrained='https://download.openmmlab.com/mmpose/'
+ 'pretrain_models/tcformer-4e1adbf1_20220421.pth')
+ model.init_weights()
+
+ # test init weights from scratch
+ model = TCFormer(embed_dims=[32, 32, 32, 32], num_layers=[2, 2, 2, 2])
+ model.init_weights()
+
+ # Test normal inference
+ model = TCFormer()
+ temp = torch.randn((1, 3, 256, 192))
+ outs = model(temp)
+ assert len(outs) == 4
+ assert isinstance(outs[0], dict)
+ for key in [
+ 'x', 'token_num', 'map_size', 'init_grid_size', 'idx_token',
+ 'agg_weight'
+ ]:
+ assert key in outs[0].keys()
+
+ assert outs[0]['x'].shape == (1, 3072, 64)
+ assert outs[1]['x'].shape == (1, 768, 128)
+ assert outs[2]['x'].shape == (1, 192, 256)
+ assert outs[3]['x'].shape == (1, 48, 512)
+ assert outs[3]['idx_token'].shape == (1, 3072)
+ assert outs[3]['token_num'] == 48
+ assert outs[3]['map_size'] == [8, 6]
+ assert outs[3]['init_grid_size'] == [64, 48]
+
+ # Test abnormal inference size
+ temp = torch.randn((1, 3, 193, 255))
+ outs = model(temp)
+ assert outs[0]['x'].shape == (1, 3136, 64)
+ assert outs[1]['x'].shape == (1, 784, 128)
+ assert outs[2]['x'].shape == (1, 196, 256)
+ assert outs[3]['x'].shape == (1, 49, 512)
+
+ # Test output feature map
+ model = TCFormer(return_map=True)
+ temp = torch.randn((1, 3, 256, 192))
+ outs = model(temp)
+ assert len(outs) == 4
+ assert outs[0].shape == (1, 64, 64, 48)
+ assert outs[1].shape == (1, 128, 32, 24)
+ assert outs[2].shape == (1, 256, 16, 12)
+ assert outs[3].shape == (1, 512, 8, 6)
diff --git a/tests/test_necks/test_tcformer_mta_neck.py b/tests/test_necks/test_tcformer_mta_neck.py
new file mode 100644
index 0000000000..3114077564
--- /dev/null
+++ b/tests/test_necks/test_tcformer_mta_neck.py
@@ -0,0 +1,150 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import pytest
+import torch
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from mmpose.models.backbones.tcformer import TCFormer
+from mmpose.models.necks.tcformer_mta_neck import MTA
+
+
+def test_mta():
+ in_channels = [8, 16, 32, 64]
+ out_channels = 8
+
+ # end_level=-1 is equal to end_level=3
+ MTA(in_channels=in_channels,
+ out_channels=out_channels,
+ start_level=0,
+ end_level=-1,
+ num_outs=5)
+ MTA(in_channels=in_channels,
+ out_channels=out_channels,
+ start_level=0,
+ end_level=3,
+ num_outs=5)
+
+ # `num_outs` is not equal to end_level - start_level + 1
+ with pytest.raises(AssertionError):
+ MTA(in_channels=in_channels,
+ out_channels=out_channels,
+ start_level=1,
+ end_level=2,
+ num_outs=3)
+
+ # `num_outs` SMALLER len(in_channels) - start_level
+ with pytest.raises(AssertionError):
+ MTA(in_channels=in_channels,
+ out_channels=out_channels,
+ start_level=1,
+ num_outs=2)
+
+ # `end_level` is larger than len(in_channels) - 1
+ with pytest.raises(AssertionError):
+ MTA(in_channels=in_channels,
+ out_channels=out_channels,
+ start_level=1,
+ end_level=4,
+ num_outs=2)
+
+ # `num_outs` is not equal to end_level - start_level
+ with pytest.raises(AssertionError):
+ MTA(in_channels=in_channels,
+ out_channels=out_channels,
+ start_level=1,
+ end_level=3,
+ num_outs=1)
+
+ # Invalid `add_extra_convs` option
+ with pytest.raises(AssertionError):
+ MTA(in_channels=in_channels,
+ out_channels=out_channels,
+ start_level=1,
+ add_extra_convs='on_xxx',
+ num_outs=5)
+
+ backbone = TCFormer(embed_dims=[8, 16, 32, 64])
+ temp = torch.randn((1, 3, 256, 192))
+ feats = backbone(temp)
+ h, w = 64, 48
+
+ # normal forward
+ mta_model = MTA(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ add_extra_convs=True,
+ num_outs=5)
+ assert mta_model.add_extra_convs == 'on_input'
+ outs = mta_model(feats)
+ assert len(outs) == 5
+ for i in range(mta_model.num_outs):
+ outs[i].shape[1] == out_channels
+ outs[i].shape[2] == h // (2**i)
+ outs[i].shape[3] == w // (2**i)
+
+ # Tests for mta with no extra convs (pooling is used instead)
+ mta_model = MTA(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ start_level=1,
+ add_extra_convs=False,
+ num_outs=5)
+ outs = mta_model(feats)
+ assert len(outs) == mta_model.num_outs
+ assert not mta_model.add_extra_convs
+ for i in range(mta_model.num_outs):
+ outs[i].shape[1] == out_channels
+ outs[i].shape[2] == h // (2**i)
+ outs[i].shape[3] == w // (2**i)
+
+ # Tests for mta with lateral bns
+ mta_model = MTA(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ start_level=1,
+ add_extra_convs=True,
+ no_norm_on_lateral=False,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ num_outs=5)
+ outs = mta_model(feats)
+ assert len(outs) == mta_model.num_outs
+ assert mta_model.add_extra_convs == 'on_input'
+ for i in range(mta_model.num_outs):
+ outs[i].shape[1] == out_channels
+ outs[i].shape[2] == h // (2**i)
+ outs[i].shape[3] == w // (2**i)
+
+ bn_exist = False
+ for m in mta_model.modules():
+ if isinstance(m, _BatchNorm):
+ bn_exist = True
+ assert bn_exist
+
+ # Extra convs source is 'inputs'
+ mta_model = MTA(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ add_extra_convs='on_input',
+ start_level=1,
+ num_outs=5)
+ assert mta_model.add_extra_convs == 'on_input'
+ outs = mta_model(feats)
+ assert len(outs) == mta_model.num_outs
+ for i in range(mta_model.num_outs):
+ outs[i].shape[1] == out_channels
+ outs[i].shape[2] == h // (2**i)
+ outs[i].shape[3] == w // (2**i)
+
+ # Extra convs source is 'outputs'
+ mta_model = MTA(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ add_extra_convs='on_output',
+ start_level=1,
+ num_outs=5)
+ assert mta_model.add_extra_convs == 'on_output'
+ outs = mta_model(feats)
+ assert len(outs) == mta_model.num_outs
+ for i in range(mta_model.num_outs):
+ outs[i].shape[1] == out_channels
+ outs[i].shape[2] == h // (2**i)
+ outs[i].shape[3] == w // (2**i)