diff --git a/configs/body_2d_keypoint/topdown_heatmap/coco/resnet_coco.md b/configs/body_2d_keypoint/topdown_heatmap/coco/resnet_coco.md index 4ce6da38c6..dbe14267ed 100644 --- a/configs/body_2d_keypoint/topdown_heatmap/coco/resnet_coco.md +++ b/configs/body_2d_keypoint/topdown_heatmap/coco/resnet_coco.md @@ -60,3 +60,9 @@ Results on COCO val2017 with detector having human AP of 56.4 on COCO val2017 da | [pose_resnet_101](/configs/body_2d_keypoint/topdown_heatmap/coco/td-hm_res101_8xb32-210e_coco-384x288.py) | 384x288 | 0.749 | 0.906 | 0.817 | 0.799 | 0.941 | [ckpt](https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/topdown_heatmap/coco/td-hm_res101_8xb64-210e_coco-256x192-065d3625_20220926.pth) | [log](https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/topdown_heatmap/coco/td-hm_res101_8xb64-210e_coco-256x192_20220926.log) | | [pose_resnet_152](/configs/body_2d_keypoint/topdown_heatmap/coco/td-hm_res152_8xb32-210e_coco-256x192.py) | 256x192 | 0.736 | 0.904 | 0.818 | 0.791 | 0.942 | [ckpt](https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/topdown_heatmap/coco/td-hm_res152_8xb32-210e_coco-256x192-0345f330_20220928.pth) | [log](https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/topdown_heatmap/coco/td-hm_res152_8xb32-210e_coco-256x192_20220928.log) | | [pose_resnet_152](/configs/body_2d_keypoint/topdown_heatmap/coco/td-hm_res152_8xb32-210e_coco-384x288.py) | 384x288 | 0.750 | 0.908 | 0.821 | 0.800 | 0.942 | [ckpt](https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/topdown_heatmap/coco/td-hm_res152_8xb32-210e_coco-384x288-7fbb906f_20220927.pth) | [log](https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/topdown_heatmap/coco/td-hm_res152_8xb32-210e_coco-384x288_20220927.log) | + +The following model is equipped with a visibility prediction head and has been trained using COCO and AIC datasets. + +| Arch | Input Size | AP | AP50 | AP75 | AR | AR50 | ckpt | log | +| :-------------------------------------------- | :--------: | :---: | :-------------: | :-------------: | :---: | :-------------: | :-------------------------------------------: | :-------------------------------------------: | +| [pose_resnet_50](/configs/body_2d_keypoint/topdown_heatmap/coco/td-hm-vis_res50_8xb64-210e_coco-aic-256x192-merge.py) | 256x192 | 0.729 | 0.900 | 0.807 | 0.783 | 0.938 | [ckpt](https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/topdown_heatmap/coco/td-hm-vis_res50_8xb64-210e_coco-aic-256x192-merge-21815b2c_20230726.pth) | [log](https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/topdown_heatmap/coco/td-hm_res50_8xb64-210e_coco-256x192_20220923.log) | diff --git a/configs/body_2d_keypoint/topdown_heatmap/coco/td-hm-vis_res50_8xb64-210e_coco-aic-256x192-merge.py b/configs/body_2d_keypoint/topdown_heatmap/coco/td-hm-vis_res50_8xb64-210e_coco-aic-256x192-merge.py new file mode 100644 index 0000000000..f5def39ed9 --- /dev/null +++ b/configs/body_2d_keypoint/topdown_heatmap/coco/td-hm-vis_res50_8xb64-210e_coco-aic-256x192-merge.py @@ -0,0 +1,167 @@ +_base_ = ['../../../_base_/default_runtime.py'] + +# runtime +train_cfg = dict(max_epochs=210, val_interval=10) + +# optimizer +optim_wrapper = dict(optimizer=dict( + type='Adam', + lr=5e-4, +)) + +# learning policy +param_scheduler = [ + dict( + type='LinearLR', begin=0, end=500, start_factor=0.001, + by_epoch=False), # warm-up + dict( + type='MultiStepLR', + begin=0, + end=210, + milestones=[170, 200], + gamma=0.1, + by_epoch=True) +] + +# automatically scaling LR based on the actual training batch size +auto_scale_lr = dict(base_batch_size=512) + +# hooks +default_hooks = dict(checkpoint=dict(save_best='coco/AP', rule='greater')) + +# codec settings +codec = dict( + type='MSRAHeatmap', input_size=(192, 256), heatmap_size=(48, 64), sigma=2) + +# model settings +model = dict( + type='TopdownPoseEstimator', + data_preprocessor=dict( + type='PoseDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True), + backbone=dict( + type='ResNet', + depth=50, + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'), + ), + head=dict( + type='VisPredictHead', + loss=dict( + type='BCELoss', + use_target_weight=True, + use_sigmoid=True, + loss_weight=1e-3, + ), + pose_cfg=dict( + type='HeatmapHead', + in_channels=2048, + out_channels=17, + loss=dict(type='KeypointMSELoss', use_target_weight=True), + decoder=codec)), + test_cfg=dict( + flip_test=True, + flip_mode='heatmap', + shift_heatmap=True, + )) + +# base dataset settings +dataset_type = 'CocoDataset' +data_mode = 'topdown' +data_root = 'data/coco/' + +# pipelines +train_pipeline = [ + dict(type='LoadImage'), + dict(type='GetBBoxCenterScale'), + dict(type='RandomFlip', direction='horizontal'), + dict(type='RandomHalfBody'), + dict(type='RandomBBoxTransform'), + dict(type='TopdownAffine', input_size=codec['input_size']), + dict(type='GenerateTarget', encoder=codec), + dict(type='PackPoseInputs') +] +val_pipeline = [ + dict(type='LoadImage'), + dict(type='GetBBoxCenterScale'), + dict(type='TopdownAffine', input_size=codec['input_size']), + dict(type='PackPoseInputs') +] + +# train datasets +dataset_coco = dict( + type=dataset_type, + data_root=data_root, + data_mode=data_mode, + ann_file='annotations/person_keypoints_train2017.json', + data_prefix=dict(img='train2017/'), + pipeline=[], +) + +dataset_aic = dict( + type='AicDataset', + data_root='data/aic/', + data_mode=data_mode, + ann_file='annotations/aic_train.json', + data_prefix=dict(img='ai_challenger_keypoint_train_20170902/' + 'keypoint_train_images_20170902/'), + pipeline=[ + dict( + type='KeypointConverter', + num_keypoints=17, + mapping=[ + (0, 6), + (1, 8), + (2, 10), + (3, 5), + (4, 7), + (5, 9), + (6, 12), + (7, 14), + (8, 16), + (9, 11), + (10, 13), + (11, 15), + ]) + ], +) + +# data loaders +train_dataloader = dict( + batch_size=64, + num_workers=2, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type='CombinedDataset', + metainfo=dict(from_file='configs/_base_/datasets/coco.py'), + datasets=[dataset_coco, dataset_aic], + pipeline=train_pipeline, + test_mode=False, + )) +val_dataloader = dict( + batch_size=32, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_mode=data_mode, + ann_file='annotations/person_keypoints_val2017.json', + bbox_file='data/coco/person_detection_results/' + 'COCO_val2017_detections_AP_H_56_person.json', + data_prefix=dict(img='val2017/'), + test_mode=True, + pipeline=val_pipeline, + )) +test_dataloader = val_dataloader + +# evaluators +val_evaluator = dict( + type='CocoMetric', + # score_mode='bbox', + ann_file=data_root + 'annotations/person_keypoints_val2017.json') +test_evaluator = val_evaluator diff --git a/docs/en/advanced_guides/implement_new_models.md b/docs/en/advanced_guides/implement_new_models.md index da46a99e39..ff54e2c5ff 100644 --- a/docs/en/advanced_guides/implement_new_models.md +++ b/docs/en/advanced_guides/implement_new_models.md @@ -79,3 +79,27 @@ class YourNewHead(BaseHead): ``` Finally, please remember to import your new prediction head in `[__init__.py](https://github.com/open-mmlab/mmpose/blob/main/mmpose/models/heads/__init__.py)` . + +### Head with Keypoints Visibility Prediction + +Many models predict keypoint visibility based on confidence in coordinate predictions. However, this approach is suboptimal. Our [`VisPredictHead`](https://github.com/open-mmlab/mmpose/blob/dev-1.x/mmpose/models/heads/hybrid_heads/vis_head.py) wrapper enables heads to directly predict keypoint visibility from ground truth training data, improving reliability. To add visibility prediction, wrap your head module with VisPredictHead in the config file. + +```python +model=dict( + ... + head=dict( + type='VisPredictHead', + loss=dict( + type='BCELoss', + use_target_weight=True, + use_sigmoid=True, + loss_weight=1e-3), + pose_cfg=dict( + type='HeatmapHead', + in_channels=2048, + out_channels=17, + loss=dict(type='KeypointMSELoss', use_target_weight=True), + decoder=codec)), + ... +) +``` diff --git a/docs/zh_cn/advanced_guides/implement_new_models.md b/docs/zh_cn/advanced_guides/implement_new_models.md index d3ed96bd37..22e866b52b 100644 --- a/docs/zh_cn/advanced_guides/implement_new_models.md +++ b/docs/zh_cn/advanced_guides/implement_new_models.md @@ -78,3 +78,27 @@ class YourNewHead(BaseHead): ``` 最后,请记得在 [heads/\_\_init\_\_.py](https://github.com/open-mmlab/mmpose/blob/main/mmpose/models/heads/__init__.py) 中导入你的新预测头部。 + +### 关键点可见性预测头部 + +许多模型都是通过对关键点坐标预测的置信度来判断关键点的可见性的。然而,这种解决方案并非最优。我们提供了一个叫做 [`VisPredictHead`](https://github.com/open-mmlab/mmpose/blob/dev-1.x/mmpose/models/heads/hybrid_heads/vis_head.py) 的头部模块包装器,使得头部模块能够直接预测关键点的可见性。这个包装器是用训练数据中关键点可见性真值来训练的。因此,其预测会更加可靠。用户可以通过修改配置文件来对自己的头部模块加上这个包装器。下面是一个例子: + +```python +model=dict( + ... + head=dict( + type='VisPredictHead', + loss=dict( + type='BCELoss', + use_target_weight=True, + use_sigmoid=True, + loss_weight=1e-3), + pose_cfg=dict( + type='HeatmapHead', + in_channels=2048, + out_channels=17, + loss=dict(type='KeypointMSELoss', use_target_weight=True), + decoder=codec)), + ... +) +``` diff --git a/mmpose/datasets/transforms/common_transforms.py b/mmpose/datasets/transforms/common_transforms.py index 87068246f8..8f7aa05425 100644 --- a/mmpose/datasets/transforms/common_transforms.py +++ b/mmpose/datasets/transforms/common_transforms.py @@ -340,7 +340,7 @@ def _random_select_half_body(self, keypoints_visible: np.ndarray, Args: keypoints_visible (np.ndarray, optional): The visibility of - keypoints in shape (N, K, 1). + keypoints in shape (N, K, 1) or (N, K, 2). upper_body_ids (list): The list of upper body keypoint indices lower_body_ids (list): The list of lower body keypoint indices @@ -349,6 +349,9 @@ def _random_select_half_body(self, keypoints_visible: np.ndarray, of each instance. ``None`` means not applying half-body transform. """ + if keypoints_visible.ndim == 3: + keypoints_visible = keypoints_visible[..., 0] + half_body_ids = [] for visible in keypoints_visible: @@ -390,7 +393,6 @@ def transform(self, results: Dict) -> Optional[dict]: Returns: dict: The result dict. """ - half_body_ids = self._random_select_half_body( keypoints_visible=results['keypoints_visible'], upper_body_ids=results['upper_body_ids'], @@ -952,6 +954,10 @@ def transform(self, results: Dict) -> Optional[dict]: ' \'keypoints\' in the results.') keypoints_visible = results['keypoints_visible'] + if keypoints_visible.ndim == 3 and keypoints_visible.shape[2] == 2: + keypoints_visible, keypoints_visible_weights = \ + keypoints_visible[..., 0], keypoints_visible[..., 1] + results['keypoints_visible_weights'] = keypoints_visible_weights # Encoded items from the encoder(s) will be updated into the results. # Please refer to the document of the specific codec for details about @@ -1031,16 +1037,6 @@ def transform(self, results: Dict) -> Optional[dict]: results.update(encoded) - if results.get('keypoint_weights', None) is not None: - results['transformed_keypoints_visible'] = results[ - 'keypoint_weights'] - elif results.get('keypoints', None) is not None: - results['transformed_keypoints_visible'] = results[ - 'keypoints_visible'] - else: - raise ValueError('GenerateTarget requires \'keypoint_weights\' or' - ' \'keypoints_visible\' in the results.') - return results def __repr__(self) -> str: diff --git a/mmpose/datasets/transforms/converting.py b/mmpose/datasets/transforms/converting.py index 38dcea0994..c8a4a172cf 100644 --- a/mmpose/datasets/transforms/converting.py +++ b/mmpose/datasets/transforms/converting.py @@ -87,13 +87,18 @@ def __init__(self, num_keypoints: int, self.interpolation = interpolation def transform(self, results: dict) -> dict: + """Transforms the keypoint results to match the target keypoints.""" num_instances = results['keypoints'].shape[0] + # Initialize output arrays keypoints = np.zeros((num_instances, self.num_keypoints, 2)) keypoints_visible = np.zeros((num_instances, self.num_keypoints)) - # When paired source_indexes are input, - # perform interpolation with self.source_index and self.source_index2 + # Create a mask to weight visibility loss + keypoints_visible_weights = keypoints_visible.copy() + keypoints_visible_weights[:, self.target_index] = 1.0 + + # Interpolate keypoints if pairs of source indexes provided if self.interpolation: keypoints[:, self.target_index] = 0.5 * ( results['keypoints'][:, self.source_index] + @@ -102,6 +107,8 @@ def transform(self, results: dict) -> dict: keypoints_visible[:, self.target_index] = results[ 'keypoints_visible'][:, self.source_index] * \ results['keypoints_visible'][:, self.source_index2] + + # Otherwise just copy from the source index else: keypoints[:, self.target_index] = results['keypoints'][:, self. @@ -109,8 +116,10 @@ def transform(self, results: dict) -> dict: keypoints_visible[:, self.target_index] = results[ 'keypoints_visible'][:, self.source_index] + # Update the results dict results['keypoints'] = keypoints - results['keypoints_visible'] = keypoints_visible + results['keypoints_visible'] = np.stack( + [keypoints_visible, keypoints_visible_weights], axis=2) return results def __repr__(self) -> str: diff --git a/mmpose/datasets/transforms/formatting.py b/mmpose/datasets/transforms/formatting.py index d047cff3c3..c2431c70bf 100644 --- a/mmpose/datasets/transforms/formatting.py +++ b/mmpose/datasets/transforms/formatting.py @@ -128,7 +128,7 @@ class PackPoseInputs(BaseTransform): 'keypoint_y_labels': 'keypoint_y_labels', 'keypoint_weights': 'keypoint_weights', 'instance_coords': 'instance_coords', - 'transformed_keypoints_visible': 'keypoints_visible', + 'keypoints_visible_weights': 'keypoints_visible_weights' } # items in `field_mapping_table` will be packed into @@ -195,10 +195,6 @@ def transform(self, results: dict) -> dict: if self.pack_transformed and 'transformed_keypoints' in results: gt_instances.set_field(results['transformed_keypoints'], 'transformed_keypoints') - if self.pack_transformed and \ - 'transformed_keypoints_visible' in results: - gt_instances.set_field(results['transformed_keypoints_visible'], - 'transformed_keypoints_visible') data_sample.gt_instances = gt_instances diff --git a/mmpose/models/heads/hybrid_heads/vis_head.py b/mmpose/models/heads/hybrid_heads/vis_head.py index e9ea271ac5..f95634541b 100644 --- a/mmpose/models/heads/hybrid_heads/vis_head.py +++ b/mmpose/models/heads/hybrid_heads/vis_head.py @@ -31,8 +31,7 @@ def __init__(self, pose_cfg: ConfigType, loss: ConfigType = dict( type='BCELoss', use_target_weight=False, - with_logits=True), - use_sigmoid: bool = False, + use_sigmoid=True), init_cfg: OptConfigType = None): if init_cfg is None: @@ -54,14 +53,14 @@ def __init__(self, self.pose_head = MODELS.build(pose_cfg) self.pose_cfg = pose_cfg - self.use_sigmoid = use_sigmoid + self.use_sigmoid = loss.get('use_sigmoid', False) modules = [ nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(self.in_channels, self.out_channels) ] - if use_sigmoid: + if self.use_sigmoid: modules.append(nn.Sigmoid()) self.vis_head = nn.Sequential(*modules) @@ -113,7 +112,7 @@ def integrate(self, batch_vis: Tensor, assert len(pose_pred_instances) == len(batch_vis_np) for index, _ in enumerate(pose_pred_instances): - pose_pred_instances[index].keypoint_scores = batch_vis_np[index] + pose_pred_instances[index].keypoints_visible = batch_vis_np[index] return pose_pred_instances, pose_pred_fields @@ -176,15 +175,20 @@ def predict(self, return self.integrate(batch_vis, batch_pose) - def vis_accuracy(self, vis_pred_outputs, vis_labels): + @torch.no_grad() + def vis_accuracy(self, vis_pred_outputs, vis_labels, vis_weights=None): """Calculate visibility prediction accuracy.""" - probabilities = torch.sigmoid(torch.flatten(vis_pred_outputs)) + if not self.use_sigmoid: + vis_pred_outputs = torch.sigmoid(vis_pred_outputs) threshold = 0.5 - predictions = (probabilities >= threshold).int() - labels = torch.flatten(vis_labels) - correct = torch.sum(predictions == labels).item() - accuracy = correct / len(labels) - return torch.tensor(accuracy) + predictions = (vis_pred_outputs >= threshold).float() + correct = (predictions == vis_labels).float() + if vis_weights is not None: + accuracy = (correct * vis_weights).sum(dim=1) / ( + vis_weights.sum(dim=1, keepdims=True) + 1e-6) + else: + accuracy = correct.mean(dim=1) + return accuracy.mean() def loss(self, feats: Tuple[Tensor], @@ -203,18 +207,26 @@ def loss(self, dict: A dictionary of losses. """ vis_pred_outputs = self.vis_forward(feats) - vis_labels = torch.cat([ - d.gt_instance_labels.keypoint_weights for d in batch_data_samples - ]) + vis_labels = [] + vis_weights = [] if self.loss_module.use_target_weight else None + for d in batch_data_samples: + vis_label = d.gt_instance_labels.keypoint_weights.float() + vis_labels.append(vis_label) + if vis_weights is not None: + vis_weights.append( + getattr(d.gt_instance_labels, 'keypoints_visible_weights', + vis_label.new_ones(vis_label.shape))) + vis_labels = torch.cat(vis_labels) + vis_weights = torch.cat(vis_weights) if vis_weights else None # calculate vis losses losses = dict() - loss_vis = self.loss_module(vis_pred_outputs, vis_labels) + loss_vis = self.loss_module(vis_pred_outputs, vis_labels, vis_weights) losses.update(loss_vis=loss_vis) # calculate vis accuracy - acc_vis = self.vis_accuracy(vis_pred_outputs, vis_labels) + acc_vis = self.vis_accuracy(vis_pred_outputs, vis_labels, vis_weights) losses.update(acc_vis=acc_vis) # calculate keypoints losses diff --git a/mmpose/models/losses/classification_loss.py b/mmpose/models/losses/classification_loss.py index 4605acabd3..5d2a2c7a58 100644 --- a/mmpose/models/losses/classification_loss.py +++ b/mmpose/models/losses/classification_loss.py @@ -14,15 +14,17 @@ class BCELoss(nn.Module): use_target_weight (bool): Option to use weighted loss. Different joint types may have different target weights. loss_weight (float): Weight of the loss. Default: 1.0. - with_logits (bool): Whether to use BCEWithLogitsLoss. Default: False. + use_sigmoid (bool, optional): Whether the prediction uses sigmoid + before output. Defaults to False. """ def __init__(self, use_target_weight=False, loss_weight=1., - with_logits=False): + use_sigmoid=False): super().__init__() - self.criterion = F.binary_cross_entropy if not with_logits\ + self.use_sigmoid = use_sigmoid + self.criterion = F.binary_cross_entropy if use_sigmoid \ else F.binary_cross_entropy_with_logits self.use_target_weight = use_target_weight self.loss_weight = loss_weight diff --git a/mmpose/models/pose_estimators/bottomup.py b/mmpose/models/pose_estimators/bottomup.py index 5400f2478e..e7d2aaef88 100644 --- a/mmpose/models/pose_estimators/bottomup.py +++ b/mmpose/models/pose_estimators/bottomup.py @@ -169,6 +169,9 @@ def add_pred_to_datasample(self, batch_pred_instances: InstanceList, pred_instances.keypoints = pred_instances.keypoints / input_size \ * input_scale + input_center - 0.5 * input_scale + if 'keypoints_visible' not in pred_instances: + pred_instances.keypoints_visible = \ + pred_instances.keypoint_scores data_sample.pred_instances = pred_instances diff --git a/mmpose/models/pose_estimators/topdown.py b/mmpose/models/pose_estimators/topdown.py index 89b332893f..0704627bd5 100644 --- a/mmpose/models/pose_estimators/topdown.py +++ b/mmpose/models/pose_estimators/topdown.py @@ -153,6 +153,9 @@ def add_pred_to_datasample(self, batch_pred_instances: InstanceList, pred_instances.keypoints = pred_instances.keypoints / input_size \ * bbox_scales + bbox_centers - 0.5 * bbox_scales + if 'keypoints_visible' not in pred_instances: + pred_instances.keypoints_visible = \ + pred_instances.keypoint_scores if output_keypoint_indices is not None: # select output keypoints with given indices diff --git a/mmpose/structures/keypoint/transforms.py b/mmpose/structures/keypoint/transforms.py index b50da4f8fe..bd7274dadf 100644 --- a/mmpose/structures/keypoint/transforms.py +++ b/mmpose/structures/keypoint/transforms.py @@ -20,8 +20,8 @@ def flip_keypoints(keypoints: np.ndarray, Args: keypoints (np.ndarray): Keypoints in shape (..., K, D) keypoints_visible (np.ndarray, optional): The visibility of keypoints - in shape (..., K, 1). Set ``None`` if the keypoint visibility is - unavailable + in shape (..., K, 1) or (..., K, 2). Set ``None`` if the keypoint + visibility is unavailable image_size (tuple): The image shape in [w, h] flip_indices (List[int]): The indices of each keypoint's symmetric keypoint @@ -33,11 +33,12 @@ def flip_keypoints(keypoints: np.ndarray, - keypoints_flipped (np.ndarray): Flipped keypoints in shape (..., K, D) - keypoints_visible_flipped (np.ndarray, optional): Flipped keypoints' - visibility in shape (..., K, 1). Return ``None`` if the input - ``keypoints_visible`` is ``None`` + visibility in shape (..., K, 1) or (..., K, 2). Return ``None`` if + the input ``keypoints_visible`` is ``None`` """ - assert keypoints.shape[:-1] == keypoints_visible.shape, ( + ndim = keypoints.ndim + assert keypoints.shape[:-1] == keypoints_visible.shape[:ndim - 1], ( f'Mismatched shapes of keypoints {keypoints.shape} and ' f'keypoints_visible {keypoints_visible.shape}') @@ -48,9 +49,10 @@ def flip_keypoints(keypoints: np.ndarray, # swap the symmetric keypoint pairs if direction == 'horizontal' or direction == 'vertical': - keypoints = keypoints[..., flip_indices, :] + keypoints = keypoints.take(flip_indices, axis=ndim - 2) if keypoints_visible is not None: - keypoints_visible = keypoints_visible[..., flip_indices] + keypoints_visible = keypoints_visible.take( + flip_indices, axis=ndim - 2) # flip the keypoints w, h = image_size diff --git a/mmpose/visualization/local_visualizer.py b/mmpose/visualization/local_visualizer.py index 080e628e33..1eb994f03a 100644 --- a/mmpose/visualization/local_visualizer.py +++ b/mmpose/visualization/local_visualizer.py @@ -253,11 +253,6 @@ def _draw_instances_kpts(self, keypoints = instances.get('transformed_keypoints', instances.keypoints) - if 'keypoint_scores' in instances: - scores = instances.keypoint_scores - else: - scores = np.ones(keypoints.shape[:-1]) - if 'keypoints_visible' in instances: keypoints_visible = instances.keypoints_visible else: @@ -265,15 +260,13 @@ def _draw_instances_kpts(self, if skeleton_style == 'openpose': keypoints_info = np.concatenate( - (keypoints, scores[..., None], keypoints_visible[..., - None]), - axis=-1) + (keypoints, keypoints_visible[..., None]), axis=-1) # compute neck joint neck = np.mean(keypoints_info[:, [5, 6]], axis=1) # neck score when visualizing pred - neck[:, 2:4] = np.logical_and( - keypoints_info[:, 5, 2:4] > kpt_thr, - keypoints_info[:, 6, 2:4] > kpt_thr).astype(int) + neck[:, 2:3] = np.logical_and( + keypoints_info[:, 5, 2:3] > kpt_thr, + keypoints_info[:, 6, 2:3] > kpt_thr).astype(int) new_keypoints_info = np.insert( keypoints_info, 17, neck, axis=1) @@ -287,11 +280,10 @@ def _draw_instances_kpts(self, new_keypoints_info[:, mmpose_idx] keypoints_info = new_keypoints_info - keypoints, scores, keypoints_visible = keypoints_info[ - ..., :2], keypoints_info[..., 2], keypoints_info[..., 3] + keypoints, keypoints_visible = keypoints_info[ + ..., :2], keypoints_info[..., 2] - for kpts, score, visible in zip(keypoints, scores, - keypoints_visible): + for kpts, visible in zip(keypoints, keypoints_visible): kpts = np.array(kpts, copy=False) if self.kpt_color is None or isinstance(self.kpt_color, str): @@ -320,17 +312,16 @@ def _draw_instances_kpts(self, for sk_id, sk in enumerate(self.skeleton): pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1])) pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1])) - if not (visible[sk[0]] and visible[sk[1]]): - continue if (pos1[0] <= 0 or pos1[0] >= img_w or pos1[1] <= 0 or pos1[1] >= img_h or pos2[0] <= 0 or pos2[0] >= img_w or pos2[1] <= 0 - or pos2[1] >= img_h or score[sk[0]] < kpt_thr - or score[sk[1]] < kpt_thr + or pos2[1] >= img_h or visible[sk[0]] < kpt_thr + or visible[sk[1]] < kpt_thr or link_color[sk_id] is None): # skip the link that should not be drawn continue + X = np.array((pos1[0], pos2[0])) Y = np.array((pos1[1], pos2[1])) color = link_color[sk_id] @@ -339,7 +330,9 @@ def _draw_instances_kpts(self, transparency = self.alpha if self.show_keypoint_weight: transparency *= max( - 0, min(1, 0.5 * (score[sk[0]] + score[sk[1]]))) + 0, + min(1, + 0.5 * (visible[sk[0]] + visible[sk[1]]))) if skeleton_style == 'openpose': mX = np.mean(X) @@ -365,8 +358,7 @@ def _draw_instances_kpts(self, # draw each point on image for kid, kpt in enumerate(kpts): - if score[kid] < kpt_thr or not visible[ - kid] or kpt_color[kid] is None: + if visible[kid] < kpt_thr or kpt_color[kid] is None: # skip the point that should not be drawn continue @@ -375,7 +367,7 @@ def _draw_instances_kpts(self, color = tuple(int(c) for c in color) transparency = self.alpha if self.show_keypoint_weight: - transparency *= max(0, min(1, score[kid])) + transparency *= max(0, min(1, visible[kid])) self.draw_circles( kpt, radius=np.array([self.radius]), diff --git a/tests/test_datasets/test_transforms/test_converting.py b/tests/test_datasets/test_transforms/test_converting.py index 09f06e1e65..08561b1d0f 100644 --- a/tests/test_datasets/test_transforms/test_converting.py +++ b/tests/test_datasets/test_transforms/test_converting.py @@ -32,8 +32,10 @@ def test_transform(self): self.assertTrue((results['keypoints'][:, target_index] == self.data_info['keypoints'][:, source_index]).all()) + self.assertEqual(results['keypoints_visible'].ndim, 3) + self.assertEqual(results['keypoints_visible'].shape[2], 2) self.assertTrue( - (results['keypoints_visible'][:, target_index] == + (results['keypoints_visible'][:, target_index, 0] == self.data_info['keypoints_visible'][:, source_index]).all()) # 2-to-1 mapping @@ -58,8 +60,10 @@ def test_transform(self): (results['keypoints'][:, target_index] == 0.5 * (self.data_info['keypoints'][:, source_index] + self.data_info['keypoints'][:, source_index2])).all()) + self.assertEqual(results['keypoints_visible'].ndim, 3) + self.assertEqual(results['keypoints_visible'].shape[2], 2) self.assertTrue( - (results['keypoints_visible'][:, target_index] == + (results['keypoints_visible'][:, target_index, 0] == self.data_info['keypoints_visible'][:, source_index] * self.data_info['keypoints_visible'][:, source_index2]).all()) @@ -67,7 +71,9 @@ def test_transform(self): self.assertTrue( (results['keypoints'][:, target_index] == self.data_info['keypoints'][:, source_index]).all()) + self.assertEqual(results['keypoints_visible'].ndim, 3) + self.assertEqual(results['keypoints_visible'].shape[2], 2) self.assertTrue( - (results['keypoints_visible'][:, target_index] == + (results['keypoints_visible'][:, target_index, 0] == self.data_info['keypoints_visible'][:, source_index]).all())