diff --git a/LICENSES.md b/LICENSES.md new file mode 100644 index 0000000000..d082607c09 --- /dev/null +++ b/LICENSES.md @@ -0,0 +1,7 @@ +# Licenses for special algorithms + +In this file, we list the algorithms with other licenses instead of Apache 2.0. Users should be careful about adopting these algorithms in any commercial matters. + +| Algorithm | Files | License | +| :-------: | :---------------------------------------------------------------------------------------------------------------------------------------------------------: | :--------------: | +| EDPose | [mmpose/models/heads/transformer_heads/edpose_head.py](https://github.com/open-mmlab/mmpose/blob/main/mmpose/models/heads/transformer_heads/edpose_head.py) | IDEA License 1.0 | diff --git a/configs/body_2d_keypoint/edpose/coco/edpose_coco.md b/configs/body_2d_keypoint/edpose/coco/edpose_coco.md new file mode 100644 index 0000000000..cb1c566c00 --- /dev/null +++ b/configs/body_2d_keypoint/edpose/coco/edpose_coco.md @@ -0,0 +1,62 @@ + + +
+ED-Pose (ICLR'2023) + +```bibtex +@inproceedings{ +yang2023explicit, +title={Explicit Box Detection Unifies End-to-End Multi-Person Pose Estimation}, +author={Jie Yang and Ailing Zeng and Shilong Liu and Feng Li and Ruimao Zhang and Lei Zhang}, +booktitle={International Conference on Learning Representations}, +year={2023}, +url={https://openreview.net/forum?id=s4WVupnJjmX} +} +``` + +
+ + + +
+ResNet (CVPR'2016) + +```bibtex +@inproceedings{he2016deep, + title={Deep residual learning for image recognition}, + author={He, Kaiming and Zhang, Xiangyu and Ren, Shaoqing and Sun, Jian}, + booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition}, + pages={770--778}, + year={2016} +} +``` + +
+ + + +
+COCO (ECCV'2014) + +```bibtex +@inproceedings{lin2014microsoft, + title={Microsoft coco: Common objects in context}, + author={Lin, Tsung-Yi and Maire, Michael and Belongie, Serge and Hays, James and Perona, Pietro and Ramanan, Deva and Doll{\'a}r, Piotr and Zitnick, C Lawrence}, + booktitle={European conference on computer vision}, + pages={740--755}, + year={2014}, + organization={Springer} +} +``` + +
+ +Results on COCO val2017. + +| Arch | BackBone | AP | AP50 | AP75 | AR | AR50 | ckpt | log | +| :-------------------------------------------- | :-------: | :---: | :-------------: | :-------------: | :---: | :-------------: | :--------------------------------------------: | :-------------------------------------------: | +| [edpose_res50_coco](/configs/body_2d_keypoint/edpose/coco/edpose_res50_8xb2-50e_coco-800x1333.py) | ResNet-50 | 0.716 | 0.898 | 0.783 | 0.793 | 0.944 | [ckpt](https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/edpose/coco/edpose_res50_coco_3rdparty.pth) | [log](https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/edpose/coco/edpose_res50_coco_3rdparty.json) | + +The checkpoint is converted from the official repo. The training of EDPose is not supported yet. It will be supported in the future updates. + +The above config follows [Pure Python style](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta). Please install `mmengine>=0.8.2` to use this config. diff --git a/configs/body_2d_keypoint/edpose/coco/edpose_coco.yml b/configs/body_2d_keypoint/edpose/coco/edpose_coco.yml new file mode 100644 index 0000000000..48b303b2c3 --- /dev/null +++ b/configs/body_2d_keypoint/edpose/coco/edpose_coco.yml @@ -0,0 +1,26 @@ +Collections: +- Name: ED-Pose + Paper: + Title: Explicit Box Detection Unifies End-to-End Multi-Person Pose Estimation + URL: https://arxiv.org/pdf/2302.01593.pdf + README: https://github.com/open-mmlab/mmpose/blob/main/docs/src/papers/algorithms/edpose.md +Models: +- Config: configs/body_2d_keypoint/edpose/coco/edpose_res50_8xb2-50e_coco-800x1333.py + In Collection: ED-Pose + Alias: edpose + Metadata: + Architecture: &id001 + - ED-Pose + - ResNet + Training Data: COCO + Name: edpose_res50_8xb2-50e_coco-800x1333 + Results: + - Dataset: COCO + Metrics: + AP: 0.716 + AP@0.5: 0.898 + AP@0.75: 0.783 + AR: 0.793 + AR@0.5: 0.944 + Task: Body 2D Keypoint + Weights: https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/edpose/coco/edpose_res50_coco_3rdparty.pth diff --git a/configs/body_2d_keypoint/edpose/coco/edpose_res50_8xb2-50e_coco-800x1333.py b/configs/body_2d_keypoint/edpose/coco/edpose_res50_8xb2-50e_coco-800x1333.py new file mode 100644 index 0000000000..56854d8807 --- /dev/null +++ b/configs/body_2d_keypoint/edpose/coco/edpose_res50_8xb2-50e_coco-800x1333.py @@ -0,0 +1,236 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.config import read_base + +with read_base(): + from mmpose.configs._base_.default_runtime import * # noqa + +from mmcv.transforms import RandomChoice, RandomChoiceResize +from mmengine.dataset import DefaultSampler +from mmengine.model import PretrainedInit +from mmengine.optim import LinearLR, MultiStepLR +from torch.nn import GroupNorm +from torch.optim import Adam + +from mmpose.codecs import EDPoseLabel +from mmpose.datasets import (BottomupRandomChoiceResize, BottomupRandomCrop, + CocoDataset, LoadImage, PackPoseInputs, + RandomFlip) +from mmpose.evaluation import CocoMetric +from mmpose.models import (BottomupPoseEstimator, ChannelMapper, EDPoseHead, + PoseDataPreprocessor, ResNet) +from mmpose.models.utils import FrozenBatchNorm2d + +# runtime +train_cfg.update(max_epochs=50, val_interval=10) # noqa + +# optimizer +optim_wrapper = dict(optimizer=dict( + type=Adam, + lr=1e-3, +)) + +# learning policy +param_scheduler = [ + dict(type=LinearLR, begin=0, end=500, start_factor=0.001, + by_epoch=False), # warm-up + dict( + type=MultiStepLR, + begin=0, + end=140, + milestones=[33, 45], + gamma=0.1, + by_epoch=True) +] + +# automatically scaling LR based on the actual training batch size +auto_scale_lr = dict(base_batch_size=80) + +# hooks +default_hooks.update( # noqa + checkpoint=dict(save_best='coco/AP', rule='greater')) + +# codec settings +codec = dict(type=EDPoseLabel, num_select=50, num_keypoints=17) + +# model settings +model = dict( + type=BottomupPoseEstimator, + data_preprocessor=dict( + type=PoseDataPreprocessor, + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=1), + backbone=dict( + type=ResNet, + depth=50, + num_stages=4, + out_indices=(1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type=FrozenBatchNorm2d, requires_grad=False), + norm_eval=True, + style='pytorch', + init_cfg=dict( + type=PretrainedInit, checkpoint='torchvision://resnet50')), + neck=dict( + type=ChannelMapper, + in_channels=[512, 1024, 2048], + kernel_size=1, + out_channels=256, + act_cfg=None, + norm_cfg=dict(type=GroupNorm, num_groups=32), + num_outs=4), + head=dict( + type=EDPoseHead, + num_queries=900, + num_feature_levels=4, + num_keypoints=17, + as_two_stage=True, + encoder=dict( + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_heads=8, + num_levels=4, + num_points=4, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + ffn_drop=0.0))), + decoder=dict( + num_layers=6, + embed_dims=256, + layer_cfg=dict( # DeformableDetrTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + batch_first=True), + cross_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, feedforward_channels=2048, ffn_drop=0.1)), + query_dim=4, + num_feature_levels=4, + num_group=100, + num_dn=100, + num_box_decoder_layers=2, + return_intermediate=True), + out_head=dict(num_classes=2), + positional_encoding=dict( + num_pos_feats=128, + temperatureH=20, + temperatureW=20, + normalize=True), + denosing_cfg=dict( + dn_box_noise_scale=0.4, + dn_label_noise_ratio=0.5, + dn_labelbook_size=100, + dn_attn_mask_type_list=['match2dn', 'dn2dn', 'group2group']), + data_decoder=codec), + test_cfg=dict(Pmultiscale_test=False, flip_test=False, num_select=50), + train_cfg=dict()) + +# enable DDP training when rescore net is used +find_unused_parameters = True + +# base dataset settings +dataset_type = CocoDataset +data_mode = 'bottomup' +data_root = 'data/coco/' + +# pipelines +train_pipeline = [ + dict(type=LoadImage), + dict(type=RandomFlip, direction='horizontal'), + dict( + type=RandomChoice, + transforms=[ + [ + dict( + type=RandomChoiceResize, + scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), + (608, 1333), (640, 1333), (672, 1333), (704, 1333), + (736, 1333), (768, 1333), (800, 1333)], + keep_ratio=True) + ], + [ + dict( + type=BottomupRandomChoiceResize, + # The radio of all image in train dataset < 7 + # follow the original implement + scales=[(400, 4200), (500, 4200), (600, 4200)], + keep_ratio=True), + dict( + type=BottomupRandomCrop, + crop_type='absolute_range', + crop_size=(384, 600), + allow_negative_crop=True), + dict( + type=BottomupRandomChoiceResize, + scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), + (608, 1333), (640, 1333), (672, 1333), (704, 1333), + (736, 1333), (768, 1333), (800, 1333)], + keep_ratio=True) + ] + ]), + dict(type=PackPoseInputs), +] + +val_pipeline = [ + dict(type=LoadImage), + dict( + type=BottomupRandomChoiceResize, + scales=[(800, 1333)], + keep_ratio=True, + backend='pillow'), + dict( + type=PackPoseInputs, + meta_keys=('id', 'img_id', 'img_path', 'crowd_index', 'ori_shape', + 'img_shape', 'input_size', 'input_center', 'input_scale', + 'flip', 'flip_direction', 'flip_indices', 'raw_ann_info', + 'skeleton_links')) +] + +# data loaders +train_dataloader = dict( + batch_size=1, + num_workers=1, + persistent_workers=True, + sampler=dict(type=DefaultSampler, shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_mode=data_mode, + ann_file='annotations/person_keypoints_train2017.json', + data_prefix=dict(img='train2017/'), + pipeline=train_pipeline, + )) + +val_dataloader = dict( + batch_size=4, + num_workers=8, + persistent_workers=True, + drop_last=False, + sampler=dict(type=DefaultSampler, shuffle=False, round_up=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_mode=data_mode, + ann_file='annotations/person_keypoints_val2017.json', + data_prefix=dict(img='val2017/'), + test_mode=True, + pipeline=val_pipeline, + )) +test_dataloader = val_dataloader + +# evaluators +val_evaluator = dict( + type=CocoMetric, + nms_mode='none', + score_mode='keypoint', +) +test_evaluator = val_evaluator diff --git a/docs/src/papers/algorithms/edpose.md b/docs/src/papers/algorithms/edpose.md new file mode 100644 index 0000000000..07acf2edb5 --- /dev/null +++ b/docs/src/papers/algorithms/edpose.md @@ -0,0 +1,31 @@ +# Explicit Box Detection Unifies End-to-End Multi-Person Pose Estimation + + + +
+ED-Pose (ICLR'2023) + +```bibtex +@inproceedings{ +yang2023explicit, +title={Explicit Box Detection Unifies End-to-End Multi-Person Pose Estimation}, +author={Jie Yang and Ailing Zeng and Shilong Liu and Feng Li and Ruimao Zhang and Lei Zhang}, +booktitle={International Conference on Learning Representations}, +year={2023}, +url={https://openreview.net/forum?id=s4WVupnJjmX} +} +``` + +
+ +## Abstract + + + +This paper presents a novel end-to-end framework with Explicit box Detection for multi-person Pose estimation, called ED-Pose, where it unifies the contextual learning between human-level (global) and keypoint-level (local) information. Different from previous one-stage methods, ED-Pose re-considers this task as two explicit box detection processes with a unified representation and regression supervision. First, we introduce a human detection decoder from encoded tokens to extract global features. It can provide a good initialization for the latter keypoint detection, making the training process converge fast. Second, to bring in contextual information near keypoints, we regard pose estimation as a keypoint box detection problem to learn both box positions and contents for each keypoint. A human-to-keypoint detection decoder adopts an interactive learning strategy between human and keypoint features to further enhance global and local feature aggregation. In general, ED-Pose is conceptually simple without post-processing and dense heatmap supervision. It demonstrates its effectiveness and efficiency compared with both two-stage and one-stage methods. Notably, explicit box detection boosts the pose estimation performance by 4.5 AP on COCO and 9.9 AP on CrowdPose. For the first time, as a fully end-to-end framework with a L1 regression loss, ED-Pose surpasses heatmap-based Top-down methods under the same backbone by 1.2 AP on COCO and achieves the state-of-the-art with 76.6 AP on CrowdPose without bells and whistles. Code is available at https://github.com/IDEA-Research/ED-Pose. + + + +
+ +
diff --git a/mmpose/apis/inference.py b/mmpose/apis/inference.py index 772ef17b7c..5662d6f30b 100644 --- a/mmpose/apis/inference.py +++ b/mmpose/apis/inference.py @@ -53,7 +53,8 @@ def dataset_meta_from_config(config: Config, import mmpose.datasets.datasets # noqa: F401, F403 from mmpose.registry import DATASETS - dataset_class = DATASETS.get(dataset_cfg.type) + dataset_class = dataset_cfg.type if isinstance( + dataset_cfg.type, type) else DATASETS.get(dataset_cfg.type) metainfo = dataset_class.METAINFO metainfo = parse_pose_metainfo(metainfo) diff --git a/mmpose/codecs/__init__.py b/mmpose/codecs/__init__.py index 102a202e7d..f2dec61ca2 100644 --- a/mmpose/codecs/__init__.py +++ b/mmpose/codecs/__init__.py @@ -2,6 +2,7 @@ from .annotation_processors import YOLOXPoseAnnotationProcessor from .associative_embedding import AssociativeEmbedding from .decoupled_heatmap import DecoupledHeatmap +from .edpose_label import EDPoseLabel from .image_pose_lifting import ImagePoseLifting from .integral_regression_label import IntegralRegressionLabel from .megvii_heatmap import MegviiHeatmap @@ -17,5 +18,5 @@ 'MSRAHeatmap', 'MegviiHeatmap', 'UDPHeatmap', 'RegressionLabel', 'SimCCLabel', 'IntegralRegressionLabel', 'AssociativeEmbedding', 'SPR', 'DecoupledHeatmap', 'VideoPoseLifting', 'ImagePoseLifting', - 'MotionBERTLabel', 'YOLOXPoseAnnotationProcessor' + 'MotionBERTLabel', 'YOLOXPoseAnnotationProcessor', 'EDPoseLabel' ] diff --git a/mmpose/codecs/edpose_label.py b/mmpose/codecs/edpose_label.py new file mode 100644 index 0000000000..0433784886 --- /dev/null +++ b/mmpose/codecs/edpose_label.py @@ -0,0 +1,153 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import numpy as np + +from mmpose.registry import KEYPOINT_CODECS +from mmpose.structures import bbox_cs2xyxy, bbox_xyxy2cs +from .base import BaseKeypointCodec + + +@KEYPOINT_CODECS.register_module() +class EDPoseLabel(BaseKeypointCodec): + r"""Generate keypoint and label coordinates for `ED-Pose`_ by + Yang J. et al (2023). + + Note: + + - instance number: N + - keypoint number: K + - keypoint dimension: D + - image size: [w, h] + + Encoded: + + - keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D) + - keypoints_visible (np.ndarray): Keypoint visibility in shape + (N, K, D) + - area (np.ndarray): Area in shape (N) + - bbox (np.ndarray): Bbox in shape (N, 4) + + Args: + num_select (int): The number of candidate instances + num_keypoints (int): The Number of keypoints + """ + + auxiliary_encode_keys = {'area', 'bboxes', 'img_shape'} + instance_mapping_table = dict( + bbox='bboxes', + keypoints='keypoints', + keypoints_visible='keypoints_visible', + area='areas', + ) + + def __init__(self, num_select: int = 100, num_keypoints: int = 17): + super().__init__() + + self.num_select = num_select + self.num_keypoints = num_keypoints + + def encode( + self, + img_shape, + keypoints: np.ndarray, + keypoints_visible: Optional[np.ndarray] = None, + area: Optional[np.ndarray] = None, + bboxes: Optional[np.ndarray] = None, + ) -> dict: + """Encoding keypoints, area and bbox from input image space to + normalized space. + + Args: + - img_shape (Sequence[int]): The shape of image in the format + of (width, height). + - keypoints (np.ndarray): Keypoint coordinates in + shape (N, K, D). + - keypoints_visible (np.ndarray): Keypoint visibility in shape + (N, K) + - area (np.ndarray): + - bboxes (np.ndarray): + + Returns: + encoded (dict): Contains the following items: + + - keypoint_labels (np.ndarray): The processed keypoints in + shape like (N, K, D). + - keypoints_visible (np.ndarray): Keypoint visibility in shape + (N, K, D) + - area_labels (np.ndarray): The processed target + area in shape (N). + - bboxes_labels: The processed target bbox in + shape (N, 4). + """ + w, h = img_shape + + if keypoints_visible is None: + keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32) + + if bboxes is not None: + bboxes = np.concatenate(bbox_xyxy2cs(bboxes), axis=-1) + bboxes = bboxes / np.array([w, h, w, h], dtype=np.float32) + + if area is not None: + area = area / float(w * h) + + if keypoints is not None: + keypoints = keypoints / np.array([w, h], dtype=np.float32) + + encoded = dict( + keypoints=keypoints, + area=area, + bbox=bboxes, + keypoints_visible=keypoints_visible) + + return encoded + + def decode(self, input_shapes: np.ndarray, pred_logits: np.ndarray, + pred_boxes: np.ndarray, pred_keypoints: np.ndarray): + """Select the final top-k keypoints, and decode the results from + normalize size to origin input size. + + Args: + input_shapes (Tensor): The size of input image resize. + test_cfg (ConfigType): Config of testing. + pred_logits (Tensor): The result of score. + pred_boxes (Tensor): The result of bbox. + pred_keypoints (Tensor): The result of keypoints. + + Returns: + tuple: Decoded boxes, keypoints, and keypoint scores. + """ + + # Initialization + num_keypoints = self.num_keypoints + prob = pred_logits.reshape(-1) + + # Select top-k instances based on prediction scores + topk_indexes = np.argsort(-prob)[:self.num_select] + topk_values = np.take_along_axis(prob, topk_indexes, axis=0) + scores = np.tile(topk_values[:, np.newaxis], [1, num_keypoints]) + + # Decode bounding boxes + topk_boxes = topk_indexes // pred_logits.shape[1] + boxes = bbox_cs2xyxy(*np.split(pred_boxes, [2], axis=-1)) + boxes = np.take_along_axis( + boxes, np.tile(topk_boxes[:, np.newaxis], [1, 4]), axis=0) + + # Convert from relative to absolute coordinates + img_h, img_w = np.split(input_shapes, 2, axis=0) + scale_fct = np.hstack([img_w, img_h, img_w, img_h]) + boxes = boxes * scale_fct[np.newaxis, :] + + # Decode keypoints + topk_keypoints = topk_indexes // pred_logits.shape[1] + keypoints = np.take_along_axis( + pred_keypoints, + np.tile(topk_keypoints[:, np.newaxis], [1, num_keypoints * 3]), + axis=0) + keypoints = keypoints[:, :(num_keypoints * 2)] + keypoints = keypoints * np.tile( + np.hstack([img_w, img_h]), [num_keypoints])[np.newaxis, :] + keypoints = keypoints.reshape(-1, num_keypoints, 2) + + return boxes, keypoints, scores diff --git a/mmpose/datasets/transforms/__init__.py b/mmpose/datasets/transforms/__init__.py index 46ca6c749e..fb9a5fc0bb 100644 --- a/mmpose/datasets/transforms/__init__.py +++ b/mmpose/datasets/transforms/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .bottomup_transforms import (BottomupGetHeatmapMask, BottomupRandomAffine, - BottomupResize) + BottomupRandomChoiceResize, + BottomupRandomCrop, BottomupResize) from .common_transforms import (Albumentation, FilterAnnotations, GenerateTarget, GetBBoxCenterScale, PhotometricDistortion, RandomBBoxTransform, @@ -18,5 +19,6 @@ 'PhotometricDistortion', 'PackPoseInputs', 'LoadImage', 'BottomupGetHeatmapMask', 'BottomupRandomAffine', 'BottomupResize', 'GenerateTarget', 'KeypointConverter', 'RandomFlipAroundRoot', - 'FilterAnnotations', 'YOLOXHSVRandomAug', 'YOLOXMixUp', 'Mosaic' + 'FilterAnnotations', 'YOLOXHSVRandomAug', 'YOLOXMixUp', 'Mosaic', + 'BottomupRandomCrop', 'BottomupRandomChoiceResize' ] diff --git a/mmpose/datasets/transforms/bottomup_transforms.py b/mmpose/datasets/transforms/bottomup_transforms.py index 5ef2fa5838..0175e013dc 100644 --- a/mmpose/datasets/transforms/bottomup_transforms.py +++ b/mmpose/datasets/transforms/bottomup_transforms.py @@ -1,11 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. from functools import partial -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Sequence, Tuple, Union import cv2 import numpy as np import xtcocotools.mask as cocomask from mmcv.image import imflip_, imresize +from mmcv.image.geometric import imrescale from mmcv.transforms import BaseTransform from mmcv.transforms.utils import cache_randomness from scipy.stats import truncnorm @@ -607,3 +608,416 @@ def transform(self, results: Dict) -> Optional[dict]: results['aug_scale'] = None return results + + +@TRANSFORMS.register_module() +class BottomupRandomCrop(BaseTransform): + """Random crop the image & bboxes & masks. + + The absolute ``crop_size`` is sampled based on ``crop_type`` and + ``image_size``, then the cropped results are generated. + + Required Keys: + + - img + - keypoints + - bbox (optional) + - masks (BitmapMasks | PolygonMasks) (optional) + + Modified Keys: + + - img + - img_shape + - keypoints + - keypoints_visible + - num_keypoints + - bbox (optional) + - bbox_score (optional) + - id (optional) + - category_id (optional) + - raw_ann_info (optional) + - iscrowd (optional) + - segmentation (optional) + - masks (optional) + + Added Keys: + + - warp_mat + + Args: + crop_size (tuple): The relative ratio or absolute pixels of + (width, height). + crop_type (str, optional): One of "relative_range", "relative", + "absolute", "absolute_range". "relative" randomly crops + (h * crop_size[0], w * crop_size[1]) part from an input of size + (h, w). "relative_range" uniformly samples relative crop size from + range [crop_size[0], 1] and [crop_size[1], 1] for height and width + respectively. "absolute" crops from an input with absolute size + (crop_size[0], crop_size[1]). "absolute_range" uniformly samples + crop_h in range [crop_size[0], min(h, crop_size[1])] and crop_w + in range [crop_size[0], min(w, crop_size[1])]. + Defaults to "absolute". + allow_negative_crop (bool, optional): Whether to allow a crop that does + not contain any bbox area. Defaults to False. + recompute_bbox (bool, optional): Whether to re-compute the boxes based + on cropped instance masks. Defaults to False. + bbox_clip_border (bool, optional): Whether clip the objects outside + the border of the image. Defaults to True. + + Note: + - If the image is smaller than the absolute crop size, return the + original image. + - If the crop does not contain any gt-bbox region and + ``allow_negative_crop`` is set to False, skip this image. + """ + + def __init__(self, + crop_size: tuple, + crop_type: str = 'absolute', + allow_negative_crop: bool = False, + recompute_bbox: bool = False, + bbox_clip_border: bool = True) -> None: + if crop_type not in [ + 'relative_range', 'relative', 'absolute', 'absolute_range' + ]: + raise ValueError(f'Invalid crop_type {crop_type}.') + if crop_type in ['absolute', 'absolute_range']: + assert crop_size[0] > 0 and crop_size[1] > 0 + assert isinstance(crop_size[0], int) and isinstance( + crop_size[1], int) + if crop_type == 'absolute_range': + assert crop_size[0] <= crop_size[1] + else: + assert 0 < crop_size[0] <= 1 and 0 < crop_size[1] <= 1 + self.crop_size = crop_size + self.crop_type = crop_type + self.allow_negative_crop = allow_negative_crop + self.bbox_clip_border = bbox_clip_border + self.recompute_bbox = recompute_bbox + + def _crop_data(self, results: dict, crop_size: Tuple[int, int], + allow_negative_crop: bool) -> Union[dict, None]: + """Function to randomly crop images, bounding boxes, masks, semantic + segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + crop_size (Tuple[int, int]): Expected absolute size after + cropping, (h, w). + allow_negative_crop (bool): Whether to allow a crop that does not + contain any bbox area. + + Returns: + results (Union[dict, None]): Randomly cropped results, 'img_shape' + key in result dict is updated according to crop size. None will + be returned when there is no valid bbox after cropping. + """ + assert crop_size[0] > 0 and crop_size[1] > 0 + img = results['img'] + margin_h = max(img.shape[0] - crop_size[0], 0) + margin_w = max(img.shape[1] - crop_size[1], 0) + offset_h, offset_w = self._rand_offset((margin_h, margin_w)) + crop_y1, crop_y2 = offset_h, offset_h + crop_size[0] + crop_x1, crop_x2 = offset_w, offset_w + crop_size[1] + + # Record the warp matrix for the RandomCrop + warp_mat = np.array([[1, 0, -offset_w], [0, 1, -offset_h], [0, 0, 1]], + dtype=np.float32) + if results.get('warp_mat', None) is None: + results['warp_mat'] = warp_mat + else: + results['warp_mat'] = warp_mat @ results['warp_mat'] + + # crop the image + img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...] + img_shape = img.shape + results['img'] = img + results['img_shape'] = img_shape[:2] + + # crop bboxes accordingly and clip to the image boundary + if results.get('bbox', None) is not None: + distances = (-offset_w, -offset_h) + bboxes = results['bbox'] + bboxes = bboxes + np.tile(np.asarray(distances), 2) + + if self.bbox_clip_border: + bboxes[..., 0::2] = bboxes[..., 0::2].clip(0, img_shape[1]) + bboxes[..., 1::2] = bboxes[..., 1::2].clip(0, img_shape[0]) + + valid_inds = (bboxes[..., 0] < img_shape[1]) & \ + (bboxes[..., 1] < img_shape[0]) & \ + (bboxes[..., 2] > 0) & \ + (bboxes[..., 3] > 0) + + # If the crop does not contain any gt-bbox area and + # allow_negative_crop is False, skip this image. + if (not valid_inds.any() and not allow_negative_crop): + return None + + results['bbox'] = bboxes[valid_inds] + meta_keys = [ + 'bbox_score', 'id', 'category_id', 'raw_ann_info', 'iscrowd' + ] + for key in meta_keys: + if results.get(key): + if isinstance(results[key], list): + results[key] = np.asarray( + results[key])[valid_inds].tolist() + else: + results[key] = results[key][valid_inds] + + if results.get('keypoints', None) is not None: + keypoints = results['keypoints'] + distances = np.asarray(distances).reshape(1, 1, 2) + keypoints = keypoints + distances + if self.bbox_clip_border: + keypoints_outside_x = keypoints[:, :, 0] < 0 + keypoints_outside_y = keypoints[:, :, 1] < 0 + keypoints_outside_width = keypoints[:, :, 0] > img_shape[1] + keypoints_outside_height = keypoints[:, :, + 1] > img_shape[0] + + kpt_outside = np.logical_or.reduce( + (keypoints_outside_x, keypoints_outside_y, + keypoints_outside_width, keypoints_outside_height)) + + results['keypoints_visible'][kpt_outside] *= 0 + keypoints[:, :, 0] = keypoints[:, :, 0].clip(0, img_shape[1]) + keypoints[:, :, 1] = keypoints[:, :, 1].clip(0, img_shape[0]) + results['keypoints'] = keypoints[valid_inds] + results['keypoints_visible'] = results['keypoints_visible'][ + valid_inds] + + if results.get('segmentation', None) is not None: + results['segmentation'] = results['segmentation'][ + crop_y1:crop_y2, crop_x1:crop_x2] + + if results.get('masks', None) is not None: + results['masks'] = results['masks'][valid_inds.nonzero( + )[0]].crop(np.asarray([crop_x1, crop_y1, crop_x2, crop_y2])) + if self.recompute_bbox: + results['bbox'] = results['masks'].get_bboxes( + type(results['bbox'])) + + return results + + @cache_randomness + def _rand_offset(self, margin: Tuple[int, int]) -> Tuple[int, int]: + """Randomly generate crop offset. + + Args: + margin (Tuple[int, int]): The upper bound for the offset generated + randomly. + + Returns: + Tuple[int, int]: The random offset for the crop. + """ + margin_h, margin_w = margin + offset_h = np.random.randint(0, margin_h + 1) + offset_w = np.random.randint(0, margin_w + 1) + + return offset_h, offset_w + + @cache_randomness + def _get_crop_size(self, image_size: Tuple[int, int]) -> Tuple[int, int]: + """Randomly generates the absolute crop size based on `crop_type` and + `image_size`. + + Args: + image_size (Tuple[int, int]): (h, w). + + Returns: + crop_size (Tuple[int, int]): (crop_h, crop_w) in absolute pixels. + """ + h, w = image_size + if self.crop_type == 'absolute': + return min(self.crop_size[1], h), min(self.crop_size[0], w) + elif self.crop_type == 'absolute_range': + crop_h = np.random.randint( + min(h, self.crop_size[0]), + min(h, self.crop_size[1]) + 1) + crop_w = np.random.randint( + min(w, self.crop_size[0]), + min(w, self.crop_size[1]) + 1) + return crop_h, crop_w + elif self.crop_type == 'relative': + crop_w, crop_h = self.crop_size + return int(h * crop_h + 0.5), int(w * crop_w + 0.5) + else: + # 'relative_range' + crop_size = np.asarray(self.crop_size, dtype=np.float32) + crop_h, crop_w = crop_size + np.random.rand(2) * (1 - crop_size) + return int(h * crop_h + 0.5), int(w * crop_w + 0.5) + + def transform(self, results: dict) -> Union[dict, None]: + """Transform function to randomly crop images, bounding boxes, masks, + semantic segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + results (Union[dict, None]): Randomly cropped results, 'img_shape' + key in result dict is updated according to crop size. None will + be returned when there is no valid bbox after cropping. + """ + image_size = results['img'].shape[:2] + crop_size = self._get_crop_size(image_size) + results = self._crop_data(results, crop_size, self.allow_negative_crop) + return results + + +@TRANSFORMS.register_module() +class BottomupRandomChoiceResize(BaseTransform): + """Resize images & bbox & mask from a list of multiple scales. + + This transform resizes the input image to some scale. Bboxes and masks are + then resized with the same scale factor. Resize scale will be randomly + selected from ``scales``. + + How to choose the target scale to resize the image will follow the rules + below: + + - if `scale` is a list of tuple, the target scale is sampled from the list + uniformally. + - if `scale` is a tuple, the target scale will be set to the tuple. + + Required Keys: + + - img + - bbox + - keypoints + + Modified Keys: + + - img + - img_shape + - bbox + - keypoints + + Added Keys: + + - scale + - scale_factor + - scale_idx + + Args: + scales (Union[list, Tuple]): Images scales for resizing. + + **resize_kwargs: Other keyword arguments for the ``resize_type``. + """ + + def __init__( + self, + scales: Sequence[Union[int, Tuple]], + keep_ratio: bool = False, + clip_object_border: bool = True, + backend: str = 'cv2', + **resize_kwargs, + ) -> None: + super().__init__() + if isinstance(scales, list): + self.scales = scales + else: + self.scales = [scales] + + self.keep_ratio = keep_ratio + self.clip_object_border = clip_object_border + self.backend = backend + + @cache_randomness + def _random_select(self) -> Tuple[int, int]: + """Randomly select an scale from given candidates. + + Returns: + (tuple, int): Returns a tuple ``(scale, scale_dix)``, + where ``scale`` is the selected image scale and + ``scale_idx`` is the selected index in the given candidates. + """ + + scale_idx = np.random.randint(len(self.scales)) + scale = self.scales[scale_idx] + return scale, scale_idx + + def _resize_img(self, results: dict) -> None: + """Resize images with ``self.scale``.""" + + if self.keep_ratio: + + img, scale_factor = imrescale( + results['img'], + self.scale, + interpolation='bilinear', + return_scale=True, + backend=self.backend) + # the w_scale and h_scale has minor difference + # a real fix should be done in the mmcv.imrescale in the future + new_h, new_w = img.shape[:2] + h, w = results['img'].shape[:2] + w_scale = new_w / w + h_scale = new_h / h + else: + img, w_scale, h_scale = imresize( + results['img'], + self.scale, + interpolation='bilinear', + return_scale=True, + backend=self.backend) + + results['img'] = img + results['img_shape'] = img.shape[:2] + results['scale_factor'] = (w_scale, h_scale) + results['input_size'] = img.shape[:2] + w, h = results['ori_shape'] + center = np.array([w / 2, h / 2], dtype=np.float32) + scale = np.array([w, h], dtype=np.float32) + results['input_center'] = center + results['input_scale'] = scale + + def _resize_bboxes(self, results: dict) -> None: + """Resize bounding boxes with ``self.scale``.""" + if results.get('bbox', None) is not None: + bboxes = results['bbox'] * np.tile( + np.array(results['scale_factor']), 2) + if self.clip_object_border: + bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, + results['img_shape'][1]) + bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, + results['img_shape'][0]) + results['bbox'] = bboxes + + def _resize_keypoints(self, results: dict) -> None: + """Resize keypoints with ``self.scale``.""" + if results.get('keypoints', None) is not None: + keypoints = results['keypoints'] + + keypoints[:, :, :2] = keypoints[:, :, :2] * np.array( + results['scale_factor']) + if self.clip_object_border: + keypoints[:, :, 0] = np.clip(keypoints[:, :, 0], 0, + results['img_shape'][1]) + keypoints[:, :, 1] = np.clip(keypoints[:, :, 1], 0, + results['img_shape'][0]) + results['keypoints'] = keypoints + + def transform(self, results: dict) -> dict: + """Apply resize transforms on results from a list of scales. + + Args: + results (dict): Result dict contains the data to transform. + + Returns: + dict: Resized results, 'img', 'bbox', + 'keypoints', 'scale', 'scale_factor', 'img_shape', + and 'keep_ratio' keys are updated in result dict. + """ + + target_scale, scale_idx = self._random_select() + + self.scale = target_scale + self._resize_img(results) + self._resize_bboxes(results) + self._resize_keypoints(results) + + results['scale_idx'] = scale_idx + return results diff --git a/mmpose/models/data_preprocessors/__init__.py b/mmpose/models/data_preprocessors/__init__.py index 7abf9a6af0..89980f1f6e 100644 --- a/mmpose/models/data_preprocessors/__init__.py +++ b/mmpose/models/data_preprocessors/__init__.py @@ -2,4 +2,7 @@ from .batch_augmentation import BatchSyncRandomResize from .data_preprocessor import PoseDataPreprocessor -__all__ = ['PoseDataPreprocessor', 'BatchSyncRandomResize'] +__all__ = [ + 'PoseDataPreprocessor', + 'BatchSyncRandomResize', +] diff --git a/mmpose/models/data_preprocessors/data_preprocessor.py b/mmpose/models/data_preprocessors/data_preprocessor.py index b5ce1e7fdd..9442d0ed50 100644 --- a/mmpose/models/data_preprocessors/data_preprocessor.py +++ b/mmpose/models/data_preprocessors/data_preprocessor.py @@ -12,7 +12,43 @@ @MODELS.register_module() class PoseDataPreprocessor(ImgDataPreprocessor): - """Image pre-processor for pose estimation tasks.""" + """Image pre-processor for pose estimation tasks. + + Comparing with the :class:`ImgDataPreprocessor`, + + 1. It will additionally append batch_input_shape + to data_samples considering the DETR-based pose estimation tasks. + + 2. Support image augmentation transforms on batched data. + + It provides the data pre-processing as follows + + - Collate and move data to the target device. + - Pad inputs to the maximum size of current batch with defined + ``pad_value``. The padding size can be divisible by a defined + ``pad_size_divisor`` + - Stack inputs to batch_inputs. + - Convert inputs from bgr to rgb if the shape of input is (3, H, W). + - Normalize image with defined std and mean. + - Apply batch augmentation transforms. + + Args: + mean (sequence of float, optional): The pixel mean of R, G, B + channels. Defaults to None. + std (sequence of float, optional): The pixel standard deviation + of R, G, B channels. Defaults to None. + pad_size_divisor (int): The size of padded image should be + divisible by ``pad_size_divisor``. Defaults to 1. + pad_value (float or int): The padded pixel value. Defaults to 0. + bgr_to_rgb (bool): whether to convert image from BGR to RGB. + Defaults to False. + rgb_to_bgr (bool): whether to convert image from RGB to BGR. + Defaults to False. + non_blocking (bool): Whether block current process + when transferring data to device. Defaults to False. + batch_augments: (list of dict, optional): Configs of augmentation + transforms on batched data. Defaults to None. + """ def __init__(self, mean: Sequence[float] = None, @@ -31,6 +67,7 @@ def __init__(self, bgr_to_rgb=bgr_to_rgb, rgb_to_bgr=rgb_to_bgr, non_blocking=non_blocking) + if batch_augments is not None: self.batch_augments = nn.ModuleList( [MODELS.build(aug) for aug in batch_augments]) @@ -51,6 +88,8 @@ def forward(self, data: dict, training: bool = False) -> dict: batch_pad_shape = self._get_pad_shape(data) data = super().forward(data=data, training=training) inputs, data_samples = data['inputs'], data['data_samples'] + + # update metainfo since the image shape might change batch_input_shape = tuple(inputs[0].size()[-2:]) for data_sample, pad_shape in zip(data_samples, batch_pad_shape): data_sample.set_metainfo({ @@ -58,6 +97,7 @@ def forward(self, data: dict, training: bool = False) -> dict: 'pad_shape': pad_shape }) + # apply batch augmentations if training and self.batch_augments is not None: for batch_aug in self.batch_augments: inputs, data_samples = batch_aug(inputs, data_samples) diff --git a/mmpose/models/heads/__init__.py b/mmpose/models/heads/__init__.py index ef0e17d98e..8415354486 100644 --- a/mmpose/models/heads/__init__.py +++ b/mmpose/models/heads/__init__.py @@ -8,11 +8,12 @@ MotionRegressionHead, RegressionHead, RLEHead, TemporalRegressionHead, TrajectoryRegressionHead) +from .transformer_heads import EDPoseHead __all__ = [ 'BaseHead', 'HeatmapHead', 'CPMHead', 'MSPNHead', 'ViPNASHead', 'RegressionHead', 'IntegralRegressionHead', 'SimCCHead', 'RLEHead', 'DSNTHead', 'AssociativeEmbeddingHead', 'DEKRHead', 'VisPredictHead', 'CIDHead', 'RTMCCHead', 'TemporalRegressionHead', - 'TrajectoryRegressionHead', 'MotionRegressionHead' + 'TrajectoryRegressionHead', 'MotionRegressionHead', 'EDPoseHead' ] diff --git a/mmpose/models/heads/transformer_heads/__init__.py b/mmpose/models/heads/transformer_heads/__init__.py new file mode 100644 index 0000000000..bb16484ff8 --- /dev/null +++ b/mmpose/models/heads/transformer_heads/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .edpose_head import EDPoseHead +from .transformers import (FFN, DeformableDetrTransformerDecoder, + DeformableDetrTransformerDecoderLayer, + DeformableDetrTransformerEncoder, + DeformableDetrTransformerEncoderLayer, + DetrTransformerDecoder, DetrTransformerDecoderLayer, + DetrTransformerEncoder, DetrTransformerEncoderLayer, + PositionEmbeddingSineHW) + +__all__ = [ + 'EDPoseHead', 'DetrTransformerEncoder', 'DetrTransformerDecoder', + 'DetrTransformerEncoderLayer', 'DetrTransformerDecoderLayer', + 'DeformableDetrTransformerEncoder', 'DeformableDetrTransformerDecoder', + 'DeformableDetrTransformerEncoderLayer', + 'DeformableDetrTransformerDecoderLayer', 'PositionEmbeddingSineHW', 'FFN' +] diff --git a/mmpose/models/heads/transformer_heads/base_transformer_head.py b/mmpose/models/heads/transformer_heads/base_transformer_head.py new file mode 100644 index 0000000000..96855e186d --- /dev/null +++ b/mmpose/models/heads/transformer_heads/base_transformer_head.py @@ -0,0 +1,136 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import abstractmethod +from typing import Dict, Tuple + +import torch +from torch import Tensor + +from mmpose.registry import MODELS +from mmpose.utils.typing import (Features, OptConfigType, OptMultiConfig, + OptSampleList, Predictions) +from ..base_head import BaseHead + + +@MODELS.register_module() +class TransformerHead(BaseHead): + r"""Implementation of `Deformable DETR: Deformable Transformers for + End-to-End Object Detection `_ + + Code is modified from the `official github repo + `_. + + Args: + encoder (ConfigDict, optional): Config of the + Transformer encoder. Defaults to None. + decoder (ConfigDict, optional): Config of the + Transformer decoder. Defaults to None. + out_head (ConfigDict, optional): Config for the + bounding final out head module. Defaults to None. + positional_encoding (ConfigDict, optional): Config for + transformer position encoding. Defaults to None. + num_queries (int): Number of query in Transformer. + loss (ConfigDict, optional): Config for loss functions. + Defaults to None. + init_cfg (ConfigDict, optional): Config to control the initialization. + """ + + def __init__(self, + encoder: OptConfigType = None, + decoder: OptConfigType = None, + out_head: OptConfigType = None, + positional_encoding: OptConfigType = None, + num_queries: int = 100, + loss: OptConfigType = None, + init_cfg: OptMultiConfig = None): + + if init_cfg is None: + init_cfg = self.default_init_cfg + + super().__init__(init_cfg) + + self.encoder_cfg = encoder + self.decoder_cfg = decoder + self.out_head_cfg = out_head + self.positional_encoding_cfg = positional_encoding + self.num_queries = num_queries + + def forward(self, + feats: Tuple[Tensor], + batch_data_samples: OptSampleList = None) -> Dict: + """Forward the network.""" + encoder_outputs_dict = self.forward_encoder(feats, batch_data_samples) + + decoder_outputs_dict = self.forward_decoder(**encoder_outputs_dict) + + head_outputs_dict = self.forward_out_head(batch_data_samples, + **decoder_outputs_dict) + return head_outputs_dict + + @abstractmethod + def predict(self, + feats: Features, + batch_data_samples: OptSampleList, + test_cfg: OptConfigType = {}) -> Predictions: + """Predict results from features.""" + pass + + def loss(self, + feats: Tuple[Tensor], + batch_data_samples: OptSampleList, + train_cfg: OptConfigType = {}) -> dict: + """Calculate losses from a batch of inputs and data samples.""" + pass + + @abstractmethod + def forward_encoder(self, feat: Tensor, feat_mask: Tensor, + feat_pos: Tensor, **kwargs) -> Dict: + pass + + @abstractmethod + def forward_decoder(self, query: Tensor, query_pos: Tensor, memory: Tensor, + **kwargs) -> Dict: + pass + + @abstractmethod + def forward_out_head(self, query: Tensor, query_pos: Tensor, + memory: Tensor, **kwargs) -> Dict: + pass + + @staticmethod + def get_valid_ratio(mask: Tensor) -> Tensor: + """Get the valid radios of feature map in a level. + + .. code:: text + + |---> valid_W <---| + ---+-----------------+-----+--- + A | | | A + | | | | | + | | | | | + valid_H | | | | + | | | | H + | | | | | + V | | | | + ---+-----------------+ | | + | | V + +-----------------------+--- + |---------> W <---------| + + The valid_ratios are defined as: + r_h = valid_H / H, r_w = valid_W / W + They are the factors to re-normalize the relative coordinates of the + image to the relative coordinates of the current level feature map. + + Args: + mask (Tensor): Binary mask of a feature map, has shape (bs, H, W). + + Returns: + Tensor: valid ratios [r_w, r_h] of a feature map, has shape (1, 2). + """ + _, H, W = mask.shape + valid_H = torch.sum(~mask[:, :, 0], 1) + valid_W = torch.sum(~mask[:, 0, :], 1) + valid_ratio_h = valid_H.float() / H + valid_ratio_w = valid_W.float() / W + valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) + return valid_ratio diff --git a/mmpose/models/heads/transformer_heads/edpose_head.py b/mmpose/models/heads/transformer_heads/edpose_head.py new file mode 100644 index 0000000000..d864f8fadd --- /dev/null +++ b/mmpose/models/heads/transformer_heads/edpose_head.py @@ -0,0 +1,1346 @@ +# ---------------------------------------------------------------------------- +# Adapted from https://github.com/IDEA-Research/ED-Pose/ \ +# tree/master/models/edpose +# Original licence: IDEA License 1.0 +# ---------------------------------------------------------------------------- + +import copy +import math +from typing import Dict, List, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from mmcv.ops import MultiScaleDeformableAttention +from mmengine.model import BaseModule, ModuleList, constant_init +from mmengine.structures import InstanceData +from torch import Tensor, nn + +from mmpose.models.utils import inverse_sigmoid +from mmpose.registry import KEYPOINT_CODECS, MODELS +from mmpose.utils.tensor_utils import to_numpy +from mmpose.utils.typing import (ConfigType, Features, OptConfigType, + OptSampleList, Predictions) +from .base_transformer_head import TransformerHead +from .transformers.deformable_detr_layers import ( + DeformableDetrTransformerDecoderLayer, DeformableDetrTransformerEncoder) +from .transformers.utils import FFN, PositionEmbeddingSineHW + + +class EDPoseDecoder(BaseModule): + """Transformer decoder of EDPose: `Explicit Box Detection Unifies End-to- + End Multi-Person Pose Estimation. + + Args: + layer_cfg (ConfigDict): the config of each encoder + layer. All the layers will share the same config. + num_layers (int): Number of decoder layers. + return_intermediate (bool, optional): Whether to return outputs of + intermediate layers. Defaults to `True`. + embed_dims (int): Dims of embed. + query_dim (int): Dims of queries. + num_feature_levels (int): Number of feature levels. + num_box_decoder_layers (int): Number of box decoder layers. + num_keypoints (int): Number of datasets' body keypoints. + num_dn (int): Number of denosing points. + num_group (int): Number of decoder layers. + """ + + def __init__(self, + layer_cfg, + num_layers, + return_intermediate, + embed_dims: int = 256, + query_dim=4, + num_feature_levels=1, + num_box_decoder_layers=2, + num_keypoints=17, + num_dn=100, + num_group=100): + super().__init__() + + self.layer_cfg = layer_cfg + self.num_layers = num_layers + self.embed_dims = embed_dims + + assert return_intermediate, 'support return_intermediate only' + self.return_intermediate = return_intermediate + + assert query_dim in [ + 2, 4 + ], 'query_dim should be 2/4 but {}'.format(query_dim) + self.query_dim = query_dim + + self.num_feature_levels = num_feature_levels + + self.layers = ModuleList([ + DeformableDetrTransformerDecoderLayer(**self.layer_cfg) + for _ in range(self.num_layers) + ]) + self.norm = nn.LayerNorm(self.embed_dims) + + self.ref_point_head = FFN(self.query_dim // 2 * self.embed_dims, + self.embed_dims, self.embed_dims, 2) + + self.num_keypoints = num_keypoints + self.query_scale = None + self.bbox_embed = None + self.class_embed = None + self.pose_embed = None + self.pose_hw_embed = None + self.num_box_decoder_layers = num_box_decoder_layers + self.box_pred_damping = None + self.num_group = num_group + self.rm_detach = None + self.num_dn = num_dn + self.hw = nn.Embedding(self.num_keypoints, 2) + self.keypoint_embed = nn.Embedding(self.num_keypoints, embed_dims) + self.kpt_index = [ + x for x in range(self.num_group * (self.num_keypoints + 1)) + if x % (self.num_keypoints + 1) != 0 + ] + + def forward(self, query: Tensor, value: Tensor, key_padding_mask: Tensor, + reference_points: Tensor, spatial_shapes: Tensor, + level_start_index: Tensor, valid_ratios: Tensor, + humandet_attn_mask: Tensor, human2pose_attn_mask: Tensor, + **kwargs) -> Tuple[Tensor]: + """Forward function of decoder + Args: + query (Tensor): The input queries, has shape (bs, num_queries, + dim). + value (Tensor): The input values, has shape (bs, num_value, dim). + key_padding_mask (Tensor): The `key_padding_mask` of `cross_attn` + input. ByteTensor, has shape (bs, num_value). + reference_points (Tensor): The initial reference, has shape + (bs, num_queries, 4) with the last dimension arranged as + (cx, cy, w, h) when `as_two_stage` is `True`, otherwise has + shape (bs, num_queries, 2) with the last dimension arranged + as (cx, cy). + spatial_shapes (Tensor): Spatial shapes of features in all levels, + has shape (num_levels, 2), last dimension represents (h, w). + level_start_index (Tensor): The start index of each level. + A tensor has shape (num_levels, ) and can be represented + as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. + valid_ratios (Tensor): The ratios of the valid width and the valid + height relative to the width and the height of features in all + levels, has shape (bs, num_levels, 2). + reg_branches: (obj:`nn.ModuleList`, optional): Used for refining + the regression results. + + Returns: + Tuple[Tuple[Tensor]]: Outputs of Deformable Transformer Decoder. + + - output (Tuple[Tensor]): Output embeddings of the last decoder, + each has shape (num_decoder_layers, num_queries, bs, embed_dims) + - reference_points (Tensor): The reference of the last decoder + layer, each has shape (num_decoder_layers, bs, num_queries, 4). + The coordinates are arranged as (cx, cy, w, h) + """ + output = query + attn_mask = humandet_attn_mask + intermediate = [] + intermediate_reference_points = [reference_points] + effect_num_dn = self.num_dn if self.training else 0 + inter_select_number = self.num_group + for layer_id, layer in enumerate(self.layers): + if reference_points.shape[-1] == 4: + reference_points_input = \ + reference_points[:, :, None] * \ + torch.cat([valid_ratios, valid_ratios], -1)[None, :] + else: + assert reference_points.shape[-1] == 2 + reference_points_input = \ + reference_points[:, :, None] * \ + valid_ratios[None, :] + + query_sine_embed = self.get_proposal_pos_embed( + reference_points_input[:, :, 0, :]) # nq, bs, 256*2 + query_pos = self.ref_point_head(query_sine_embed) # nq, bs, 256 + + output = layer( + output.transpose(0, 1), + query_pos=query_pos.transpose(0, 1), + value=value.transpose(0, 1), + key_padding_mask=key_padding_mask, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + reference_points=reference_points_input.transpose( + 0, 1).contiguous(), + self_attn_mask=attn_mask, + **kwargs) + output = output.transpose(0, 1) + intermediate.append(self.norm(output)) + + # human update + if layer_id < self.num_box_decoder_layers: + delta_unsig = self.bbox_embed[layer_id](output) + new_reference_points = delta_unsig + inverse_sigmoid( + reference_points) + new_reference_points = new_reference_points.sigmoid() + + # query expansion + if layer_id == self.num_box_decoder_layers - 1: + dn_output = output[:effect_num_dn] + dn_new_reference_points = new_reference_points[:effect_num_dn] + class_unselected = self.class_embed[layer_id]( + output)[effect_num_dn:] + topk_proposals = torch.topk( + class_unselected.max(-1)[0], inter_select_number, dim=0)[1] + new_reference_points_for_box = torch.gather( + new_reference_points[effect_num_dn:], 0, + topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) + new_output_for_box = torch.gather( + output[effect_num_dn:], 0, + topk_proposals.unsqueeze(-1).repeat(1, 1, self.embed_dims)) + bs = new_output_for_box.shape[1] + new_output_for_keypoint = new_output_for_box[:, None, :, :] \ + + self.keypoint_embed.weight[None, :, None, :] + if self.num_keypoints == 17: + delta_xy = self.pose_embed[-1](new_output_for_keypoint)[ + ..., :2] + else: + delta_xy = self.pose_embed[0](new_output_for_keypoint)[ + ..., :2] + keypoint_xy = (inverse_sigmoid( + new_reference_points_for_box[..., :2][:, None]) + + delta_xy).sigmoid() + num_queries, _, bs, _ = keypoint_xy.shape + keypoint_wh_weight = self.hw.weight.unsqueeze(0).unsqueeze( + -2).repeat(num_queries, 1, bs, 1).sigmoid() + keypoint_wh = keypoint_wh_weight * \ + new_reference_points_for_box[..., 2:][:, None] + new_reference_points_for_keypoint = torch.cat( + (keypoint_xy, keypoint_wh), dim=-1) + new_reference_points = torch.cat( + (new_reference_points_for_box.unsqueeze(1), + new_reference_points_for_keypoint), + dim=1).flatten(0, 1) + output = torch.cat( + (new_output_for_box.unsqueeze(1), new_output_for_keypoint), + dim=1).flatten(0, 1) + new_reference_points = torch.cat( + (dn_new_reference_points, new_reference_points), dim=0) + output = torch.cat((dn_output, output), dim=0) + attn_mask = human2pose_attn_mask + + # human-to-keypoints update + if layer_id >= self.num_box_decoder_layers: + effect_num_dn = self.num_dn if self.training else 0 + inter_select_number = self.num_group + ref_before_sigmoid = inverse_sigmoid(reference_points) + output_bbox_dn = output[:effect_num_dn] + output_bbox_norm = output[effect_num_dn:][0::( + self.num_keypoints + 1)] + ref_before_sigmoid_bbox_dn = \ + ref_before_sigmoid[:effect_num_dn] + ref_before_sigmoid_bbox_norm = \ + ref_before_sigmoid[effect_num_dn:][0::( + self.num_keypoints + 1)] + delta_unsig_dn = self.bbox_embed[layer_id](output_bbox_dn) + delta_unsig_norm = self.bbox_embed[layer_id](output_bbox_norm) + outputs_unsig_dn = delta_unsig_dn + ref_before_sigmoid_bbox_dn + outputs_unsig_norm = delta_unsig_norm + \ + ref_before_sigmoid_bbox_norm + new_reference_points_for_box_dn = outputs_unsig_dn.sigmoid() + new_reference_points_for_box_norm = outputs_unsig_norm.sigmoid( + ) + output_kpt = output[effect_num_dn:].index_select( + 0, torch.tensor(self.kpt_index, device=output.device)) + delta_xy_unsig = self.pose_embed[layer_id - + self.num_box_decoder_layers]( + output_kpt) + outputs_unsig = ref_before_sigmoid[ + effect_num_dn:].index_select( + 0, torch.tensor(self.kpt_index, + device=output.device)).clone() + delta_hw_unsig = self.pose_hw_embed[ + layer_id - self.num_box_decoder_layers]( + output_kpt) + outputs_unsig[..., :2] += delta_xy_unsig[..., :2] + outputs_unsig[..., 2:] += delta_hw_unsig + new_reference_points_for_keypoint = outputs_unsig.sigmoid() + bs = new_reference_points_for_box_norm.shape[1] + new_reference_points_norm = torch.cat( + (new_reference_points_for_box_norm.unsqueeze(1), + new_reference_points_for_keypoint.view( + -1, self.num_keypoints, bs, 4)), + dim=1).flatten(0, 1) + new_reference_points = torch.cat( + (new_reference_points_for_box_dn, + new_reference_points_norm), + dim=0) + + reference_points = new_reference_points.detach() + intermediate_reference_points.append(reference_points) + + decoder_outputs = [itm_out.transpose(0, 1) for itm_out in intermediate] + reference_points = [ + itm_refpoint.transpose(0, 1) + for itm_refpoint in intermediate_reference_points + ] + + return decoder_outputs, reference_points + + @staticmethod + def get_proposal_pos_embed(pos_tensor: Tensor, + temperature: int = 10000, + num_pos_feats: int = 128) -> Tensor: + """Get the position embedding of the proposal. + + Args: + pos_tensor (Tensor): Not normalized proposals, has shape + (bs, num_queries, 4) with the last dimension arranged as + (cx, cy, w, h). + temperature (int, optional): The temperature used for scaling the + position embedding. Defaults to 10000. + num_pos_feats (int, optional): The feature dimension for each + position along x, y, w, and h-axis. Note the final returned + dimension for each position is 4 times of num_pos_feats. + Default to 128. + + Returns: + Tensor: The position embedding of proposal, has shape + (bs, num_queries, num_pos_feats * 4), with the last dimension + arranged as (cx, cy, w, h) + """ + + scale = 2 * math.pi + dim_t = torch.arange( + num_pos_feats, dtype=torch.float32, device=pos_tensor.device) + dim_t = temperature**(2 * (dim_t // 2) / num_pos_feats) + x_embed = pos_tensor[:, :, 0] * scale + y_embed = pos_tensor[:, :, 1] * scale + pos_x = x_embed[:, :, None] / dim_t + pos_y = y_embed[:, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), + dim=3).flatten(2) + pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), + dim=3).flatten(2) + if pos_tensor.size(-1) == 2: + pos = torch.cat((pos_y, pos_x), dim=2) + elif pos_tensor.size(-1) == 4: + w_embed = pos_tensor[:, :, 2] * scale + pos_w = w_embed[:, :, None] / dim_t + pos_w = torch.stack( + (pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), + dim=3).flatten(2) + + h_embed = pos_tensor[:, :, 3] * scale + pos_h = h_embed[:, :, None] / dim_t + pos_h = torch.stack( + (pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), + dim=3).flatten(2) + + pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) + else: + raise ValueError('Unknown pos_tensor shape(-1):{}'.format( + pos_tensor.size(-1))) + return pos + + +class EDPoseOutHead(BaseModule): + """Final Head of EDPose: `Explicit Box Detection Unifies End-to-End Multi- + Person Pose Estimation. + + Args: + num_classes (int): The number of classes. + num_keypoints (int): The number of datasets' body keypoints. + num_queries (int): The number of queries. + cls_no_bias (bool): Weather add the bias to class embed. + embed_dims (int): The dims of embed. + as_two_stage (bool, optional): Whether to generate the proposal + from the outputs of encoder. Defaults to `False`. + refine_queries_num (int): The number of refines queries after + decoders. + num_box_decoder_layers (int): The number of bbox decoder layer. + num_group (int): The number of groups. + num_pred_layer (int): The number of the prediction layers. + Defaults to 6. + dec_pred_class_embed_share (bool): Whether to share parameters + for all the class prediction layers. Defaults to `False`. + dec_pred_bbox_embed_share (bool): Whether to share parameters + for all the bbox prediction layers. Defaults to `False`. + dec_pred_pose_embed_share (bool): Whether to share parameters + for all the pose prediction layers. Defaults to `False`. + """ + + def __init__(self, + num_classes, + num_keypoints: int = 17, + num_queries: int = 900, + cls_no_bias: bool = False, + embed_dims: int = 256, + as_two_stage: bool = False, + refine_queries_num: int = 100, + num_box_decoder_layers: int = 2, + num_group: int = 100, + num_pred_layer: int = 6, + dec_pred_class_embed_share: bool = False, + dec_pred_bbox_embed_share: bool = False, + dec_pred_pose_embed_share: bool = False, + **kwargs): + super().__init__() + self.embed_dims = embed_dims + self.as_two_stage = as_two_stage + self.num_classes = num_classes + self.refine_queries_num = refine_queries_num + self.num_box_decoder_layers = num_box_decoder_layers + self.num_keypoints = num_keypoints + self.num_queries = num_queries + + # prepare pred layers + self.dec_pred_class_embed_share = dec_pred_class_embed_share + self.dec_pred_bbox_embed_share = dec_pred_bbox_embed_share + self.dec_pred_pose_embed_share = dec_pred_pose_embed_share + # prepare class & box embed + _class_embed = nn.Linear( + self.embed_dims, self.num_classes, bias=(not cls_no_bias)) + if not cls_no_bias: + prior_prob = 0.01 + bias_value = -math.log((1 - prior_prob) / prior_prob) + _class_embed.bias.data = torch.ones(self.num_classes) * bias_value + + _bbox_embed = FFN(self.embed_dims, self.embed_dims, 4, 3) + _pose_embed = FFN(self.embed_dims, self.embed_dims, 2, 3) + _pose_hw_embed = FFN(self.embed_dims, self.embed_dims, 2, 3) + + self.num_group = num_group + if dec_pred_bbox_embed_share: + box_embed_layerlist = [_bbox_embed for i in range(num_pred_layer)] + else: + box_embed_layerlist = [ + copy.deepcopy(_bbox_embed) for i in range(num_pred_layer) + ] + if dec_pred_class_embed_share: + class_embed_layerlist = [ + _class_embed for i in range(num_pred_layer) + ] + else: + class_embed_layerlist = [ + copy.deepcopy(_class_embed) for i in range(num_pred_layer) + ] + + if num_keypoints == 17: + if dec_pred_pose_embed_share: + pose_embed_layerlist = [ + _pose_embed + for i in range(num_pred_layer - num_box_decoder_layers + 1) + ] + else: + pose_embed_layerlist = [ + copy.deepcopy(_pose_embed) + for i in range(num_pred_layer - num_box_decoder_layers + 1) + ] + else: + if dec_pred_pose_embed_share: + pose_embed_layerlist = [ + _pose_embed + for i in range(num_pred_layer - num_box_decoder_layers) + ] + else: + pose_embed_layerlist = [ + copy.deepcopy(_pose_embed) + for i in range(num_pred_layer - num_box_decoder_layers) + ] + + pose_hw_embed_layerlist = [ + _pose_hw_embed + for i in range(num_pred_layer - num_box_decoder_layers) + ] + self.bbox_embed = nn.ModuleList(box_embed_layerlist) + self.class_embed = nn.ModuleList(class_embed_layerlist) + self.pose_embed = nn.ModuleList(pose_embed_layerlist) + self.pose_hw_embed = nn.ModuleList(pose_hw_embed_layerlist) + + def init_weights(self) -> None: + """Initialize weights of the Deformable DETR head.""" + + for m in self.bbox_embed: + constant_init(m[-1], 0, bias=0) + for m in self.pose_embed: + constant_init(m[-1], 0, bias=0) + + def forward(self, hidden_states: List[Tensor], references: List[Tensor], + mask_dict: Dict, hidden_states_enc: Tensor, + referens_enc: Tensor, batch_data_samples) -> Dict: + """Forward function. + + Args: + hidden_states (Tensor): Hidden states output from each decoder + layer, has shape (num_decoder_layers, bs, num_queries, dim). + references (list[Tensor]): List of the reference from the decoder. + + Returns: + tuple[Tensor]: results of head containing the following tensor. + + - pred_logits (Tensor): Outputs from the + classification head, the socres of every bboxes. + - pred_boxes (Tensor): The output boxes. + - pred_keypoints (Tensor): The output keypoints. + """ + # update human boxes + effec_dn_num = self.refine_queries_num if self.training else 0 + outputs_coord_list = [] + outputs_class = [] + for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_cls_embed, + layer_hs) in enumerate( + zip(references[:-1], self.bbox_embed, + self.class_embed, hidden_states)): + if dec_lid < self.num_box_decoder_layers: + layer_delta_unsig = layer_bbox_embed(layer_hs) + layer_outputs_unsig = layer_delta_unsig + inverse_sigmoid( + layer_ref_sig) + layer_outputs_unsig = layer_outputs_unsig.sigmoid() + layer_cls = layer_cls_embed(layer_hs) + outputs_coord_list.append(layer_outputs_unsig) + outputs_class.append(layer_cls) + else: + layer_hs_bbox_dn = layer_hs[:, :effec_dn_num, :] + layer_hs_bbox_norm = \ + layer_hs[:, effec_dn_num:, :][:, 0::( + self.num_keypoints + 1), :] + bs = layer_ref_sig.shape[0] + ref_before_sigmoid_bbox_dn = \ + layer_ref_sig[:, : effec_dn_num, :] + ref_before_sigmoid_bbox_norm = \ + layer_ref_sig[:, effec_dn_num:, :][:, 0::( + self.num_keypoints + 1), :] + layer_delta_unsig_dn = layer_bbox_embed(layer_hs_bbox_dn) + layer_delta_unsig_norm = layer_bbox_embed(layer_hs_bbox_norm) + layer_outputs_unsig_dn = layer_delta_unsig_dn + \ + inverse_sigmoid(ref_before_sigmoid_bbox_dn) + layer_outputs_unsig_dn = layer_outputs_unsig_dn.sigmoid() + layer_outputs_unsig_norm = layer_delta_unsig_norm + \ + inverse_sigmoid(ref_before_sigmoid_bbox_norm) + layer_outputs_unsig_norm = layer_outputs_unsig_norm.sigmoid() + layer_outputs_unsig = torch.cat( + (layer_outputs_unsig_dn, layer_outputs_unsig_norm), dim=1) + layer_cls_dn = layer_cls_embed(layer_hs_bbox_dn) + layer_cls_norm = layer_cls_embed(layer_hs_bbox_norm) + layer_cls = torch.cat((layer_cls_dn, layer_cls_norm), dim=1) + outputs_class.append(layer_cls) + outputs_coord_list.append(layer_outputs_unsig) + + # update keypoints boxes + outputs_keypoints_list = [] + kpt_index = [ + x for x in range(self.num_group * (self.num_keypoints + 1)) + if x % (self.num_keypoints + 1) != 0 + ] + for dec_lid, (layer_ref_sig, layer_hs) in enumerate( + zip(references[:-1], hidden_states)): + if dec_lid < self.num_box_decoder_layers: + assert isinstance(layer_hs, torch.Tensor) + bs = layer_hs.shape[0] + layer_res = layer_hs.new_zeros( + (bs, self.num_queries, self.num_keypoints * 3)) + outputs_keypoints_list.append(layer_res) + else: + bs = layer_ref_sig.shape[0] + layer_hs_kpt = \ + layer_hs[:, effec_dn_num:, :].index_select( + 1, torch.tensor(kpt_index, device=layer_hs.device)) + delta_xy_unsig = self.pose_embed[dec_lid - + self.num_box_decoder_layers]( + layer_hs_kpt) + layer_ref_sig_kpt = \ + layer_ref_sig[:, effec_dn_num:, :].index_select( + 1, torch.tensor(kpt_index, device=layer_hs.device)) + layer_outputs_unsig_keypoints = delta_xy_unsig + \ + inverse_sigmoid(layer_ref_sig_kpt[..., :2]) + vis_xy_unsig = torch.ones_like( + layer_outputs_unsig_keypoints, + device=layer_outputs_unsig_keypoints.device) + xyv = torch.cat((layer_outputs_unsig_keypoints, + vis_xy_unsig[:, :, 0].unsqueeze(-1)), + dim=-1) + xyv = xyv.sigmoid() + layer_res = xyv.reshape( + (bs, self.num_group, self.num_keypoints, 3)).flatten(2, 3) + layer_res = self.keypoint_xyzxyz_to_xyxyzz(layer_res) + outputs_keypoints_list.append(layer_res) + + dn_mask_dict = mask_dict + if self.refine_queries_num > 0 and dn_mask_dict is not None: + outputs_class, outputs_coord_list, outputs_keypoints_list = \ + self.dn_post_process2( + outputs_class, outputs_coord_list, + outputs_keypoints_list, dn_mask_dict + ) + + for _out_class, _out_bbox, _out_keypoint in zip( + outputs_class, outputs_coord_list, outputs_keypoints_list): + assert _out_class.shape[1] == \ + _out_bbox.shape[1] == _out_keypoint.shape[1] + + return outputs_class[-1], outputs_coord_list[ + -1], outputs_keypoints_list[-1] + + def keypoint_xyzxyz_to_xyxyzz(self, keypoints: torch.Tensor): + """ + Args: + keypoints (torch.Tensor): ..., 51 + """ + res = torch.zeros_like(keypoints) + num_points = keypoints.shape[-1] // 3 + res[..., 0:2 * num_points:2] = keypoints[..., 0::3] + res[..., 1:2 * num_points:2] = keypoints[..., 1::3] + res[..., 2 * num_points:] = keypoints[..., 2::3] + return res + + +@MODELS.register_module() +class EDPoseHead(TransformerHead): + """Head introduced in `Explicit Box Detection Unifies End-to-End Multi- + Person Pose Estimation`_ by J Yang1 et al (2023). The head is composed of + Encoder, Decoder and Out_head. + + Code is modified from the `official github repo + `_. + + More details can be found in the `paper + `_ . + + Args: + num_queries (int): Number of query in Transformer. + num_feature_levels (int): Number of feature levels. Defaults to 4. + num_keypoints (int): Number of keypoints. Defaults to 4. + as_two_stage (bool, optional): Whether to generate the proposal + from the outputs of encoder. Defaults to `False`. + encoder (:obj:`ConfigDict` or dict, optional): Config of the + Transformer encoder. Defaults to None. + decoder (:obj:`ConfigDict` or dict, optional): Config of the + Transformer decoder. Defaults to None. + out_head (:obj:`ConfigDict` or dict, optional): Config for the + bounding final out head module. Defaults to None. + positional_encoding (:obj:`ConfigDict` or dict): Config for + transformer position encoding. Defaults None. + denosing_cfg (:obj:`ConfigDict` or dict, optional): Config of the + human query denoising training strategy. + data_decoder (:obj:`ConfigDict` or dict, optional): Config of the + data decoder which transform the results from output space to + input space. + dec_pred_class_embed_share (bool): Whether to share the class embed + layer. Default False. + dec_pred_bbox_embed_share (bool): Whether to share the bbox embed + layer. Default False. + refine_queries_num (int): Number of refined human content queries + and their position queries . + two_stage_keep_all_tokens (bool): Whether to keep all tokens. + """ + + def __init__(self, + num_queries: int = 100, + num_feature_levels: int = 4, + num_keypoints: int = 17, + as_two_stage: bool = False, + encoder: OptConfigType = None, + decoder: OptConfigType = None, + out_head: OptConfigType = None, + positional_encoding: OptConfigType = None, + data_decoder: OptConfigType = None, + denosing_cfg: OptConfigType = None, + dec_pred_class_embed_share: bool = False, + dec_pred_bbox_embed_share: bool = False, + refine_queries_num: int = 100, + two_stage_keep_all_tokens: bool = False) -> None: + + self.as_two_stage = as_two_stage + self.num_feature_levels = num_feature_levels + self.refine_queries_num = refine_queries_num + self.dec_pred_class_embed_share = dec_pred_class_embed_share + self.dec_pred_bbox_embed_share = dec_pred_bbox_embed_share + self.two_stage_keep_all_tokens = two_stage_keep_all_tokens + self.num_heads = decoder['layer_cfg']['self_attn_cfg']['num_heads'] + self.num_group = decoder['num_group'] + self.num_keypoints = num_keypoints + self.denosing_cfg = denosing_cfg + if data_decoder is not None: + self.data_decoder = KEYPOINT_CODECS.build(data_decoder) + else: + self.data_decoder = None + + super().__init__( + encoder=encoder, + decoder=decoder, + out_head=out_head, + positional_encoding=positional_encoding, + num_queries=num_queries) + + self.positional_encoding = PositionEmbeddingSineHW( + **self.positional_encoding_cfg) + self.encoder = DeformableDetrTransformerEncoder(**self.encoder_cfg) + self.decoder = EDPoseDecoder( + num_keypoints=num_keypoints, **self.decoder_cfg) + self.out_head = EDPoseOutHead( + num_keypoints=num_keypoints, + as_two_stage=as_two_stage, + refine_queries_num=refine_queries_num, + **self.out_head_cfg, + **self.decoder_cfg) + + self.embed_dims = self.encoder.embed_dims + self.label_enc = nn.Embedding( + self.denosing_cfg['dn_labelbook_size'] + 1, self.embed_dims) + + if not self.as_two_stage: + self.query_embedding = nn.Embedding(self.num_queries, + self.embed_dims) + self.refpoint_embedding = nn.Embedding(self.num_queries, 4) + + self.level_embed = nn.Parameter( + torch.Tensor(self.num_feature_levels, self.embed_dims)) + + self.decoder.bbox_embed = self.out_head.bbox_embed + self.decoder.pose_embed = self.out_head.pose_embed + self.decoder.pose_hw_embed = self.out_head.pose_hw_embed + self.decoder.class_embed = self.out_head.class_embed + + if self.as_two_stage: + self.memory_trans_fc = nn.Linear(self.embed_dims, self.embed_dims) + self.memory_trans_norm = nn.LayerNorm(self.embed_dims) + if dec_pred_class_embed_share and dec_pred_bbox_embed_share: + self.enc_out_bbox_embed = self.out_head.bbox_embed[0] + else: + self.enc_out_bbox_embed = copy.deepcopy( + self.out_head.bbox_embed[0]) + + if dec_pred_class_embed_share and dec_pred_bbox_embed_share: + self.enc_out_class_embed = self.out_head.class_embed[0] + else: + self.enc_out_class_embed = copy.deepcopy( + self.out_head.class_embed[0]) + + def init_weights(self) -> None: + """Initialize weights for Transformer and other components.""" + super().init_weights() + for coder in self.encoder, self.decoder: + for p in coder.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + for m in self.modules(): + if isinstance(m, MultiScaleDeformableAttention): + m.init_weights() + if self.as_two_stage: + nn.init.xavier_uniform_(self.memory_trans_fc.weight) + + nn.init.normal_(self.level_embed) + + def pre_transformer(self, + img_feats: Tuple[Tensor], + batch_data_samples: OptSampleList = None + ) -> Tuple[Dict]: + """Process image features before feeding them to the transformer. + + Args: + img_feats (tuple[Tensor]): Multi-level features that may have + different resolutions, output from neck. Each feature has + shape (bs, dim, h_lvl, w_lvl), where 'lvl' means 'layer'. + batch_data_samples (list[:obj:`DetDataSample`], optional): The + batch data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + Defaults to None. + + Returns: + tuple[dict]: The first dict contains the inputs of encoder and the + second dict contains the inputs of decoder. + + - encoder_inputs_dict (dict): The keyword args dictionary of + `self.encoder()`. + - decoder_inputs_dict (dict): The keyword args dictionary of + `self.forward_decoder()`, which includes 'memory_mask'. + """ + batch_size = img_feats[0].size(0) + # construct binary masks for the transformer. + assert batch_data_samples is not None + batch_input_shape = batch_data_samples[0].batch_input_shape + img_shape_list = [sample.img_shape for sample in batch_data_samples] + input_img_h, input_img_w = batch_input_shape + masks = img_feats[0].new_ones((batch_size, input_img_h, input_img_w)) + for img_id in range(batch_size): + img_h, img_w = img_shape_list[img_id] + masks[img_id, :img_h, :img_w] = 0 + # NOTE following the official DETR repo, non-zero values representing + # ignored positions, while zero values means valid positions. + + mlvl_masks = [] + mlvl_pos_embeds = [] + for feat in img_feats: + mlvl_masks.append( + F.interpolate(masks[None], + size=feat.shape[-2:]).to(torch.bool).squeeze(0)) + mlvl_pos_embeds.append(self.positional_encoding(mlvl_masks[-1])) + + feat_flatten = [] + lvl_pos_embed_flatten = [] + mask_flatten = [] + spatial_shapes = [] + for lvl, (feat, mask, pos_embed) in enumerate( + zip(img_feats, mlvl_masks, mlvl_pos_embeds)): + batch_size, c, h, w = feat.shape + # [bs, c, h_lvl, w_lvl] -> [bs, h_lvl*w_lvl, c] + feat = feat.view(batch_size, c, -1).permute(0, 2, 1) + pos_embed = pos_embed.view(batch_size, c, -1).permute(0, 2, 1) + lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) + # [bs, h_lvl, w_lvl] -> [bs, h_lvl*w_lvl] + mask = mask.flatten(1) + spatial_shape = (h, w) + + feat_flatten.append(feat) + lvl_pos_embed_flatten.append(lvl_pos_embed) + mask_flatten.append(mask) + spatial_shapes.append(spatial_shape) + + # (bs, num_feat_points, dim) + feat_flatten = torch.cat(feat_flatten, 1) + lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) + # (bs, num_feat_points), where num_feat_points = sum_lvl(h_lvl*w_lvl) + mask_flatten = torch.cat(mask_flatten, 1) + + spatial_shapes = torch.as_tensor( # (num_level, 2) + spatial_shapes, + dtype=torch.long, + device=feat_flatten.device) + level_start_index = torch.cat(( + spatial_shapes.new_zeros((1, )), # (num_level) + spatial_shapes.prod(1).cumsum(0)[:-1])) + valid_ratios = torch.stack( # (bs, num_level, 2) + [self.get_valid_ratio(m) for m in mlvl_masks], 1) + + if self.refine_queries_num > 0 or batch_data_samples is not None: + input_query_label, input_query_bbox, humandet_attn_mask, \ + human2pose_attn_mask, mask_dict =\ + self.prepare_for_denosing( + batch_data_samples, + device=img_feats[0].device) + else: + assert batch_data_samples is None + input_query_bbox = input_query_label = \ + humandet_attn_mask = human2pose_attn_mask = mask_dict = None + + encoder_inputs_dict = dict( + query=feat_flatten, + query_pos=lvl_pos_embed_flatten, + key_padding_mask=mask_flatten, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios) + decoder_inputs_dict = dict( + memory_mask=mask_flatten, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + humandet_attn_mask=humandet_attn_mask, + human2pose_attn_mask=human2pose_attn_mask, + input_query_bbox=input_query_bbox, + input_query_label=input_query_label, + mask_dict=mask_dict) + return encoder_inputs_dict, decoder_inputs_dict + + def forward_encoder(self, + img_feats: Tuple[Tensor], + batch_data_samples: OptSampleList = None) -> Dict: + """Forward with Transformer encoder. + + The forward procedure is defined as: + 'pre_transformer' -> 'encoder' + + Args: + img_feats (tuple[Tensor]): Multi-level features that may have + different resolutions, output from neck. Each feature has + shape (bs, dim, h_lvl, w_lvl), where 'lvl' means 'layer'. + batch_data_samples (list[:obj:`DetDataSample`], optional): The + batch data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + Defaults to None. + + Returns: + dict: The dictionary of encoder outputs, which includes the + `memory` of the encoder output. + """ + encoder_inputs_dict, decoder_inputs_dict = self.pre_transformer( + img_feats, batch_data_samples) + + memory = self.encoder(**encoder_inputs_dict) + encoder_outputs_dict = dict(memory=memory, **decoder_inputs_dict) + return encoder_outputs_dict + + def pre_decoder(self, memory: Tensor, memory_mask: Tensor, + spatial_shapes: Tensor, input_query_bbox: Tensor, + input_query_label: Tensor) -> Tuple[Dict, Dict]: + """Prepare intermediate variables before entering Transformer decoder, + such as `query` and `reference_points`. + + Args: + memory (Tensor): The output embeddings of the Transformer encoder, + has shape (bs, num_feat_points, dim). + memory_mask (Tensor): ByteTensor, the padding mask of the memory, + has shape (bs, num_feat_points). It will only be used when + `as_two_stage` is `True`. + spatial_shapes (Tensor): Spatial shapes of features in all levels, + has shape (num_levels, 2), last dimension represents (h, w). + It will only be used when `as_two_stage` is `True`. + input_query_bbox (Tensor): Denosing bbox query for training. + input_query_label (Tensor): Denosing label query for training. + + Returns: + tuple[dict, dict]: The decoder_inputs_dict and head_inputs_dict. + + - decoder_inputs_dict (dict): The keyword dictionary args of + `self.decoder()`. + - head_inputs_dict (dict): The keyword dictionary args of the + bbox_head functions. + """ + bs, _, c = memory.shape + if self.as_two_stage: + output_memory, output_proposals = \ + self.gen_encoder_output_proposals( + memory, memory_mask, spatial_shapes) + enc_outputs_class = self.enc_out_class_embed(output_memory) + enc_outputs_coord_unact = self.enc_out_bbox_embed( + output_memory) + output_proposals + + topk_proposals = torch.topk( + enc_outputs_class.max(-1)[0], self.num_queries, dim=1)[1] + topk_coords_undetach = torch.gather( + enc_outputs_coord_unact, 1, + topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) + topk_coords_unact = topk_coords_undetach.detach() + reference_points = topk_coords_unact.sigmoid() + + query_undetach = torch.gather( + output_memory, 1, + topk_proposals.unsqueeze(-1).repeat(1, 1, self.embed_dims)) + query = query_undetach.detach() + + if input_query_bbox is not None: + reference_points = torch.cat( + [input_query_bbox, topk_coords_unact], dim=1).sigmoid() + query = torch.cat([input_query_label, query], dim=1) + if self.two_stage_keep_all_tokens: + hidden_states_enc = output_memory.unsqueeze(0) + referens_enc = enc_outputs_coord_unact.unsqueeze(0) + else: + hidden_states_enc = query_undetach.unsqueeze(0) + referens_enc = topk_coords_undetach.sigmoid().unsqueeze(0) + else: + hidden_states_enc, referens_enc = None, None + query = self.query_embedding.weight[:, None, :].repeat( + 1, bs, 1).transpose(0, 1) + reference_points = \ + self.refpoint_embedding.weight[:, None, :].repeat(1, bs, 1) + + if input_query_bbox is not None: + reference_points = torch.cat( + [input_query_bbox, reference_points], dim=1) + query = torch.cat([input_query_label, query], dim=1) + reference_points = reference_points.sigmoid() + + decoder_inputs_dict = dict( + query=query, reference_points=reference_points) + head_inputs_dict = dict( + hidden_states_enc=hidden_states_enc, referens_enc=referens_enc) + return decoder_inputs_dict, head_inputs_dict + + def forward_decoder(self, memory: Tensor, memory_mask: Tensor, + spatial_shapes: Tensor, level_start_index: Tensor, + valid_ratios: Tensor, humandet_attn_mask: Tensor, + human2pose_attn_mask: Tensor, input_query_bbox: Tensor, + input_query_label: Tensor, mask_dict: Dict) -> Dict: + """Forward with Transformer decoder. + + The forward procedure is defined as: + 'pre_decoder' -> 'decoder' + + Args: + memory (Tensor): The output embeddings of the Transformer encoder, + has shape (bs, num_feat_points, dim). + memory_mask (Tensor): ByteTensor, the padding mask of the memory, + has shape (bs, num_feat_points). + spatial_shapes (Tensor): Spatial shapes of features in all levels, + has shape (num_levels, 2), last dimension represents (h, w). + level_start_index (Tensor): The start index of each level. + A tensor has shape (num_levels, ) and can be represented + as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. + valid_ratios (Tensor): The ratios of the valid width and the valid + height relative to the width and the height of features in all + levels, has shape (bs, num_levels, 2). + humandet_attn_mask (Tensor): Human attention mask. + human2pose_attn_mask (Tensor): Human to pose attention mask. + input_query_bbox (Tensor): Denosing bbox query for training. + input_query_label (Tensor): Denosing label query for training. + + Returns: + dict: The dictionary of decoder outputs, which includes the + `hidden_states` of the decoder output and `references` including + the initial and intermediate reference_points. + """ + decoder_in, head_in = self.pre_decoder(memory, memory_mask, + spatial_shapes, + input_query_bbox, + input_query_label) + + inter_states, inter_references = self.decoder( + query=decoder_in['query'].transpose(0, 1), + value=memory.transpose(0, 1), + key_padding_mask=memory_mask, # for cross_attn + reference_points=decoder_in['reference_points'].transpose(0, 1), + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + humandet_attn_mask=humandet_attn_mask, + human2pose_attn_mask=human2pose_attn_mask) + references = inter_references + decoder_outputs_dict = dict( + hidden_states=inter_states, + references=references, + mask_dict=mask_dict) + decoder_outputs_dict.update(head_in) + return decoder_outputs_dict + + def forward_out_head(self, batch_data_samples: OptSampleList, + hidden_states: List[Tensor], references: List[Tensor], + mask_dict: Dict, hidden_states_enc: Tensor, + referens_enc: Tensor) -> Tuple[Tensor]: + """Forward function.""" + out = self.out_head(hidden_states, references, mask_dict, + hidden_states_enc, referens_enc, + batch_data_samples) + return out + + def predict(self, + feats: Features, + batch_data_samples: OptSampleList, + test_cfg: ConfigType = {}) -> Predictions: + """Predict results from features.""" + input_shapes = np.array( + [d.metainfo['input_size'] for d in batch_data_samples]) + + if test_cfg.get('flip_test', False): + assert NotImplementedError( + 'flip_test is currently not supported ' + 'for EDPose. Please set `model.test_cfg.flip_test=False`') + else: + pred_logits, pred_boxes, pred_keypoints = self.forward( + feats, batch_data_samples) # (B, K, D) + + pred = self.decode( + input_shapes, + pred_logits=pred_logits, + pred_boxes=pred_boxes, + pred_keypoints=pred_keypoints) + return pred + + def decode(self, input_shapes: np.ndarray, pred_logits: Tensor, + pred_boxes: Tensor, pred_keypoints: Tensor): + """Select the final top-k keypoints, and decode the results from + normalize size to origin input size. + + Args: + input_shapes (Tensor): The size of input image. + pred_logits (Tensor): The result of score. + pred_boxes (Tensor): The result of bbox. + pred_keypoints (Tensor): The result of keypoints. + + Returns: + """ + + if self.data_decoder is None: + raise RuntimeError(f'The data decoder has not been set in \ + {self.__class__.__name__}. ' + 'Please set the data decoder configs in \ + the init parameters to ' + 'enable head methods `head.predict()` and \ + `head.decode()`') + + preds = [] + + pred_logits = pred_logits.sigmoid() + pred_logits, pred_boxes, pred_keypoints = to_numpy( + [pred_logits, pred_boxes, pred_keypoints]) + + for input_shape, pred_logit, pred_bbox, pred_kpts in zip( + input_shapes, pred_logits, pred_boxes, pred_keypoints): + + bboxes, keypoints, keypoint_scores = self.data_decoder.decode( + input_shape, pred_logit, pred_bbox, pred_kpts) + + # pack outputs + preds.append( + InstanceData( + keypoints=keypoints, + keypoint_scores=keypoint_scores, + bboxes=bboxes)) + + return preds + + def gen_encoder_output_proposals(self, memory: Tensor, memory_mask: Tensor, + spatial_shapes: Tensor + ) -> Tuple[Tensor, Tensor]: + """Generate proposals from encoded memory. The function will only be + used when `as_two_stage` is `True`. + + Args: + memory (Tensor): The output embeddings of the Transformer encoder, + has shape (bs, num_feat_points, dim). + memory_mask (Tensor): ByteTensor, the padding mask of the memory, + has shape (bs, num_feat_points). + spatial_shapes (Tensor): Spatial shapes of features in all levels, + has shape (num_levels, 2), last dimension represents (h, w). + + Returns: + tuple: A tuple of transformed memory and proposals. + + - output_memory (Tensor): The transformed memory for obtaining + top-k proposals, has shape (bs, num_feat_points, dim). + - output_proposals (Tensor): The inverse-normalized proposal, has + shape (batch_size, num_keys, 4) with the last dimension arranged + as (cx, cy, w, h). + """ + bs = memory.size(0) + proposals = [] + _cur = 0 # start index in the sequence of the current level + for lvl, (H, W) in enumerate(spatial_shapes): + mask_flatten_ = memory_mask[:, + _cur:(_cur + H * W)].view(bs, H, W, 1) + valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1).unsqueeze(-1) + valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1).unsqueeze(-1) + + grid_y, grid_x = torch.meshgrid( + torch.linspace( + 0, H - 1, H, dtype=torch.float32, device=memory.device), + torch.linspace( + 0, W - 1, W, dtype=torch.float32, device=memory.device)) + grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) + + scale = torch.cat([valid_W, valid_H], 1).view(bs, 1, 1, 2) + grid = (grid.unsqueeze(0).expand(bs, -1, -1, -1) + 0.5) / scale + wh = torch.ones_like(grid) * 0.05 * (2.0**lvl) + proposal = torch.cat((grid, wh), -1).view(bs, -1, 4) + proposals.append(proposal) + _cur += (H * W) + output_proposals = torch.cat(proposals, 1) + output_proposals_valid = ((output_proposals > 0.01) & + (output_proposals < 0.99)).all( + -1, keepdim=True) + + output_proposals = inverse_sigmoid(output_proposals) + output_proposals = output_proposals.masked_fill( + memory_mask.unsqueeze(-1), float('inf')) + output_proposals = output_proposals.masked_fill( + ~output_proposals_valid, float('inf')) + + output_memory = memory + output_memory = output_memory.masked_fill( + memory_mask.unsqueeze(-1), float(0)) + output_memory = output_memory.masked_fill(~output_proposals_valid, + float(0)) + output_memory = self.memory_trans_fc(output_memory) + output_memory = self.memory_trans_norm(output_memory) + # [bs, sum(hw), 2] + return output_memory, output_proposals + + @property + def default_init_cfg(self): + init_cfg = [dict(type='Normal', layer=['Linear'], std=0.01, bias=0)] + return init_cfg + + def prepare_for_denosing(self, targets: OptSampleList, device): + """prepare for dn components in forward function.""" + if not self.training: + bs = len(targets) + attn_mask_infere = torch.zeros( + bs, + self.num_heads, + self.num_group * (self.num_keypoints + 1), + self.num_group * (self.num_keypoints + 1), + device=device, + dtype=torch.bool) + group_bbox_kpt = (self.num_keypoints + 1) + kpt_index = [ + x for x in range(self.num_group * (self.num_keypoints + 1)) + if x % (self.num_keypoints + 1) == 0 + ] + for matchj in range(self.num_group * (self.num_keypoints + 1)): + sj = (matchj // group_bbox_kpt) * group_bbox_kpt + ej = (matchj // group_bbox_kpt + 1) * group_bbox_kpt + if sj > 0: + attn_mask_infere[:, :, matchj, :sj] = True + if ej < self.num_group * (self.num_keypoints + 1): + attn_mask_infere[:, :, matchj, ej:] = True + for match_x in range(self.num_group * (self.num_keypoints + 1)): + if match_x % group_bbox_kpt == 0: + attn_mask_infere[:, :, match_x, kpt_index] = False + + attn_mask_infere = attn_mask_infere.flatten(0, 1) + return None, None, None, attn_mask_infere, None + + # targets, dn_scalar, noise_scale = dn_args + device = targets[0]['boxes'].device + bs = len(targets) + refine_queries_num = self.refine_queries_num + + # gather gt boxes and labels + gt_boxes = [t['boxes'] for t in targets] + gt_labels = [t['labels'] for t in targets] + gt_keypoints = [t['keypoints'] for t in targets] + + # repeat them + def get_indices_for_repeat(now_num, target_num, device='cuda'): + """ + Input: + - now_num: int + - target_num: int + Output: + - indices: tensor[target_num] + """ + out_indice = [] + base_indice = torch.arange(now_num).to(device) + multiplier = target_num // now_num + out_indice.append(base_indice.repeat(multiplier)) + residue = target_num % now_num + out_indice.append(base_indice[torch.randint( + 0, now_num, (residue, ), device=device)]) + return torch.cat(out_indice) + + gt_boxes_expand = [] + gt_labels_expand = [] + gt_keypoints_expand = [] + for idx, (gt_boxes_i, gt_labels_i, gt_keypoint_i) in enumerate( + zip(gt_boxes, gt_labels, gt_keypoints)): + num_gt_i = gt_boxes_i.shape[0] + if num_gt_i > 0: + indices = get_indices_for_repeat(num_gt_i, refine_queries_num, + device) + gt_boxes_expand_i = gt_boxes_i[indices] # num_dn, 4 + gt_labels_expand_i = gt_labels_i[indices] + gt_keypoints_expand_i = gt_keypoint_i[indices] + else: + # all negative samples when no gt boxes + gt_boxes_expand_i = torch.rand( + refine_queries_num, 4, device=device) + gt_labels_expand_i = torch.ones( + refine_queries_num, dtype=torch.int64, + device=device) * int(self.num_classes) + gt_keypoints_expand_i = torch.rand( + refine_queries_num, self.num_keypoints * 3, device=device) + gt_boxes_expand.append(gt_boxes_expand_i) + gt_labels_expand.append(gt_labels_expand_i) + gt_keypoints_expand.append(gt_keypoints_expand_i) + gt_boxes_expand = torch.stack(gt_boxes_expand) + gt_labels_expand = torch.stack(gt_labels_expand) + gt_keypoints_expand = torch.stack(gt_keypoints_expand) + knwon_boxes_expand = gt_boxes_expand.clone() + knwon_labels_expand = gt_labels_expand.clone() + + # add noise + if self.denosing_cfg['dn_label_noise_ratio'] > 0: + prob = torch.rand_like(knwon_labels_expand.float()) + chosen_indice = prob < self.denosing_cfg['dn_label_noise_ratio'] + new_label = torch.randint_like( + knwon_labels_expand[chosen_indice], 0, + self.dn_labelbook_size) # randomly put a new one here + knwon_labels_expand[chosen_indice] = new_label + + if self.denosing_cfg['dn_box_noise_scale'] > 0: + diff = torch.zeros_like(knwon_boxes_expand) + diff[..., :2] = knwon_boxes_expand[..., 2:] / 2 + diff[..., 2:] = knwon_boxes_expand[..., 2:] + knwon_boxes_expand += torch.mul( + (torch.rand_like(knwon_boxes_expand) * 2 - 1.0), + diff) * self.denosing_cfg['dn_box_noise_scale'] + knwon_boxes_expand = knwon_boxes_expand.clamp(min=0.0, max=1.0) + + input_query_label = self.label_enc(knwon_labels_expand) + input_query_bbox = inverse_sigmoid(knwon_boxes_expand) + + # prepare mask + if 'group2group' in self.denosing_cfg['dn_attn_mask_type_list']: + attn_mask = torch.zeros( + bs, + self.num_heads, + refine_queries_num + self.num_queries, + refine_queries_num + self.num_queries, + device=device, + dtype=torch.bool) + attn_mask[:, :, refine_queries_num:, :refine_queries_num] = True + for idx, (gt_boxes_i, + gt_labels_i) in enumerate(zip(gt_boxes, gt_labels)): + num_gt_i = gt_boxes_i.shape[0] + if num_gt_i == 0: + continue + for matchi in range(refine_queries_num): + si = (matchi // num_gt_i) * num_gt_i + ei = (matchi // num_gt_i + 1) * num_gt_i + if si > 0: + attn_mask[idx, :, matchi, :si] = True + if ei < refine_queries_num: + attn_mask[idx, :, matchi, ei:refine_queries_num] = True + attn_mask = attn_mask.flatten(0, 1) + + if 'group2group' in self.denosing_cfg['dn_attn_mask_type_list']: + attn_mask2 = torch.zeros( + bs, + self.num_heads, + refine_queries_num + self.num_group * (self.num_keypoints + 1), + refine_queries_num + self.num_group * (self.num_keypoints + 1), + device=device, + dtype=torch.bool) + attn_mask2[:, :, refine_queries_num:, :refine_queries_num] = True + group_bbox_kpt = (self.num_keypoints + 1) + kpt_index = [ + x for x in range(self.num_group * (self.num_keypoints + 1)) + if x % (self.num_keypoints + 1) == 0 + ] + for matchj in range(self.num_group * (self.num_keypoints + 1)): + sj = (matchj // group_bbox_kpt) * group_bbox_kpt + ej = (matchj // group_bbox_kpt + 1) * group_bbox_kpt + if sj > 0: + attn_mask2[:, :, refine_queries_num:, + refine_queries_num:][:, :, matchj, :sj] = True + if ej < self.num_group * (self.num_keypoints + 1): + attn_mask2[:, :, refine_queries_num:, + refine_queries_num:][:, :, matchj, ej:] = True + + for match_x in range(self.num_group * (self.num_keypoints + 1)): + if match_x % group_bbox_kpt == 0: + attn_mask2[:, :, refine_queries_num:, + refine_queries_num:][:, :, match_x, + kpt_index] = False + + for idx, (gt_boxes_i, + gt_labels_i) in enumerate(zip(gt_boxes, gt_labels)): + num_gt_i = gt_boxes_i.shape[0] + if num_gt_i == 0: + continue + for matchi in range(refine_queries_num): + si = (matchi // num_gt_i) * num_gt_i + ei = (matchi // num_gt_i + 1) * num_gt_i + if si > 0: + attn_mask2[idx, :, matchi, :si] = True + if ei < refine_queries_num: + attn_mask2[idx, :, matchi, + ei:refine_queries_num] = True + attn_mask2 = attn_mask2.flatten(0, 1) + + mask_dict = { + 'pad_size': refine_queries_num, + 'known_bboxs': gt_boxes_expand, + 'known_labels': gt_labels_expand, + 'known_keypoints': gt_keypoints_expand + } + + return input_query_label, input_query_bbox, \ + attn_mask, attn_mask2, mask_dict + + def loss(self, + feats: Tuple[Tensor], + batch_data_samples: OptSampleList, + train_cfg: OptConfigType = {}) -> dict: + """Calculate losses from a batch of inputs and data samples.""" + + assert NotImplementedError( + 'the training of EDPose has not been ' + 'supported. Please stay tuned for further update.') diff --git a/mmpose/models/heads/transformer_heads/transformers/__init__.py b/mmpose/models/heads/transformer_heads/transformers/__init__.py new file mode 100644 index 0000000000..0e9f115cd1 --- /dev/null +++ b/mmpose/models/heads/transformer_heads/transformers/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .deformable_detr_layers import (DeformableDetrTransformerDecoder, + DeformableDetrTransformerDecoderLayer, + DeformableDetrTransformerEncoder, + DeformableDetrTransformerEncoderLayer) +from .detr_layers import (DetrTransformerDecoder, DetrTransformerDecoderLayer, + DetrTransformerEncoder, DetrTransformerEncoderLayer) +from .utils import FFN, PositionEmbeddingSineHW + +__all__ = [ + 'DetrTransformerEncoder', 'DetrTransformerDecoder', + 'DetrTransformerEncoderLayer', 'DetrTransformerDecoderLayer', + 'DeformableDetrTransformerEncoder', 'DeformableDetrTransformerDecoder', + 'DeformableDetrTransformerEncoderLayer', + 'DeformableDetrTransformerDecoderLayer', 'PositionEmbeddingSineHW', 'FFN' +] diff --git a/mmpose/models/heads/transformer_heads/transformers/deformable_detr_layers.py b/mmpose/models/heads/transformer_heads/transformers/deformable_detr_layers.py new file mode 100644 index 0000000000..149f04e469 --- /dev/null +++ b/mmpose/models/heads/transformer_heads/transformers/deformable_detr_layers.py @@ -0,0 +1,251 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple, Union + +import torch +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention +from mmcv.ops import MultiScaleDeformableAttention +from mmengine.model import ModuleList +from torch import Tensor, nn + +from mmpose.models.utils import inverse_sigmoid +from .detr_layers import (DetrTransformerDecoder, DetrTransformerDecoderLayer, + DetrTransformerEncoder, DetrTransformerEncoderLayer) + + +class DeformableDetrTransformerEncoder(DetrTransformerEncoder): + """Transformer encoder of Deformable DETR.""" + + def _init_layers(self) -> None: + """Initialize encoder layers.""" + self.layers = ModuleList([ + DeformableDetrTransformerEncoderLayer(**self.layer_cfg) + for _ in range(self.num_layers) + ]) + self.embed_dims = self.layers[0].embed_dims + + def forward(self, query: Tensor, query_pos: Tensor, + key_padding_mask: Tensor, spatial_shapes: Tensor, + level_start_index: Tensor, valid_ratios: Tensor, + **kwargs) -> Tensor: + """Forward function of Transformer encoder. + + Args: + query (Tensor): The input query, has shape (bs, num_queries, dim). + query_pos (Tensor): The positional encoding for query, has shape + (bs, num_queries, dim). + key_padding_mask (Tensor): The `key_padding_mask` of `self_attn` + input. ByteTensor, has shape (bs, num_queries). + spatial_shapes (Tensor): Spatial shapes of features in all levels, + has shape (num_levels, 2), last dimension represents (h, w). + level_start_index (Tensor): The start index of each level. + A tensor has shape (num_levels, ) and can be represented + as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. + valid_ratios (Tensor): The ratios of the valid width and the valid + height relative to the width and the height of features in all + levels, has shape (bs, num_levels, 2). + + Returns: + Tensor: Output queries of Transformer encoder, which is also + called 'encoder output embeddings' or 'memory', has shape + (bs, num_queries, dim) + """ + reference_points = self.get_encoder_reference_points( + spatial_shapes, valid_ratios, device=query.device) + for layer in self.layers: + query = layer( + query=query, + query_pos=query_pos, + key_padding_mask=key_padding_mask, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + reference_points=reference_points, + **kwargs) + return query + + @staticmethod + def get_encoder_reference_points(spatial_shapes: Tensor, + valid_ratios: Tensor, + device: Union[torch.device, + str]) -> Tensor: + """Get the reference points used in encoder. + + Args: + spatial_shapes (Tensor): Spatial shapes of features in all levels, + has shape (num_levels, 2), last dimension represents (h, w). + valid_ratios (Tensor): The ratios of the valid width and the valid + height relative to the width and the height of features in all + levels, has shape (bs, num_levels, 2). + device (obj:`device` or str): The device acquired by the + `reference_points`. + + Returns: + Tensor: Reference points used in decoder, has shape (bs, length, + num_levels, 2). + """ + + reference_points_list = [] + for lvl, (H, W) in enumerate(spatial_shapes): + ref_y, ref_x = torch.meshgrid( + torch.linspace( + 0.5, H - 0.5, H, dtype=torch.float32, device=device), + torch.linspace( + 0.5, W - 0.5, W, dtype=torch.float32, device=device)) + ref_y = ref_y.reshape(-1)[None] / ( + valid_ratios[:, None, lvl, 1] * H) + ref_x = ref_x.reshape(-1)[None] / ( + valid_ratios[:, None, lvl, 0] * W) + ref = torch.stack((ref_x, ref_y), -1) + reference_points_list.append(ref) + reference_points = torch.cat(reference_points_list, 1) + # [bs, sum(hw), num_level, 2] + reference_points = reference_points[:, :, None] * valid_ratios[:, None] + return reference_points + + +class DeformableDetrTransformerDecoder(DetrTransformerDecoder): + """Transformer Decoder of Deformable DETR.""" + + def _init_layers(self) -> None: + """Initialize decoder layers.""" + self.layers = ModuleList([ + DeformableDetrTransformerDecoderLayer(**self.layer_cfg) + for _ in range(self.num_layers) + ]) + self.embed_dims = self.layers[0].embed_dims + if self.post_norm_cfg is not None: + raise ValueError('There is not post_norm in ' + f'{self._get_name()}') + + def forward(self, + query: Tensor, + query_pos: Tensor, + value: Tensor, + key_padding_mask: Tensor, + reference_points: Tensor, + spatial_shapes: Tensor, + level_start_index: Tensor, + valid_ratios: Tensor, + reg_branches: Optional[nn.Module] = None, + **kwargs) -> Tuple[Tensor]: + """Forward function of Transformer decoder. + + Args: + query (Tensor): The input queries, has shape (bs, num_queries, + dim). + query_pos (Tensor): The input positional query, has shape + (bs, num_queries, dim). It will be added to `query` before + forward function. + value (Tensor): The input values, has shape (bs, num_value, dim). + key_padding_mask (Tensor): The `key_padding_mask` of `cross_attn` + input. ByteTensor, has shape (bs, num_value). + reference_points (Tensor): The initial reference, has shape + (bs, num_queries, 4) with the last dimension arranged as + (cx, cy, w, h) when `as_two_stage` is `True`, otherwise has + shape (bs, num_queries, 2) with the last dimension arranged + as (cx, cy). + spatial_shapes (Tensor): Spatial shapes of features in all levels, + has shape (num_levels, 2), last dimension represents (h, w). + level_start_index (Tensor): The start index of each level. + A tensor has shape (num_levels, ) and can be represented + as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. + valid_ratios (Tensor): The ratios of the valid width and the valid + height relative to the width and the height of features in all + levels, has shape (bs, num_levels, 2). + reg_branches: (obj:`nn.ModuleList`, optional): Used for refining + the regression results. Only would be passed when + `with_box_refine` is `True`, otherwise would be `None`. + + Returns: + tuple[Tensor]: Outputs of Deformable Transformer Decoder. + + - output (Tensor): Output embeddings of the last decoder, has + shape (num_queries, bs, embed_dims) when `return_intermediate` + is `False`. Otherwise, Intermediate output embeddings of all + decoder layers, has shape (num_decoder_layers, num_queries, bs, + embed_dims). + - reference_points (Tensor): The reference of the last decoder + layer, has shape (bs, num_queries, 4) when `return_intermediate` + is `False`. Otherwise, Intermediate references of all decoder + layers, has shape (num_decoder_layers, bs, num_queries, 4). The + coordinates are arranged as (cx, cy, w, h) + """ + output = query + intermediate = [] + intermediate_reference_points = [] + for layer_id, layer in enumerate(self.layers): + if reference_points.shape[-1] == 4: + reference_points_input = \ + reference_points[:, :, None] * \ + torch.cat([valid_ratios, valid_ratios], -1)[:, None] + else: + assert reference_points.shape[-1] == 2 + reference_points_input = \ + reference_points[:, :, None] * \ + valid_ratios[:, None] + output = layer( + output, + query_pos=query_pos, + value=value, + key_padding_mask=key_padding_mask, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + reference_points=reference_points_input, + **kwargs) + + if reg_branches is not None: + tmp_reg_preds = reg_branches[layer_id](output) + if reference_points.shape[-1] == 4: + new_reference_points = tmp_reg_preds + inverse_sigmoid( + reference_points) + new_reference_points = new_reference_points.sigmoid() + else: + assert reference_points.shape[-1] == 2 + new_reference_points = tmp_reg_preds + new_reference_points[..., :2] = tmp_reg_preds[ + ..., :2] + inverse_sigmoid(reference_points) + new_reference_points = new_reference_points.sigmoid() + reference_points = new_reference_points.detach() + + if self.return_intermediate: + intermediate.append(output) + intermediate_reference_points.append(reference_points) + + if self.return_intermediate: + return torch.stack(intermediate), torch.stack( + intermediate_reference_points) + + return output, reference_points + + +class DeformableDetrTransformerEncoderLayer(DetrTransformerEncoderLayer): + """Encoder layer of Deformable DETR.""" + + def _init_layers(self) -> None: + """Initialize self_attn, ffn, and norms.""" + self.self_attn = MultiScaleDeformableAttention(**self.self_attn_cfg) + self.embed_dims = self.self_attn.embed_dims + self.ffn = FFN(**self.ffn_cfg) + norms_list = [ + build_norm_layer(self.norm_cfg, self.embed_dims)[1] + for _ in range(2) + ] + self.norms = ModuleList(norms_list) + + +class DeformableDetrTransformerDecoderLayer(DetrTransformerDecoderLayer): + """Decoder layer of Deformable DETR.""" + + def _init_layers(self) -> None: + """Initialize self_attn, cross-attn, ffn, and norms.""" + self.self_attn = MultiheadAttention(**self.self_attn_cfg) + self.cross_attn = MultiScaleDeformableAttention(**self.cross_attn_cfg) + self.embed_dims = self.self_attn.embed_dims + self.ffn = FFN(**self.ffn_cfg) + norms_list = [ + build_norm_layer(self.norm_cfg, self.embed_dims)[1] + for _ in range(3) + ] + self.norms = ModuleList(norms_list) diff --git a/mmpose/models/heads/transformer_heads/transformers/detr_layers.py b/mmpose/models/heads/transformer_heads/transformers/detr_layers.py new file mode 100644 index 0000000000..a669c5dda6 --- /dev/null +++ b/mmpose/models/heads/transformer_heads/transformers/detr_layers.py @@ -0,0 +1,354 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Union + +import torch +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention +from mmengine import ConfigDict +from mmengine.model import BaseModule, ModuleList +from torch import Tensor + +from mmpose.utils.typing import ConfigType, OptConfigType + + +class DetrTransformerEncoder(BaseModule): + """Encoder of DETR. + + Args: + num_layers (int): Number of encoder layers. + layer_cfg (:obj:`ConfigDict` or dict): the config of each encoder + layer. All the layers will share the same config. + init_cfg (:obj:`ConfigDict` or dict, optional): the config to control + the initialization. Defaults to None. + """ + + def __init__(self, + num_layers: int, + layer_cfg: ConfigType, + init_cfg: OptConfigType = None) -> None: + + super().__init__(init_cfg=init_cfg) + self.num_layers = num_layers + self.layer_cfg = layer_cfg + self._init_layers() + + def _init_layers(self) -> None: + """Initialize encoder layers.""" + self.layers = ModuleList([ + DetrTransformerEncoderLayer(**self.layer_cfg) + for _ in range(self.num_layers) + ]) + self.embed_dims = self.layers[0].embed_dims + + def forward(self, query: Tensor, query_pos: Tensor, + key_padding_mask: Tensor, **kwargs) -> Tensor: + """Forward function of encoder. + + Args: + query (Tensor): Input queries of encoder, has shape + (bs, num_queries, dim). + query_pos (Tensor): The positional embeddings of the queries, has + shape (bs, num_queries, dim). + key_padding_mask (Tensor): The `key_padding_mask` of `self_attn` + input. ByteTensor, has shape (bs, num_queries). + + Returns: + Tensor: Has shape (bs, num_queries, dim) if `batch_first` is + `True`, otherwise (num_queries, bs, dim). + """ + for layer in self.layers: + query = layer(query, query_pos, key_padding_mask, **kwargs) + return query + + +class DetrTransformerDecoder(BaseModule): + """Decoder of DETR. + + Args: + num_layers (int): Number of decoder layers. + layer_cfg (:obj:`ConfigDict` or dict): the config of each encoder + layer. All the layers will share the same config. + post_norm_cfg (:obj:`ConfigDict` or dict, optional): Config of the + post normalization layer. Defaults to `LN`. + return_intermediate (bool, optional): Whether to return outputs of + intermediate layers. Defaults to `True`, + init_cfg (:obj:`ConfigDict` or dict, optional): the config to control + the initialization. Defaults to None. + """ + + def __init__(self, + num_layers: int, + layer_cfg: ConfigType, + post_norm_cfg: OptConfigType = dict(type='LN'), + return_intermediate: bool = True, + init_cfg: Union[dict, ConfigDict] = None) -> None: + super().__init__(init_cfg=init_cfg) + self.layer_cfg = layer_cfg + self.num_layers = num_layers + self.post_norm_cfg = post_norm_cfg + self.return_intermediate = return_intermediate + self._init_layers() + + def _init_layers(self) -> None: + """Initialize decoder layers.""" + self.layers = ModuleList([ + DetrTransformerDecoderLayer(**self.layer_cfg) + for _ in range(self.num_layers) + ]) + self.embed_dims = self.layers[0].embed_dims + self.post_norm = build_norm_layer(self.post_norm_cfg, + self.embed_dims)[1] + + def forward(self, query: Tensor, key: Tensor, value: Tensor, + query_pos: Tensor, key_pos: Tensor, key_padding_mask: Tensor, + **kwargs) -> Tensor: + """Forward function of decoder + Args: + query (Tensor): The input query, has shape (bs, num_queries, dim). + key (Tensor): The input key, has shape (bs, num_keys, dim). + value (Tensor): The input value with the same shape as `key`. + query_pos (Tensor): The positional encoding for `query`, with the + same shape as `query`. + key_pos (Tensor): The positional encoding for `key`, with the + same shape as `key`. + key_padding_mask (Tensor): The `key_padding_mask` of `cross_attn` + input. ByteTensor, has shape (bs, num_value). + + Returns: + Tensor: The forwarded results will have shape + (num_decoder_layers, bs, num_queries, dim) if + `return_intermediate` is `True` else (1, bs, num_queries, dim). + """ + intermediate = [] + for layer in self.layers: + query = layer( + query, + key=key, + value=value, + query_pos=query_pos, + key_pos=key_pos, + key_padding_mask=key_padding_mask, + **kwargs) + if self.return_intermediate: + intermediate.append(self.post_norm(query)) + query = self.post_norm(query) + + if self.return_intermediate: + return torch.stack(intermediate) + + return query.unsqueeze(0) + + +class DetrTransformerEncoderLayer(BaseModule): + """Implements encoder layer in DETR transformer. + + Args: + self_attn_cfg (:obj:`ConfigDict` or dict, optional): Config for self + attention. + ffn_cfg (:obj:`ConfigDict` or dict, optional): Config for FFN. + norm_cfg (:obj:`ConfigDict` or dict, optional): Config for + normalization layers. All the layers will share the same + config. Defaults to `LN`. + init_cfg (:obj:`ConfigDict` or dict, optional): Config to control + the initialization. Defaults to None. + """ + + def __init__(self, + self_attn_cfg: OptConfigType = dict( + embed_dims=256, num_heads=8, dropout=0.0), + ffn_cfg: OptConfigType = dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0., + act_cfg=dict(type='ReLU', inplace=True)), + norm_cfg: OptConfigType = dict(type='LN'), + init_cfg: OptConfigType = None) -> None: + + super().__init__(init_cfg=init_cfg) + + self.self_attn_cfg = self_attn_cfg + if 'batch_first' not in self.self_attn_cfg: + self.self_attn_cfg['batch_first'] = True + else: + assert self.self_attn_cfg['batch_first'] is True, 'First \ + dimension of all DETRs in mmdet is `batch`, \ + please set `batch_first` flag.' + + self.ffn_cfg = ffn_cfg + self.norm_cfg = norm_cfg + self._init_layers() + + def _init_layers(self) -> None: + """Initialize self-attention, FFN, and normalization.""" + self.self_attn = MultiheadAttention(**self.self_attn_cfg) + self.embed_dims = self.self_attn.embed_dims + self.ffn = FFN(**self.ffn_cfg) + norms_list = [ + build_norm_layer(self.norm_cfg, self.embed_dims)[1] + for _ in range(2) + ] + self.norms = ModuleList(norms_list) + + def forward(self, query: Tensor, query_pos: Tensor, + key_padding_mask: Tensor, **kwargs) -> Tensor: + """Forward function of an encoder layer. + + Args: + query (Tensor): The input query, has shape (bs, num_queries, dim). + query_pos (Tensor): The positional encoding for query, with + the same shape as `query`. + key_padding_mask (Tensor): The `key_padding_mask` of `self_attn` + input. ByteTensor. has shape (bs, num_queries). + Returns: + Tensor: forwarded results, has shape (bs, num_queries, dim). + """ + query = self.self_attn( + query=query, + key=query, + value=query, + query_pos=query_pos, + key_pos=query_pos, + key_padding_mask=key_padding_mask, + **kwargs) + query = self.norms[0](query) + query = self.ffn(query) + query = self.norms[1](query) + + return query + + +class DetrTransformerDecoderLayer(BaseModule): + """Implements decoder layer in DETR transformer. + + Args: + self_attn_cfg (:obj:`ConfigDict` or dict, optional): Config for self + attention. + cross_attn_cfg (:obj:`ConfigDict` or dict, optional): Config for cross + attention. + ffn_cfg (:obj:`ConfigDict` or dict, optional): Config for FFN. + norm_cfg (:obj:`ConfigDict` or dict, optional): Config for + normalization layers. All the layers will share the same + config. Defaults to `LN`. + init_cfg (:obj:`ConfigDict` or dict, optional): Config to control + the initialization. Defaults to None. + """ + + def __init__(self, + self_attn_cfg: OptConfigType = dict( + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + cross_attn_cfg: OptConfigType = dict( + embed_dims=256, + num_heads=8, + dropout=0.0, + batch_first=True), + ffn_cfg: OptConfigType = dict( + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0., + act_cfg=dict(type='ReLU', inplace=True), + ), + norm_cfg: OptConfigType = dict(type='LN'), + init_cfg: OptConfigType = None) -> None: + + super().__init__(init_cfg=init_cfg) + + self.self_attn_cfg = self_attn_cfg + self.cross_attn_cfg = cross_attn_cfg + if 'batch_first' not in self.self_attn_cfg: + self.self_attn_cfg['batch_first'] = True + else: + assert self.self_attn_cfg['batch_first'] is True, 'First \ + dimension of all DETRs in mmdet is `batch`, \ + please set `batch_first` flag.' + + if 'batch_first' not in self.cross_attn_cfg: + self.cross_attn_cfg['batch_first'] = True + else: + assert self.cross_attn_cfg['batch_first'] is True, 'First \ + dimension of all DETRs in mmdet is `batch`, \ + please set `batch_first` flag.' + + self.ffn_cfg = ffn_cfg + self.norm_cfg = norm_cfg + self._init_layers() + + def _init_layers(self) -> None: + """Initialize self-attention, FFN, and normalization.""" + self.self_attn = MultiheadAttention(**self.self_attn_cfg) + self.cross_attn = MultiheadAttention(**self.cross_attn_cfg) + self.embed_dims = self.self_attn.embed_dims + self.ffn = FFN(**self.ffn_cfg) + norms_list = [ + build_norm_layer(self.norm_cfg, self.embed_dims)[1] + for _ in range(3) + ] + self.norms = ModuleList(norms_list) + + def forward(self, + query: Tensor, + key: Tensor = None, + value: Tensor = None, + query_pos: Tensor = None, + key_pos: Tensor = None, + self_attn_mask: Tensor = None, + cross_attn_mask: Tensor = None, + key_padding_mask: Tensor = None, + **kwargs) -> Tensor: + """ + Args: + query (Tensor): The input query, has shape (bs, num_queries, dim). + key (Tensor, optional): The input key, has shape (bs, num_keys, + dim). If `None`, the `query` will be used. Defaults to `None`. + value (Tensor, optional): The input value, has the same shape as + `key`, as in `nn.MultiheadAttention.forward`. If `None`, the + `key` will be used. Defaults to `None`. + query_pos (Tensor, optional): The positional encoding for `query`, + has the same shape as `query`. If not `None`, it will be added + to `query` before forward function. Defaults to `None`. + key_pos (Tensor, optional): The positional encoding for `key`, has + the same shape as `key`. If not `None`, it will be added to + `key` before forward function. If None, and `query_pos` has the + same shape as `key`, then `query_pos` will be used for + `key_pos`. Defaults to None. + self_attn_mask (Tensor, optional): ByteTensor mask, has shape + (num_queries, num_keys), as in `nn.MultiheadAttention.forward`. + Defaults to None. + cross_attn_mask (Tensor, optional): ByteTensor mask, has shape + (num_queries, num_keys), as in `nn.MultiheadAttention.forward`. + Defaults to None. + key_padding_mask (Tensor, optional): The `key_padding_mask` of + `self_attn` input. ByteTensor, has shape (bs, num_value). + Defaults to None. + + Returns: + Tensor: forwarded results, has shape (bs, num_queries, dim). + """ + + query = self.self_attn( + query=query, + key=query, + value=query, + query_pos=query_pos, + key_pos=query_pos, + attn_mask=self_attn_mask, + **kwargs) + query = self.norms[0](query) + query = self.cross_attn( + query=query, + key=key, + value=value, + query_pos=query_pos, + key_pos=key_pos, + attn_mask=cross_attn_mask, + key_padding_mask=key_padding_mask, + **kwargs) + query = self.norms[1](query) + query = self.ffn(query) + query = self.norms[2](query) + + return query diff --git a/mmpose/models/heads/transformer_heads/transformers/utils.py b/mmpose/models/heads/transformer_heads/transformers/utils.py new file mode 100644 index 0000000000..7d7c086dc8 --- /dev/null +++ b/mmpose/models/heads/transformer_heads/transformers/utils.py @@ -0,0 +1,103 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.nn.functional as F +from mmcv.cnn import Linear +from mmengine.model import BaseModule, ModuleList +from torch import Tensor + + +class FFN(BaseModule): + """Very simple multi-layer perceptron with relu. Mostly used in DETR series + detectors. + + Args: + input_dim (int): Feature dim of the input tensor. + hidden_dim (int): Feature dim of the hidden layer. + output_dim (int): Feature dim of the output tensor. + num_layers (int): Number of FFN layers.. + """ + + def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, + num_layers: int) -> None: + super().__init__() + + self.num_layers = num_layers + + self.layers = ModuleList() + self.layers.append(Linear(input_dim, hidden_dim)) + for _ in range(num_layers - 2): + self.layers.append(Linear(hidden_dim, hidden_dim)) + self.layers.append(Linear(hidden_dim, output_dim)) + + def forward(self, x: Tensor) -> Tensor: + """Forward function of FFN. + + Args: + x (Tensor): The input feature, has shape + (num_queries, bs, input_dim). + Returns: + Tensor: The output feature, has shape + (num_queries, bs, output_dim). + """ + for i, layer in enumerate(self.layers): + x = layer(x) + if i < self.num_layers - 1: + x = F.relu(x) + return x + + +class PositionEmbeddingSineHW(BaseModule): + """This is a more standard version of the position embedding, very similar + to the one used by the Attention is all you need paper, generalized to work + on images.""" + + def __init__(self, + num_pos_feats=64, + temperatureH=10000, + temperatureW=10000, + normalize=False, + scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperatureH = temperatureH + self.temperatureW = temperatureW + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError('normalize should be True if scale is passed') + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, mask: Tensor): + + assert mask is not None + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_tx = torch.arange( + self.num_pos_feats, dtype=torch.float32, device=mask.device) + dim_tx = self.temperatureW**(2 * (dim_tx // 2) / self.num_pos_feats) + pos_x = x_embed[:, :, :, None] / dim_tx + + dim_ty = torch.arange( + self.num_pos_feats, dtype=torch.float32, device=mask.device) + dim_ty = self.temperatureH**(2 * (dim_ty // 2) / self.num_pos_feats) + pos_y = y_embed[:, :, :, None] / dim_ty + + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), + dim=4).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), + dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + + return pos diff --git a/mmpose/models/necks/channel_mapper.py b/mmpose/models/necks/channel_mapper.py index 246ed363d8..4d4148a089 100644 --- a/mmpose/models/necks/channel_mapper.py +++ b/mmpose/models/necks/channel_mapper.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import List, Tuple +from typing import List, Tuple, Union import torch.nn as nn from mmcv.cnn import ConvModule @@ -56,6 +56,7 @@ def __init__( norm_cfg: OptConfigType = None, act_cfg: OptConfigType = dict(type='ReLU'), num_outs: int = None, + bias: Union[bool, str] = 'auto', init_cfg: OptMultiConfig = dict( type='Xavier', layer='Conv2d', distribution='uniform') ) -> None: @@ -71,6 +72,7 @@ def __init__( in_channel, out_channels, kernel_size, + bias=bias, padding=(kernel_size - 1) // 2, conv_cfg=conv_cfg, norm_cfg=norm_cfg, @@ -89,6 +91,7 @@ def __init__( 3, stride=2, padding=1, + bias=bias, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)) diff --git a/mmpose/models/utils/__init__.py b/mmpose/models/utils/__init__.py index 545fc4c64d..539da6ea2f 100644 --- a/mmpose/models/utils/__init__.py +++ b/mmpose/models/utils/__init__.py @@ -3,10 +3,12 @@ from .ckpt_convert import pvt_convert from .csp_layer import CSPLayer from .misc import filter_scores_and_topk +from .ops import FrozenBatchNorm2d, inverse_sigmoid from .rtmcc_block import RTMCCBlock, rope from .transformer import PatchEmbed, nchw_to_nlc, nlc_to_nchw __all__ = [ 'PatchEmbed', 'nchw_to_nlc', 'nlc_to_nchw', 'pvt_convert', 'RTMCCBlock', - 'rope', 'check_and_update_config', 'filter_scores_and_topk', 'CSPLayer' + 'rope', 'check_and_update_config', 'filter_scores_and_topk', 'CSPLayer', + 'FrozenBatchNorm2d', 'inverse_sigmoid' ] diff --git a/mmpose/models/utils/ops.py b/mmpose/models/utils/ops.py index 0c94352647..d1ba0cf37c 100644 --- a/mmpose/models/utils/ops.py +++ b/mmpose/models/utils/ops.py @@ -3,8 +3,11 @@ from typing import Optional, Tuple, Union import torch +from torch import Tensor from torch.nn import functional as F +from mmpose.registry import MODELS + def resize(input: torch.Tensor, size: Optional[Union[Tuple[int, int], torch.Size]] = None, @@ -50,3 +53,58 @@ def resize(input: torch.Tensor, # Perform the resizing operation return F.interpolate(input, size, scale_factor, mode, align_corners) + + +@MODELS.register_module() +class FrozenBatchNorm2d(torch.nn.Module): + """BatchNorm2d where the batch statistics and the affine parameters are + fixed. + + Copy-paste from torchvision.misc.ops with added eps before rqsrt, without + which any other models than torchvision.models.resnet[18,34,50,101] produce + nans. + """ + + def __init__(self, n, eps: int = 1e-5): + super(FrozenBatchNorm2d, self).__init__() + self.register_buffer('weight', torch.ones(n)) + self.register_buffer('bias', torch.zeros(n)) + self.register_buffer('running_mean', torch.zeros(n)) + self.register_buffer('running_var', torch.ones(n)) + self.eps = eps + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + num_batches_tracked_key = prefix + 'num_batches_tracked' + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super(FrozenBatchNorm2d, + self)._load_from_state_dict(state_dict, prefix, local_metadata, + strict, missing_keys, + unexpected_keys, error_msgs) + + def forward(self, x): + w = self.weight.reshape(1, -1, 1, 1) + b = self.bias.reshape(1, -1, 1, 1) + rv = self.running_var.reshape(1, -1, 1, 1) + rm = self.running_mean.reshape(1, -1, 1, 1) + scale = w * (rv + self.eps).rsqrt() + bias = b - rm * scale + return x * scale + bias + + +def inverse_sigmoid(x: Tensor, eps: float = 1e-3) -> Tensor: + """Inverse function of sigmoid. + + Args: + x (Tensor): The tensor to do the inverse. + eps (float): EPS avoid numerical overflow. Defaults 1e-5. + Returns: + Tensor: The x has passed the inverse function of sigmoid, has the same + shape with input. + """ + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1 / x2) diff --git a/mmpose/visualization/local_visualizer_3d.py b/mmpose/visualization/local_visualizer_3d.py index 316fd3147c..63a1d5b47c 100644 --- a/mmpose/visualization/local_visualizer_3d.py +++ b/mmpose/visualization/local_visualizer_3d.py @@ -197,7 +197,7 @@ def _draw_3d_instances_kpts(keypoints, kpt_color = kpt_color[valid][..., ::-1] / 255. - ax.scatter(x_3d, y_3d, z_3d, marker='o', color=kpt_color) + ax.scatter(x_3d, y_3d, z_3d, marker='o', c=kpt_color) for kpt_idx in range(len(x_3d)): ax.text(x_3d[kpt_idx][0], y_3d[kpt_idx][0], diff --git a/model-index.yml b/model-index.yml index 8dc3f25054..1837b97eae 100644 --- a/model-index.yml +++ b/model-index.yml @@ -12,6 +12,7 @@ Import: - configs/body_2d_keypoint/rtmpose/body8/rtmpose_body8-coco.yml - configs/body_2d_keypoint/rtmpose/body8/rtmpose_body8-halpe26.yml - configs/body_2d_keypoint/dekr/crowdpose/hrnet_crowdpose.yml +- configs/body_2d_keypoint/edpose/coco/edpose_coco.yml - configs/body_2d_keypoint/integral_regression/coco/resnet_ipr_coco.yml - configs/body_2d_keypoint/integral_regression/coco/resnet_dsnt_coco.yml - configs/body_2d_keypoint/integral_regression/coco/resnet_debias_coco.yml diff --git a/tests/test_codecs/test_edpose_label.py b/tests/test_codecs/test_edpose_label.py new file mode 100644 index 0000000000..79e4d3fe27 --- /dev/null +++ b/tests/test_codecs/test_edpose_label.py @@ -0,0 +1,58 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from unittest import TestCase + +import numpy as np + +from mmpose.codecs import EDPoseLabel + + +class TestEDPoseLabel(TestCase): + + def setUp(self): + self.encoder = EDPoseLabel(num_select=2, num_keypoints=2) + self.img_shape = (640, 480) + self.keypoints = np.array([[[100, 50], [200, 50]], + [[300, 400], [100, 200]]]) + self.area = np.array([5000, 8000]) + + def test_encode(self): + # Test encoding + encoded_data = self.encoder.encode( + img_shape=self.img_shape, keypoints=self.keypoints, area=self.area) + + self.assertEqual(encoded_data['keypoints'].shape, self.keypoints.shape) + self.assertEqual(encoded_data['area'].shape, self.area.shape) + + # Check if the keypoints were normalized correctly + expected_keypoints = self.keypoints / np.array( + self.img_shape, dtype=np.float32) + np.testing.assert_array_almost_equal(encoded_data['keypoints'], + expected_keypoints) + + # Check if the area was normalized correctly + expected_area = self.area / float( + self.img_shape[0] * self.img_shape[1]) + np.testing.assert_array_almost_equal(encoded_data['area'], + expected_area) + + def test_decode(self): + # Dummy predictions for logits, boxes, and keypoints + pred_logits = np.array([0.7, 0.6]).reshape(2, 1) + pred_boxes = np.array([[0.1, 0.1, 0.5, 0.5], [0.6, 0.6, 0.8, 0.8]]) + pred_keypoints = np.array([[0.2, 0.3, 1, 0.3, 0.4, 1], + [0.6, 0.7, 1, 0.7, 0.8, 1]]) + input_shapes = np.array(self.img_shape) + + # Test decoding + boxes, keypoints, scores = self.encoder.decode( + input_shapes=input_shapes, + pred_logits=pred_logits, + pred_boxes=pred_boxes, + pred_keypoints=pred_keypoints) + + self.assertEqual(boxes.shape, pred_boxes.shape) + self.assertEqual(keypoints.shape, (self.encoder.num_select, + self.encoder.num_keypoints, 2)) + self.assertEqual(scores.shape, + (self.encoder.num_select, self.encoder.num_keypoints)) diff --git a/tests/test_datasets/test_transforms/test_bottomup_transforms.py b/tests/test_datasets/test_transforms/test_bottomup_transforms.py index cded7a6efb..8d9213c729 100644 --- a/tests/test_datasets/test_transforms/test_bottomup_transforms.py +++ b/tests/test_datasets/test_transforms/test_bottomup_transforms.py @@ -6,7 +6,9 @@ from mmcv.transforms import Compose from mmpose.datasets.transforms import (BottomupGetHeatmapMask, - BottomupRandomAffine, BottomupResize, + BottomupRandomAffine, + BottomupRandomChoiceResize, + BottomupRandomCrop, BottomupResize, RandomFlip) from mmpose.testing import get_coco_sample @@ -145,3 +147,166 @@ def test_transform(self): self.assertIsInstance(results['input_scale'], np.ndarray) self.assertEqual(results['img'][0].shape, (256, 256, 3)) self.assertEqual(results['img'][1].shape, (384, 384, 3)) + + +class TestBottomupRandomCrop(TestCase): + + def setUp(self): + # test invalid crop_type + with self.assertRaisesRegex(ValueError, 'Invalid crop_type'): + BottomupRandomCrop(crop_size=(10, 10), crop_type='unknown') + + crop_type_list = ['absolute', 'absolute_range'] + for crop_type in crop_type_list: + # test h > 0 and w > 0 + for crop_size in [(0, 0), (0, 1), (1, 0)]: + with self.assertRaises(AssertionError): + BottomupRandomCrop( + crop_size=crop_size, crop_type=crop_type) + # test type(h) = int and type(w) = int + for crop_size in [(1.0, 1), (1, 1.0), (1.0, 1.0)]: + with self.assertRaises(AssertionError): + BottomupRandomCrop( + crop_size=crop_size, crop_type=crop_type) + + # test crop_size[0] <= crop_size[1] + with self.assertRaises(AssertionError): + BottomupRandomCrop(crop_size=(10, 5), crop_type='absolute_range') + + # test h in (0, 1] and w in (0, 1] + crop_type_list = ['relative_range', 'relative'] + for crop_type in crop_type_list: + for crop_size in [(0, 1), (1, 0), (1.1, 0.5), (0.5, 1.1)]: + with self.assertRaises(AssertionError): + BottomupRandomCrop( + crop_size=crop_size, crop_type=crop_type) + + self.data_info = get_coco_sample(img_shape=(24, 32)) + + def test_transform(self): + # test relative and absolute crop + src_results = self.data_info + target_shape = (12, 16) + for crop_type, crop_size in zip(['relative', 'absolute'], [(0.5, 0.5), + (16, 12)]): + transform = BottomupRandomCrop( + crop_size=crop_size, crop_type=crop_type) + results = transform(deepcopy(src_results)) + self.assertEqual(results['img'].shape[:2], target_shape) + + # test absolute_range crop + transform = BottomupRandomCrop( + crop_size=(10, 20), crop_type='absolute_range') + results = transform(deepcopy(src_results)) + h, w = results['img'].shape[:2] + self.assertTrue(10 <= w <= 20) + self.assertTrue(10 <= h <= 20) + self.assertEqual(results['img_shape'], results['img'].shape[:2]) + # test relative_range crop + transform = BottomupRandomCrop( + crop_size=(0.5, 0.5), crop_type='relative_range') + results = transform(deepcopy(src_results)) + h, w = results['img'].shape[:2] + self.assertTrue(16 <= w <= 32) + self.assertTrue(12 <= h <= 24) + self.assertEqual(results['img_shape'], results['img'].shape[:2]) + + # test with keypoints, bbox, segmentation + src_results = get_coco_sample(img_shape=(10, 10), num_instances=2) + segmentation = np.random.randint(0, 255, size=(10, 10), dtype=np.uint8) + keypoints = np.ones_like(src_results['keypoints']) * 5 + src_results['segmentation'] = segmentation + src_results['keypoints'] = keypoints + transform = BottomupRandomCrop( + crop_size=(7, 5), + allow_negative_crop=False, + recompute_bbox=False, + bbox_clip_border=True) + results = transform(deepcopy(src_results)) + h, w = results['img'].shape[:2] + self.assertEqual(h, 5) + self.assertEqual(w, 7) + self.assertEqual(results['bbox'].shape[0], 2) + self.assertTrue(results['keypoints_visible'].all()) + self.assertTupleEqual(results['segmentation'].shape[:2], (5, 7)) + self.assertEqual(results['img_shape'], results['img'].shape[:2]) + + # test bbox_clip_border = False + transform = BottomupRandomCrop( + crop_size=(10, 11), + allow_negative_crop=False, + recompute_bbox=True, + bbox_clip_border=False) + results = transform(deepcopy(src_results)) + self.assertTrue((results['bbox'] == src_results['bbox']).all()) + + # test the crop does not contain any gt-bbox + # allow_negative_crop = False + img = np.random.randint(0, 255, size=(10, 10), dtype=np.uint8) + bbox = np.zeros((0, 4), dtype=np.float32) + src_results = {'img': img, 'bbox': bbox} + transform = BottomupRandomCrop( + crop_size=(5, 3), allow_negative_crop=False) + results = transform(deepcopy(src_results)) + self.assertIsNone(results) + + # allow_negative_crop = True + img = np.random.randint(0, 255, size=(10, 10), dtype=np.uint8) + bbox = np.zeros((0, 4), dtype=np.float32) + src_results = {'img': img, 'bbox': bbox} + transform = BottomupRandomCrop( + crop_size=(5, 3), allow_negative_crop=True) + results = transform(deepcopy(src_results)) + self.assertTrue(isinstance(results, dict)) + + +class TestBottomupRandomChoiceResize(TestCase): + + def setUp(self): + self.data_info = get_coco_sample(img_shape=(300, 400)) + + def test_transform(self): + results = dict() + # test with one scale + transform = BottomupRandomChoiceResize(scales=[(1333, 800)]) + results = deepcopy(self.data_info) + results = transform(results) + self.assertEqual(results['img'].shape, (800, 1333, 3)) + + # test with multi scales + _scale_choice = [(1333, 800), (1333, 600)] + transform = BottomupRandomChoiceResize(scales=_scale_choice) + results = deepcopy(self.data_info) + results = transform(results) + self.assertIn((results['img'].shape[1], results['img'].shape[0]), + _scale_choice) + + # test keep_ratio + transform = BottomupRandomChoiceResize( + scales=[(900, 600)], resize_type='Resize', keep_ratio=True) + results = deepcopy(self.data_info) + _input_ratio = results['img'].shape[0] / results['img'].shape[1] + results = transform(results) + _output_ratio = results['img'].shape[0] / results['img'].shape[1] + self.assertLess(abs(_input_ratio - _output_ratio), 1.5 * 1e-3) + + # test clip_object_border + bbox = [[200, 150, 600, 450]] + transform = BottomupRandomChoiceResize( + scales=[(200, 150)], resize_type='Resize', clip_object_border=True) + results = deepcopy(self.data_info) + results['bbox'] = np.array(bbox) + results = transform(results) + self.assertEqual(results['img'].shape, (150, 200, 3)) + self.assertTrue((results['bbox'] == np.array([[100, 75, 200, + 150]])).all()) + + transform = BottomupRandomChoiceResize( + scales=[(200, 150)], + resize_type='Resize', + clip_object_border=False) + results = self.data_info + results['bbox'] = np.array(bbox) + results = transform(results) + assert results['img'].shape == (150, 200, 3) + assert np.equal(results['bbox'], np.array([[100, 75, 300, 225]])).all()