Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Fix the incorrect labels for training vis_head with combined datasets #2550

Merged
merged 8 commits into from
Jul 27, 2023
Merged
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
_base_ = ['../../../_base_/default_runtime.py']

# runtime
train_cfg = dict(max_epochs=210, val_interval=10)

# optimizer
optim_wrapper = dict(optimizer=dict(
type='Adam',
lr=5e-4,
))

# learning policy
param_scheduler = [
dict(
type='LinearLR', begin=0, end=500, start_factor=0.001,
by_epoch=False), # warm-up
dict(
type='MultiStepLR',
begin=0,
end=210,
milestones=[170, 200],
gamma=0.1,
by_epoch=True)
]

# automatically scaling LR based on the actual training batch size
auto_scale_lr = dict(base_batch_size=512)

# hooks
default_hooks = dict(checkpoint=dict(save_best='coco/AP', rule='greater'))

# codec settings
codec = dict(
type='MSRAHeatmap', input_size=(192, 256), heatmap_size=(48, 64), sigma=2)

# model settings
model = dict(
type='TopdownPoseEstimator',
data_preprocessor=dict(
type='PoseDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True),
backbone=dict(
type='HRNet',
in_channels=3,
extra=dict(
stage1=dict(
num_modules=1,
num_branches=1,
block='BOTTLENECK',
num_blocks=(4, ),
num_channels=(64, )),
stage2=dict(
num_modules=1,
num_branches=2,
block='BASIC',
num_blocks=(4, 4),
num_channels=(32, 64)),
stage3=dict(
num_modules=4,
num_branches=3,
block='BASIC',
num_blocks=(4, 4, 4),
num_channels=(32, 64, 128)),
stage4=dict(
num_modules=3,
num_branches=4,
block='BASIC',
num_blocks=(4, 4, 4, 4),
num_channels=(32, 64, 128, 256))),
init_cfg=dict(
type='Pretrained',
checkpoint='https://download.openmmlab.com/mmpose/'
'pretrain_models/hrnet_w32-36af842e.pth'),
),
head=dict(
type='VisPredictHead',
loss=dict(
type='BCELoss',
use_target_weight=True,
use_sigmoid=True,
loss_weight=1e-4,
),
use_sigmoid=False,
pose_cfg=dict(
type='HeatmapHead',
in_channels=32,
out_channels=17,
deconv_out_channels=None,
loss=dict(type='KeypointMSELoss', use_target_weight=True),
decoder=codec)),
test_cfg=dict(
flip_test=True,
flip_mode='heatmap',
shift_heatmap=True,
))

# base dataset settings
dataset_type = 'CocoDataset'
data_mode = 'topdown'
data_root = 'data/coco/'

# pipelines
train_pipeline = [
dict(type='LoadImage'),
dict(type='GetBBoxCenterScale'),
dict(type='RandomFlip', direction='horizontal'),
dict(type='RandomHalfBody'),
dict(type='RandomBBoxTransform'),
dict(type='TopdownAffine', input_size=codec['input_size']),
dict(type='GenerateTarget', encoder=codec),
dict(type='PackPoseInputs')
]
val_pipeline = [
dict(type='LoadImage'),
dict(type='GetBBoxCenterScale'),
dict(type='TopdownAffine', input_size=codec['input_size']),
dict(type='PackPoseInputs')
]

# train datasets
dataset_coco = dict(
type=dataset_type,
data_root=data_root,
data_mode=data_mode,
ann_file='annotations/person_keypoints_train2017.json',
data_prefix=dict(img='train2017/'),
pipeline=[],
)

dataset_aic = dict(
type='AicDataset',
data_root='data/aic/',
data_mode=data_mode,
ann_file='annotations/aic_train.json',
data_prefix=dict(img='ai_challenger_keypoint_train_20170902/'
'keypoint_train_images_20170902/'),
pipeline=[
dict(
type='KeypointConverter',
num_keypoints=17,
mapping=[
(0, 6),
(1, 8),
(2, 10),
(3, 5),
(4, 7),
(5, 9),
(6, 12),
(7, 14),
(8, 16),
(9, 11),
(10, 13),
(11, 15),
])
],
)

# data loaders
train_dataloader = dict(
batch_size=64,
num_workers=2,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type='CombinedDataset',
metainfo=dict(from_file='configs/_base_/datasets/coco.py'),
datasets=[dataset_coco, dataset_aic],
pipeline=train_pipeline,
test_mode=False,
))
val_dataloader = dict(
batch_size=32,
num_workers=2,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_mode=data_mode,
ann_file='annotations/person_keypoints_val2017.json',
bbox_file='data/coco/person_detection_results/'
'COCO_val2017_detections_AP_H_56_person.json',
data_prefix=dict(img='val2017/'),
test_mode=True,
pipeline=val_pipeline,
))
test_dataloader = val_dataloader

