diff --git a/configs/body_2d_keypoint/topdown_heatmap/coco/resnet_coco.md b/configs/body_2d_keypoint/topdown_heatmap/coco/resnet_coco.md
index 4ce6da38c6..dbe14267ed 100644
--- a/configs/body_2d_keypoint/topdown_heatmap/coco/resnet_coco.md
+++ b/configs/body_2d_keypoint/topdown_heatmap/coco/resnet_coco.md
@@ -60,3 +60,9 @@ Results on COCO val2017 with detector having human AP of 56.4 on COCO val2017 da
| [pose_resnet_101](/configs/body_2d_keypoint/topdown_heatmap/coco/td-hm_res101_8xb32-210e_coco-384x288.py) | 384x288 | 0.749 | 0.906 | 0.817 | 0.799 | 0.941 | [ckpt](https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/topdown_heatmap/coco/td-hm_res101_8xb64-210e_coco-256x192-065d3625_20220926.pth) | [log](https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/topdown_heatmap/coco/td-hm_res101_8xb64-210e_coco-256x192_20220926.log) |
| [pose_resnet_152](/configs/body_2d_keypoint/topdown_heatmap/coco/td-hm_res152_8xb32-210e_coco-256x192.py) | 256x192 | 0.736 | 0.904 | 0.818 | 0.791 | 0.942 | [ckpt](https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/topdown_heatmap/coco/td-hm_res152_8xb32-210e_coco-256x192-0345f330_20220928.pth) | [log](https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/topdown_heatmap/coco/td-hm_res152_8xb32-210e_coco-256x192_20220928.log) |
| [pose_resnet_152](/configs/body_2d_keypoint/topdown_heatmap/coco/td-hm_res152_8xb32-210e_coco-384x288.py) | 384x288 | 0.750 | 0.908 | 0.821 | 0.800 | 0.942 | [ckpt](https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/topdown_heatmap/coco/td-hm_res152_8xb32-210e_coco-384x288-7fbb906f_20220927.pth) | [log](https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/topdown_heatmap/coco/td-hm_res152_8xb32-210e_coco-384x288_20220927.log) |
+
+The following model is equipped with a visibility prediction head and has been trained using COCO and AIC datasets.
+
+| Arch | Input Size | AP | AP50 | AP75 | AR | AR50 | ckpt | log |
+| :-------------------------------------------- | :--------: | :---: | :-------------: | :-------------: | :---: | :-------------: | :-------------------------------------------: | :-------------------------------------------: |
+| [pose_resnet_50](/configs/body_2d_keypoint/topdown_heatmap/coco/td-hm-vis_res50_8xb64-210e_coco-aic-256x192-merge.py) | 256x192 | 0.729 | 0.900 | 0.807 | 0.783 | 0.938 | [ckpt](https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/topdown_heatmap/coco/td-hm-vis_res50_8xb64-210e_coco-aic-256x192-merge-21815b2c_20230726.pth) | [log](https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/topdown_heatmap/coco/td-hm_res50_8xb64-210e_coco-256x192_20220923.log) |
diff --git a/configs/body_2d_keypoint/topdown_heatmap/coco/td-hm-vis_res50_8xb64-210e_coco-aic-256x192-merge.py b/configs/body_2d_keypoint/topdown_heatmap/coco/td-hm-vis_res50_8xb64-210e_coco-aic-256x192-merge.py
new file mode 100644
index 0000000000..f5def39ed9
--- /dev/null
+++ b/configs/body_2d_keypoint/topdown_heatmap/coco/td-hm-vis_res50_8xb64-210e_coco-aic-256x192-merge.py
@@ -0,0 +1,167 @@
+_base_ = ['../../../_base_/default_runtime.py']
+
+# runtime
+train_cfg = dict(max_epochs=210, val_interval=10)
+
+# optimizer
+optim_wrapper = dict(optimizer=dict(
+ type='Adam',
+ lr=5e-4,
+))
+
+# learning policy
+param_scheduler = [
+ dict(
+ type='LinearLR', begin=0, end=500, start_factor=0.001,
+ by_epoch=False), # warm-up
+ dict(
+ type='MultiStepLR',
+ begin=0,
+ end=210,
+ milestones=[170, 200],
+ gamma=0.1,
+ by_epoch=True)
+]
+
+# automatically scaling LR based on the actual training batch size
+auto_scale_lr = dict(base_batch_size=512)
+
+# hooks
+default_hooks = dict(checkpoint=dict(save_best='coco/AP', rule='greater'))
+
+# codec settings
+codec = dict(
+ type='MSRAHeatmap', input_size=(192, 256), heatmap_size=(48, 64), sigma=2)
+
+# model settings
+model = dict(
+ type='TopdownPoseEstimator',
+ data_preprocessor=dict(
+ type='PoseDataPreprocessor',
+ mean=[123.675, 116.28, 103.53],
+ std=[58.395, 57.12, 57.375],
+ bgr_to_rgb=True),
+ backbone=dict(
+ type='ResNet',
+ depth=50,
+ init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'),
+ ),
+ head=dict(
+ type='VisPredictHead',
+ loss=dict(
+ type='BCELoss',
+ use_target_weight=True,
+ use_sigmoid=True,
+ loss_weight=1e-3,
+ ),
+ pose_cfg=dict(
+ type='HeatmapHead',
+ in_channels=2048,
+ out_channels=17,
+ loss=dict(type='KeypointMSELoss', use_target_weight=True),
+ decoder=codec)),
+ test_cfg=dict(
+ flip_test=True,
+ flip_mode='heatmap',
+ shift_heatmap=True,
+ ))
+
+# base dataset settings
+dataset_type = 'CocoDataset'
+data_mode = 'topdown'
+data_root = 'data/coco/'
+
+# pipelines
+train_pipeline = [
+ dict(type='LoadImage'),
+ dict(type='GetBBoxCenterScale'),
+ dict(type='RandomFlip', direction='horizontal'),
+ dict(type='RandomHalfBody'),
+ dict(type='RandomBBoxTransform'),
+ dict(type='TopdownAffine', input_size=codec['input_size']),
+ dict(type='GenerateTarget', encoder=codec),
+ dict(type='PackPoseInputs')
+]
+val_pipeline = [
+ dict(type='LoadImage'),
+ dict(type='GetBBoxCenterScale'),
+ dict(type='TopdownAffine', input_size=codec['input_size']),
+ dict(type='PackPoseInputs')
+]
+
+# train datasets
+dataset_coco = dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='annotations/person_keypoints_train2017.json',
+ data_prefix=dict(img='train2017/'),
+ pipeline=[],
+)
+
+dataset_aic = dict(
+ type='AicDataset',
+ data_root='data/aic/',
+ data_mode=data_mode,
+ ann_file='annotations/aic_train.json',
+ data_prefix=dict(img='ai_challenger_keypoint_train_20170902/'
+ 'keypoint_train_images_20170902/'),
+ pipeline=[
+ dict(
+ type='KeypointConverter',
+ num_keypoints=17,
+ mapping=[
+ (0, 6),
+ (1, 8),
+ (2, 10),
+ (3, 5),
+ (4, 7),
+ (5, 9),
+ (6, 12),
+ (7, 14),
+ (8, 16),
+ (9, 11),
+ (10, 13),
+ (11, 15),
+ ])
+ ],
+)
+
+# data loaders
+train_dataloader = dict(
+ batch_size=64,
+ num_workers=2,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=True),
+ dataset=dict(
+ type='CombinedDataset',
+ metainfo=dict(from_file='configs/_base_/datasets/coco.py'),
+ datasets=[dataset_coco, dataset_aic],
+ pipeline=train_pipeline,
+ test_mode=False,
+ ))
+val_dataloader = dict(
+ batch_size=32,
+ num_workers=2,
+ persistent_workers=True,
+ drop_last=False,
+ sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
+ dataset=dict(
+ type=dataset_type,
+ data_root=data_root,
+ data_mode=data_mode,
+ ann_file='annotations/person_keypoints_val2017.json',
+ bbox_file='data/coco/person_detection_results/'
+ 'COCO_val2017_detections_AP_H_56_person.json',
+ data_prefix=dict(img='val2017/'),
+ test_mode=True,
+ pipeline=val_pipeline,
+ ))
+test_dataloader = val_dataloader
+
+# evaluators
+val_evaluator = dict(
+ type='CocoMetric',
+ # score_mode='bbox',
+ ann_file=data_root + 'annotations/person_keypoints_val2017.json')
+test_evaluator = val_evaluator
diff --git a/docs/en/advanced_guides/implement_new_models.md b/docs/en/advanced_guides/implement_new_models.md
index da46a99e39..ff54e2c5ff 100644
--- a/docs/en/advanced_guides/implement_new_models.md
+++ b/docs/en/advanced_guides/implement_new_models.md
@@ -79,3 +79,27 @@ class YourNewHead(BaseHead):
```
Finally, please remember to import your new prediction head in `[__init__.py](https://github.com/open-mmlab/mmpose/blob/main/mmpose/models/heads/__init__.py)` .
+
+### Head with Keypoints Visibility Prediction
+
+Many models predict keypoint visibility based on confidence in coordinate predictions. However, this approach is suboptimal. Our [`VisPredictHead`](https://github.com/open-mmlab/mmpose/blob/dev-1.x/mmpose/models/heads/hybrid_heads/vis_head.py) wrapper enables heads to directly predict keypoint visibility from ground truth training data, improving reliability. To add visibility prediction, wrap your head module with VisPredictHead in the config file.
+
+```python
+model=dict(
+ ...
+ head=dict(
+ type='VisPredictHead',
+ loss=dict(
+ type='BCELoss',
+ use_target_weight=True,
+ use_sigmoid=True,
+ loss_weight=1e-3),
+ pose_cfg=dict(
+ type='HeatmapHead',
+ in_channels=2048,
+ out_channels=17,
+ loss=dict(type='KeypointMSELoss', use_target_weight=True),
+ decoder=codec)),
+ ...
+)
+```
diff --git a/docs/zh_cn/advanced_guides/implement_new_models.md b/docs/zh_cn/advanced_guides/implement_new_models.md
index d3ed96bd37..22e866b52b 100644
--- a/docs/zh_cn/advanced_guides/implement_new_models.md
+++ b/docs/zh_cn/advanced_guides/implement_new_models.md
@@ -78,3 +78,27 @@ class YourNewHead(BaseHead):
```
最后,请记得在 [heads/\_\_init\_\_.py](https://github.com/open-mmlab/mmpose/blob/main/mmpose/models/heads/__init__.py) 中导入你的新预测头部。
+
+### 关键点可见性预测头部
+
+许多模型都是通过对关键点坐标预测的置信度来判断关键点的可见性的。然而,这种解决方案并非最优。我们提供了一个叫做 [`VisPredictHead`](https://github.com/open-mmlab/mmpose/blob/dev-1.x/mmpose/models/heads/hybrid_heads/vis_head.py) 的头部模块包装器,使得头部模块能够直接预测关键点的可见性。这个包装器是用训练数据中关键点可见性真值来训练的。因此,其预测会更加可靠。用户可以通过修改配置文件来对自己的头部模块加上这个包装器。下面是一个例子:
+
+```python
+model=dict(
+ ...
+ head=dict(
+ type='VisPredictHead',
+ loss=dict(
+ type='BCELoss',
+ use_target_weight=True,
+ use_sigmoid=True,
+ loss_weight=1e-3),
+ pose_cfg=dict(
+ type='HeatmapHead',
+ in_channels=2048,
+ out_channels=17,
+ loss=dict(type='KeypointMSELoss', use_target_weight=True),
+ decoder=codec)),
+ ...
+)
+```
diff --git a/mmpose/datasets/transforms/common_transforms.py b/mmpose/datasets/transforms/common_transforms.py
index 87068246f8..8f7aa05425 100644
--- a/mmpose/datasets/transforms/common_transforms.py
+++ b/mmpose/datasets/transforms/common_transforms.py
@@ -340,7 +340,7 @@ def _random_select_half_body(self, keypoints_visible: np.ndarray,
Args:
keypoints_visible (np.ndarray, optional): The visibility of
- keypoints in shape (N, K, 1).
+ keypoints in shape (N, K, 1) or (N, K, 2).
upper_body_ids (list): The list of upper body keypoint indices
lower_body_ids (list): The list of lower body keypoint indices
@@ -349,6 +349,9 @@ def _random_select_half_body(self, keypoints_visible: np.ndarray,
of each instance. ``None`` means not applying half-body transform.
"""
+ if keypoints_visible.ndim == 3:
+ keypoints_visible = keypoints_visible[..., 0]
+
half_body_ids = []
for visible in keypoints_visible:
@@ -390,7 +393,6 @@ def transform(self, results: Dict) -> Optional[dict]:
Returns:
dict: The result dict.
"""
-
half_body_ids = self._random_select_half_body(
keypoints_visible=results['keypoints_visible'],
upper_body_ids=results['upper_body_ids'],
@@ -952,6 +954,10 @@ def transform(self, results: Dict) -> Optional[dict]:
' \'keypoints\' in the results.')
keypoints_visible = results['keypoints_visible']
+ if keypoints_visible.ndim == 3 and keypoints_visible.shape[2] == 2:
+ keypoints_visible, keypoints_visible_weights = \
+ keypoints_visible[..., 0], keypoints_visible[..., 1]
+ results['keypoints_visible_weights'] = keypoints_visible_weights
# Encoded items from the encoder(s) will be updated into the results.
# Please refer to the document of the specific codec for details about
@@ -1031,16 +1037,6 @@ def transform(self, results: Dict) -> Optional[dict]:
results.update(encoded)
- if results.get('keypoint_weights', None) is not None:
- results['transformed_keypoints_visible'] = results[
- 'keypoint_weights']
- elif results.get('keypoints', None) is not None:
- results['transformed_keypoints_visible'] = results[
- 'keypoints_visible']
- else:
- raise ValueError('GenerateTarget requires \'keypoint_weights\' or'
- ' \'keypoints_visible\' in the results.')
-
return results
def __repr__(self) -> str:
diff --git a/mmpose/datasets/transforms/converting.py b/mmpose/datasets/transforms/converting.py
index 38dcea0994..c8a4a172cf 100644
--- a/mmpose/datasets/transforms/converting.py
+++ b/mmpose/datasets/transforms/converting.py
@@ -87,13 +87,18 @@ def __init__(self, num_keypoints: int,
self.interpolation = interpolation
def transform(self, results: dict) -> dict:
+ """Transforms the keypoint results to match the target keypoints."""
num_instances = results['keypoints'].shape[0]
+ # Initialize output arrays
keypoints = np.zeros((num_instances, self.num_keypoints, 2))
keypoints_visible = np.zeros((num_instances, self.num_keypoints))
- # When paired source_indexes are input,
- # perform interpolation with self.source_index and self.source_index2
+ # Create a mask to weight visibility loss
+ keypoints_visible_weights = keypoints_visible.copy()
+ keypoints_visible_weights[:, self.target_index] = 1.0
+
+ # Interpolate keypoints if pairs of source indexes provided
if self.interpolation:
keypoints[:, self.target_index] = 0.5 * (
results['keypoints'][:, self.source_index] +
@@ -102,6 +107,8 @@ def transform(self, results: dict) -> dict:
keypoints_visible[:, self.target_index] = results[
'keypoints_visible'][:, self.source_index] * \
results['keypoints_visible'][:, self.source_index2]
+
+ # Otherwise just copy from the source index
else:
keypoints[:,
self.target_index] = results['keypoints'][:, self.
@@ -109,8 +116,10 @@ def transform(self, results: dict) -> dict:
keypoints_visible[:, self.target_index] = results[
'keypoints_visible'][:, self.source_index]
+ # Update the results dict
results['keypoints'] = keypoints
- results['keypoints_visible'] = keypoints_visible
+ results['keypoints_visible'] = np.stack(
+ [keypoints_visible, keypoints_visible_weights], axis=2)
return results
def __repr__(self) -> str:
diff --git a/mmpose/datasets/transforms/formatting.py b/mmpose/datasets/transforms/formatting.py
index d047cff3c3..c2431c70bf 100644
--- a/mmpose/datasets/transforms/formatting.py
+++ b/mmpose/datasets/transforms/formatting.py
@@ -128,7 +128,7 @@ class PackPoseInputs(BaseTransform):
'keypoint_y_labels': 'keypoint_y_labels',
'keypoint_weights': 'keypoint_weights',
'instance_coords': 'instance_coords',
- 'transformed_keypoints_visible': 'keypoints_visible',
+ 'keypoints_visible_weights': 'keypoints_visible_weights'
}
# items in `field_mapping_table` will be packed into
@@ -195,10 +195,6 @@ def transform(self, results: dict) -> dict:
if self.pack_transformed and 'transformed_keypoints' in results:
gt_instances.set_field(results['transformed_keypoints'],
'transformed_keypoints')
- if self.pack_transformed and \
- 'transformed_keypoints_visible' in results:
- gt_instances.set_field(results['transformed_keypoints_visible'],
- 'transformed_keypoints_visible')
data_sample.gt_instances = gt_instances
diff --git a/mmpose/models/heads/hybrid_heads/vis_head.py b/mmpose/models/heads/hybrid_heads/vis_head.py
index e9ea271ac5..f95634541b 100644
--- a/mmpose/models/heads/hybrid_heads/vis_head.py
+++ b/mmpose/models/heads/hybrid_heads/vis_head.py
@@ -31,8 +31,7 @@ def __init__(self,
pose_cfg: ConfigType,
loss: ConfigType = dict(
type='BCELoss', use_target_weight=False,
- with_logits=True),
- use_sigmoid: bool = False,
+ use_sigmoid=True),
init_cfg: OptConfigType = None):
if init_cfg is None:
@@ -54,14 +53,14 @@ def __init__(self,
self.pose_head = MODELS.build(pose_cfg)
self.pose_cfg = pose_cfg
- self.use_sigmoid = use_sigmoid
+ self.use_sigmoid = loss.get('use_sigmoid', False)
modules = [
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(self.in_channels, self.out_channels)
]
- if use_sigmoid:
+ if self.use_sigmoid:
modules.append(nn.Sigmoid())
self.vis_head = nn.Sequential(*modules)
@@ -113,7 +112,7 @@ def integrate(self, batch_vis: Tensor,
assert len(pose_pred_instances) == len(batch_vis_np)
for index, _ in enumerate(pose_pred_instances):
- pose_pred_instances[index].keypoint_scores = batch_vis_np[index]
+ pose_pred_instances[index].keypoints_visible = batch_vis_np[index]
return pose_pred_instances, pose_pred_fields
@@ -176,15 +175,20 @@ def predict(self,
return self.integrate(batch_vis, batch_pose)
- def vis_accuracy(self, vis_pred_outputs, vis_labels):
+ @torch.no_grad()
+ def vis_accuracy(self, vis_pred_outputs, vis_labels, vis_weights=None):
"""Calculate visibility prediction accuracy."""
- probabilities = torch.sigmoid(torch.flatten(vis_pred_outputs))
+ if not self.use_sigmoid:
+ vis_pred_outputs = torch.sigmoid(vis_pred_outputs)
threshold = 0.5
- predictions = (probabilities >= threshold).int()
- labels = torch.flatten(vis_labels)
- correct = torch.sum(predictions == labels).item()
- accuracy = correct / len(labels)
- return torch.tensor(accuracy)
+ predictions = (vis_pred_outputs >= threshold).float()
+ correct = (predictions == vis_labels).float()
+ if vis_weights is not None:
+ accuracy = (correct * vis_weights).sum(dim=1) / (
+ vis_weights.sum(dim=1, keepdims=True) + 1e-6)
+ else:
+ accuracy = correct.mean(dim=1)
+ return accuracy.mean()
def loss(self,
feats: Tuple[Tensor],
@@ -203,18 +207,26 @@ def loss(self,
dict: A dictionary of losses.
"""
vis_pred_outputs = self.vis_forward(feats)
- vis_labels = torch.cat([
- d.gt_instance_labels.keypoint_weights for d in batch_data_samples
- ])
+ vis_labels = []
+ vis_weights = [] if self.loss_module.use_target_weight else None
+ for d in batch_data_samples:
+ vis_label = d.gt_instance_labels.keypoint_weights.float()
+ vis_labels.append(vis_label)
+ if vis_weights is not None:
+ vis_weights.append(
+ getattr(d.gt_instance_labels, 'keypoints_visible_weights',
+ vis_label.new_ones(vis_label.shape)))
+ vis_labels = torch.cat(vis_labels)
+ vis_weights = torch.cat(vis_weights) if vis_weights else None
# calculate vis losses
losses = dict()
- loss_vis = self.loss_module(vis_pred_outputs, vis_labels)
+ loss_vis = self.loss_module(vis_pred_outputs, vis_labels, vis_weights)
losses.update(loss_vis=loss_vis)
# calculate vis accuracy
- acc_vis = self.vis_accuracy(vis_pred_outputs, vis_labels)
+ acc_vis = self.vis_accuracy(vis_pred_outputs, vis_labels, vis_weights)
losses.update(acc_vis=acc_vis)
# calculate keypoints losses
diff --git a/mmpose/models/losses/classification_loss.py b/mmpose/models/losses/classification_loss.py
index 4605acabd3..5d2a2c7a58 100644
--- a/mmpose/models/losses/classification_loss.py
+++ b/mmpose/models/losses/classification_loss.py
@@ -14,15 +14,17 @@ class BCELoss(nn.Module):
use_target_weight (bool): Option to use weighted loss.
Different joint types may have different target weights.
loss_weight (float): Weight of the loss. Default: 1.0.
- with_logits (bool): Whether to use BCEWithLogitsLoss. Default: False.
+ use_sigmoid (bool, optional): Whether the prediction uses sigmoid
+ before output. Defaults to False.
"""
def __init__(self,
use_target_weight=False,
loss_weight=1.,
- with_logits=False):
+ use_sigmoid=False):
super().__init__()
- self.criterion = F.binary_cross_entropy if not with_logits\
+ self.use_sigmoid = use_sigmoid
+ self.criterion = F.binary_cross_entropy if use_sigmoid \
else F.binary_cross_entropy_with_logits
self.use_target_weight = use_target_weight
self.loss_weight = loss_weight
diff --git a/mmpose/models/pose_estimators/bottomup.py b/mmpose/models/pose_estimators/bottomup.py
index 5400f2478e..e7d2aaef88 100644
--- a/mmpose/models/pose_estimators/bottomup.py
+++ b/mmpose/models/pose_estimators/bottomup.py
@@ -169,6 +169,9 @@ def add_pred_to_datasample(self, batch_pred_instances: InstanceList,
pred_instances.keypoints = pred_instances.keypoints / input_size \
* input_scale + input_center - 0.5 * input_scale
+ if 'keypoints_visible' not in pred_instances:
+ pred_instances.keypoints_visible = \
+ pred_instances.keypoint_scores
data_sample.pred_instances = pred_instances
diff --git a/mmpose/models/pose_estimators/topdown.py b/mmpose/models/pose_estimators/topdown.py
index 89b332893f..0704627bd5 100644
--- a/mmpose/models/pose_estimators/topdown.py
+++ b/mmpose/models/pose_estimators/topdown.py
@@ -153,6 +153,9 @@ def add_pred_to_datasample(self, batch_pred_instances: InstanceList,
pred_instances.keypoints = pred_instances.keypoints / input_size \
* bbox_scales + bbox_centers - 0.5 * bbox_scales
+ if 'keypoints_visible' not in pred_instances:
+ pred_instances.keypoints_visible = \
+ pred_instances.keypoint_scores
if output_keypoint_indices is not None:
# select output keypoints with given indices
diff --git a/mmpose/structures/keypoint/transforms.py b/mmpose/structures/keypoint/transforms.py
index b50da4f8fe..bd7274dadf 100644
--- a/mmpose/structures/keypoint/transforms.py
+++ b/mmpose/structures/keypoint/transforms.py
@@ -20,8 +20,8 @@ def flip_keypoints(keypoints: np.ndarray,
Args:
keypoints (np.ndarray): Keypoints in shape (..., K, D)
keypoints_visible (np.ndarray, optional): The visibility of keypoints
- in shape (..., K, 1). Set ``None`` if the keypoint visibility is
- unavailable
+ in shape (..., K, 1) or (..., K, 2). Set ``None`` if the keypoint
+ visibility is unavailable
image_size (tuple): The image shape in [w, h]
flip_indices (List[int]): The indices of each keypoint's symmetric
keypoint
@@ -33,11 +33,12 @@ def flip_keypoints(keypoints: np.ndarray,
- keypoints_flipped (np.ndarray): Flipped keypoints in shape
(..., K, D)
- keypoints_visible_flipped (np.ndarray, optional): Flipped keypoints'
- visibility in shape (..., K, 1). Return ``None`` if the input
- ``keypoints_visible`` is ``None``
+ visibility in shape (..., K, 1) or (..., K, 2). Return ``None`` if
+ the input ``keypoints_visible`` is ``None``
"""
- assert keypoints.shape[:-1] == keypoints_visible.shape, (
+ ndim = keypoints.ndim
+ assert keypoints.shape[:-1] == keypoints_visible.shape[:ndim - 1], (
f'Mismatched shapes of keypoints {keypoints.shape} and '
f'keypoints_visible {keypoints_visible.shape}')
@@ -48,9 +49,10 @@ def flip_keypoints(keypoints: np.ndarray,
# swap the symmetric keypoint pairs
if direction == 'horizontal' or direction == 'vertical':
- keypoints = keypoints[..., flip_indices, :]
+ keypoints = keypoints.take(flip_indices, axis=ndim - 2)
if keypoints_visible is not None:
- keypoints_visible = keypoints_visible[..., flip_indices]
+ keypoints_visible = keypoints_visible.take(
+ flip_indices, axis=ndim - 2)
# flip the keypoints
w, h = image_size
diff --git a/mmpose/visualization/local_visualizer.py b/mmpose/visualization/local_visualizer.py
index 080e628e33..1eb994f03a 100644
--- a/mmpose/visualization/local_visualizer.py
+++ b/mmpose/visualization/local_visualizer.py
@@ -253,11 +253,6 @@ def _draw_instances_kpts(self,
keypoints = instances.get('transformed_keypoints',
instances.keypoints)
- if 'keypoint_scores' in instances:
- scores = instances.keypoint_scores
- else:
- scores = np.ones(keypoints.shape[:-1])
-
if 'keypoints_visible' in instances:
keypoints_visible = instances.keypoints_visible
else:
@@ -265,15 +260,13 @@ def _draw_instances_kpts(self,
if skeleton_style == 'openpose':
keypoints_info = np.concatenate(
- (keypoints, scores[..., None], keypoints_visible[...,
- None]),
- axis=-1)
+ (keypoints, keypoints_visible[..., None]), axis=-1)
# compute neck joint
neck = np.mean(keypoints_info[:, [5, 6]], axis=1)
# neck score when visualizing pred
- neck[:, 2:4] = np.logical_and(
- keypoints_info[:, 5, 2:4] > kpt_thr,
- keypoints_info[:, 6, 2:4] > kpt_thr).astype(int)
+ neck[:, 2:3] = np.logical_and(
+ keypoints_info[:, 5, 2:3] > kpt_thr,
+ keypoints_info[:, 6, 2:3] > kpt_thr).astype(int)
new_keypoints_info = np.insert(
keypoints_info, 17, neck, axis=1)
@@ -287,11 +280,10 @@ def _draw_instances_kpts(self,
new_keypoints_info[:, mmpose_idx]
keypoints_info = new_keypoints_info
- keypoints, scores, keypoints_visible = keypoints_info[
- ..., :2], keypoints_info[..., 2], keypoints_info[..., 3]
+ keypoints, keypoints_visible = keypoints_info[
+ ..., :2], keypoints_info[..., 2]
- for kpts, score, visible in zip(keypoints, scores,
- keypoints_visible):
+ for kpts, visible in zip(keypoints, keypoints_visible):
kpts = np.array(kpts, copy=False)
if self.kpt_color is None or isinstance(self.kpt_color, str):
@@ -320,17 +312,16 @@ def _draw_instances_kpts(self,
for sk_id, sk in enumerate(self.skeleton):
pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1]))
pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1]))
- if not (visible[sk[0]] and visible[sk[1]]):
- continue
if (pos1[0] <= 0 or pos1[0] >= img_w or pos1[1] <= 0
or pos1[1] >= img_h or pos2[0] <= 0
or pos2[0] >= img_w or pos2[1] <= 0
- or pos2[1] >= img_h or score[sk[0]] < kpt_thr
- or score[sk[1]] < kpt_thr
+ or pos2[1] >= img_h or visible[sk[0]] < kpt_thr
+ or visible[sk[1]] < kpt_thr
or link_color[sk_id] is None):
# skip the link that should not be drawn
continue
+
X = np.array((pos1[0], pos2[0]))
Y = np.array((pos1[1], pos2[1]))
color = link_color[sk_id]
@@ -339,7 +330,9 @@ def _draw_instances_kpts(self,
transparency = self.alpha
if self.show_keypoint_weight:
transparency *= max(
- 0, min(1, 0.5 * (score[sk[0]] + score[sk[1]])))
+ 0,
+ min(1,
+ 0.5 * (visible[sk[0]] + visible[sk[1]])))
if skeleton_style == 'openpose':
mX = np.mean(X)
@@ -365,8 +358,7 @@ def _draw_instances_kpts(self,
# draw each point on image
for kid, kpt in enumerate(kpts):
- if score[kid] < kpt_thr or not visible[
- kid] or kpt_color[kid] is None:
+ if visible[kid] < kpt_thr or kpt_color[kid] is None:
# skip the point that should not be drawn
continue
@@ -375,7 +367,7 @@ def _draw_instances_kpts(self,
color = tuple(int(c) for c in color)
transparency = self.alpha
if self.show_keypoint_weight:
- transparency *= max(0, min(1, score[kid]))
+ transparency *= max(0, min(1, visible[kid]))
self.draw_circles(
kpt,
radius=np.array([self.radius]),
diff --git a/tests/test_datasets/test_transforms/test_converting.py b/tests/test_datasets/test_transforms/test_converting.py
index 09f06e1e65..08561b1d0f 100644
--- a/tests/test_datasets/test_transforms/test_converting.py
+++ b/tests/test_datasets/test_transforms/test_converting.py
@@ -32,8 +32,10 @@ def test_transform(self):
self.assertTrue((results['keypoints'][:, target_index] ==
self.data_info['keypoints'][:,
source_index]).all())
+ self.assertEqual(results['keypoints_visible'].ndim, 3)
+ self.assertEqual(results['keypoints_visible'].shape[2], 2)
self.assertTrue(
- (results['keypoints_visible'][:, target_index] ==
+ (results['keypoints_visible'][:, target_index, 0] ==
self.data_info['keypoints_visible'][:, source_index]).all())
# 2-to-1 mapping
@@ -58,8 +60,10 @@ def test_transform(self):
(results['keypoints'][:, target_index] == 0.5 *
(self.data_info['keypoints'][:, source_index] +
self.data_info['keypoints'][:, source_index2])).all())
+ self.assertEqual(results['keypoints_visible'].ndim, 3)
+ self.assertEqual(results['keypoints_visible'].shape[2], 2)
self.assertTrue(
- (results['keypoints_visible'][:, target_index] ==
+ (results['keypoints_visible'][:, target_index, 0] ==
self.data_info['keypoints_visible'][:, source_index] *
self.data_info['keypoints_visible'][:,
source_index2]).all())
@@ -67,7 +71,9 @@ def test_transform(self):
self.assertTrue(
(results['keypoints'][:, target_index] ==
self.data_info['keypoints'][:, source_index]).all())
+ self.assertEqual(results['keypoints_visible'].ndim, 3)
+ self.assertEqual(results['keypoints_visible'].shape[2], 2)
self.assertTrue(
- (results['keypoints_visible'][:, target_index] ==
+ (results['keypoints_visible'][:, target_index, 0] ==
self.data_info['keypoints_visible'][:,
source_index]).all())