# evaluators
val_evaluator = dict(
type='CocoMetric',
ann_file=data_root + 'annotations/person_keypoints_val2017.json')
test_evaluator = val_evaluator
20 changes: 8 additions & 12 deletions mmpose/datasets/transforms/common_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def _random_select_half_body(self, keypoints_visible: np.ndarray,

Args:
keypoints_visible (np.ndarray, optional): The visibility of
keypoints in shape (N, K, 1).
keypoints in shape (N, K, 1) or (N, K, 2).
upper_body_ids (list): The list of upper body keypoint indices
lower_body_ids (list): The list of lower body keypoint indices

Expand All @@ -349,6 +349,9 @@ def _random_select_half_body(self, keypoints_visible: np.ndarray,
of each instance. ``None`` means not applying half-body transform.
"""

if keypoints_visible.ndim == 3:
keypoints_visible = keypoints_visible[..., 0]

half_body_ids = []

for visible in keypoints_visible:
Expand Down Expand Up @@ -390,7 +393,6 @@ def transform(self, results: Dict) -> Optional[dict]:
Returns:
dict: The result dict.
"""

half_body_ids = self._random_select_half_body(
keypoints_visible=results['keypoints_visible'],
upper_body_ids=results['upper_body_ids'],
Expand Down Expand Up @@ -952,6 +954,10 @@ def transform(self, results: Dict) -> Optional[dict]:
' \'keypoints\' in the results.')

keypoints_visible = results['keypoints_visible']
if keypoints_visible.ndim == 3 and keypoints_visible.shape[2] == 2:
keypoints_visible, keypoints_visible_weights = \
keypoints_visible[..., 0], keypoints_visible[..., 1]
results['keypoints_visible_weights'] = keypoints_visible_weights

# Encoded items from the encoder(s) will be updated into the results.
# Please refer to the document of the specific codec for details about
Expand Down Expand Up @@ -1031,16 +1037,6 @@ def transform(self, results: Dict) -> Optional[dict]:

results.update(encoded)

if results.get('keypoint_weights', None) is not None:
results['transformed_keypoints_visible'] = results[
'keypoint_weights']
elif results.get('keypoints', None) is not None:
results['transformed_keypoints_visible'] = results[
'keypoints_visible']
else:
raise ValueError('GenerateTarget requires \'keypoint_weights\' or'
' \'keypoints_visible\' in the results.')

return results

def __repr__(self) -> str:
Expand Down
15 changes: 12 additions & 3 deletions mmpose/datasets/transforms/converting.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,18 @@ def __init__(self, num_keypoints: int,
self.interpolation = interpolation

def transform(self, results: dict) -> dict:
"""Transforms the keypoint results to match the target keypoints."""
num_instances = results['keypoints'].shape[0]

# Initialize output arrays
keypoints = np.zeros((num_instances, self.num_keypoints, 2))
keypoints_visible = np.zeros((num_instances, self.num_keypoints))

# When paired source_indexes are input,
# perform interpolation with self.source_index and self.source_index2
# Create a mask to weight visibility loss
keypoints_visible_weights = keypoints_visible.copy()
keypoints_visible_weights[:, self.target_index] = 1.0

# Interpolate keypoints if pairs of source indexes provided
if self.interpolation:
keypoints[:, self.target_index] = 0.5 * (
results['keypoints'][:, self.source_index] +
Expand All @@ -102,15 +107,19 @@ def transform(self, results: dict) -> dict:
keypoints_visible[:, self.target_index] = results[
'keypoints_visible'][:, self.source_index] * \
results['keypoints_visible'][:, self.source_index2]

# Otherwise just copy from the source index
else:
keypoints[:,
self.target_index] = results['keypoints'][:, self.
source_index]
keypoints_visible[:, self.target_index] = results[
'keypoints_visible'][:, self.source_index]

# Update the results dict
results['keypoints'] = keypoints
results['keypoints_visible'] = keypoints_visible
results['keypoints_visible'] = np.stack(
[keypoints_visible, keypoints_visible_weights], axis=2)
return results

def __repr__(self) -> str:
Expand Down
6 changes: 1 addition & 5 deletions mmpose/datasets/transforms/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class PackPoseInputs(BaseTransform):
'keypoint_y_labels': 'keypoint_y_labels',
'keypoint_weights': 'keypoint_weights',
'instance_coords': 'instance_coords',
'transformed_keypoints_visible': 'keypoints_visible',
'keypoints_visible_weights': 'keypoints_visible_weights'
}

# items in `field_mapping_table` will be packed into
Expand Down Expand Up @@ -195,10 +195,6 @@ def transform(self, results: dict) -> dict:
if self.pack_transformed and 'transformed_keypoints' in results:
gt_instances.set_field(results['transformed_keypoints'],
'transformed_keypoints')
if self.pack_transformed and \
'transformed_keypoints_visible' in results:
gt_instances.set_field(results['transformed_keypoints_visible'],
'transformed_keypoints_visible')

data_sample.gt_instances = gt_instances

Expand Down
24 changes: 15 additions & 9 deletions mmpose/models/heads/hybrid_heads/vis_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self,
pose_cfg: ConfigType,
loss: ConfigType = dict(
type='BCELoss', use_target_weight=False,
with_logits=True),
use_sigmoid=True),
use_sigmoid: bool = False,
init_cfg: OptConfigType = None):

Expand All @@ -54,14 +54,12 @@ def __init__(self,
self.pose_head = MODELS.build(pose_cfg)
self.pose_cfg = pose_cfg

self.use_sigmoid = use_sigmoid

modules = [
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(self.in_channels, self.out_channels)
]
if use_sigmoid:
if self.loss_module.use_sigmoid:
modules.append(nn.Sigmoid())

self.vis_head = nn.Sequential(*modules)
Expand Down Expand Up @@ -168,7 +166,7 @@ def predict(self,

batch_vis.unsqueeze_(dim=1) # (B, N, K, D)

if not self.use_sigmoid:
if not self.loss_module.use_sigmoid:
batch_vis = torch.sigmoid(batch_vis)

batch_pose = self.pose_head.predict(feats, batch_data_samples,
Expand Down Expand Up @@ -203,13 +201,21 @@ def loss(self,
dict: A dictionary of losses.
"""
vis_pred_outputs = self.vis_forward(feats)
vis_labels = torch.cat([
d.gt_instance_labels.keypoint_weights for d in batch_data_samples
])
vis_labels = []
vis_weights = [] if self.loss_module.use_target_weight else None
for d in batch_data_samples:
vis_label = d.gt_instance_labels.keypoint_weights.float()
vis_labels.append(vis_label)
if vis_weights is not None:
vis_weights.append(
getattr(d.gt_instance_labels, 'keypoints_visible_weights',
vis_label.new_ones(vis_label.shape)))
vis_labels = torch.cat(vis_labels)
vis_weights = torch.cat(vis_weights) if vis_weights else None

# calculate vis losses
losses = dict()
loss_vis = self.loss_module(vis_pred_outputs, vis_labels)
loss_vis = self.loss_module(vis_pred_outputs, vis_labels, vis_weights)

losses.update(loss_vis=loss_vis)

Expand Down
8 changes: 5 additions & 3 deletions mmpose/models/losses/classification_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,17 @@ class BCELoss(nn.Module):
use_target_weight (bool): Option to use weighted loss.
Different joint types may have different target weights.
loss_weight (float): Weight of the loss. Default: 1.0.
with_logits (bool): Whether to use BCEWithLogitsLoss. Default: False.
use_sigmoid (bool, optional): Whether the prediction uses sigmoid
before output. Defaults to False.
"""

def __init__(self,
use_target_weight=False,
loss_weight=1.,
with_logits=False):
use_sigmoid=False):
super().__init__()
self.criterion = F.binary_cross_entropy if not with_logits\
self.use_sigmoid = use_sigmoid
self.criterion = F.binary_cross_entropy if use_sigmoid \
else F.binary_cross_entropy_with_logits
self.use_target_weight = use_target_weight
self.loss_weight = loss_weight
Expand Down
Loading
Loading