diff --git a/mmpose/core/evaluation/bottom_up_eval.py b/mmpose/core/evaluation/bottom_up_eval.py index 422ff40821..fb37a3373b 100644 --- a/mmpose/core/evaluation/bottom_up_eval.py +++ b/mmpose/core/evaluation/bottom_up_eval.py @@ -133,8 +133,9 @@ def get_multi_stage_outputs_paf(outputs, resize them to base sizes. Args: - outputs (list(torch.Tensor)): Outputs of network - outputs_flip (list(torch.Tensor)): Flip outputs of network + outputs (dict): Outputs of network, including heatmaps and pafs. + outputs_flip (dict): Flip outputs of network, including + heatmaps and pafs. with_heatmaps (list[bool]): Option to output heatmaps for different stages. with_pafs (list[bool]): Option to output diff --git a/mmpose/core/post_processing/group.py b/mmpose/core/post_processing/group.py index 9a870d6ce3..7063c6ee64 100644 --- a/mmpose/core/post_processing/group.py +++ b/mmpose/core/post_processing/group.py @@ -404,3 +404,272 @@ def parse(self, heatmaps, tags, adjust=True, refine=True): ans = [ans] return ans, scores + + +class PAFParser: + """The paf parser for post processing.""" + + def __init__(self, cfg): + self.params = _Params(cfg) + self.tag_per_joint = cfg['tag_per_joint'] + self.pool = torch.nn.MaxPool2d(cfg['nms_kernel'], 1, + cfg['nms_padding']) + self.use_udp = cfg.get('use_udp', False) + + # def nms(self, heatmaps): + # """Non-Maximum Suppression for heatmaps. + # + # Args: + # heatmap(torch.Tensor): Heatmaps before nms. + # + # Returns: + # torch.Tensor: Heatmaps after nms. + # """ + # + # maxm = self.pool(heatmaps) + # maxm = torch.eq(maxm, heatmaps).float() + # heatmaps = heatmaps * maxm + # + # return heatmaps + # + # def match(self, tag_k, loc_k, val_k): + # """Group keypoints to human poses in a batch. + # + # Args: + # tag_k (np.ndarray[NxKxMxL]): tag corresponding to the + # top k values of feature map per keypoint. + # loc_k (np.ndarray[NxKxMx2]): top k locations of the + # feature maps for keypoint. + # val_k (np.ndarray[NxKxM]): top k value of the + # feature maps per keypoint. + # + # Returns: + # list + # """ + # + # def _match(x): + # return _match_by_tag(x, self.params) + # + # return list(map(_match, zip(tag_k, loc_k, val_k))) + # + # def top_k(self, heatmaps, tags): + # """Find top_k values in an image. + # + # Note: + # batch size: N + # number of keypoints: K + # heatmap height: H + # heatmap width: W + # max number of people: M + # dim of tags: L + # If use flip testing, L=2; else L=1. + # + # Args: + # heatmaps (torch.Tensor[NxKxHxW]) + # tags (torch.Tensor[NxKxHxWxL]) + # + # Return: + # dict: A dict containing top_k values. + # + # - tag_k (np.ndarray[NxKxMxL]): + # tag corresponding to the top k values of + # feature map per keypoint. + # - loc_k (np.ndarray[NxKxMx2]): + # top k location of feature map per keypoint. + # - val_k (np.ndarray[NxKxM]): + # top k value of feature map per keypoint. + # """ + # heatmaps = self.nms(heatmaps) + # N, K, H, W = heatmaps.size() + # heatmaps = heatmaps.view(N, K, -1) + # val_k, ind = heatmaps.topk(self.params.max_num_people, dim=2) + # + # tags = tags.view(tags.size(0), tags.size(1), W * H, -1) + # if not self.tag_per_joint: + # tags = tags.expand(-1, self.params.num_joints, -1, -1) + # + # tag_k = torch.stack( + # [torch.gather(tags[..., i], 2, ind) for i in + # range(tags.size(3))], + # dim=3) + # + # x = ind % W + # y = ind // W + # + # ind_k = torch.stack((x, y), dim=3) + # + # ans = { + # 'tag_k': tag_k.cpu().numpy(), + # 'loc_k': ind_k.cpu().numpy(), + # 'val_k': val_k.cpu().numpy() + # } + # + # return ans + # + # @staticmethod + # def adjust(ans, heatmaps): + # """Adjust the coordinates for better accuracy. + # + # Note: + # batch size: N + # number of keypoints: K + # heatmap height: H + # heatmap width: W + # + # Args: + # ans (list(np.ndarray)): Keypoint predictions. + # heatmaps (torch.Tensor[NxKxHxW]): Heatmaps. + # """ + # _, _, H, W = heatmaps.shape + # for batch_id, people in enumerate(ans): + # for people_id, people_i in enumerate(people): + # for joint_id, joint in enumerate(people_i): + # if joint[2] > 0: + # x, y = joint[0:2] + # xx, yy = int(x), int(y) + # tmp = heatmaps[batch_id][joint_id] + # if tmp[min(H - 1, yy + 1), xx] > tmp[max(0, yy - 1), + # xx]: + # y += 0.25 + # else: + # y -= 0.25 + # + # if tmp[yy, min(W - 1, xx + 1)] > tmp[yy, + # max(0, xx - 1)]: + # x += 0.25 + # else: + # x -= 0.25 + # ans[batch_id][people_id, joint_id, + # 0:2] = (x + 0.5, y + 0.5) + # return ans + # + # @staticmethod + # def refine(heatmap, tag, keypoints, use_udp=False): + # """Given initial keypoint predictions, we identify missing joints. + # + # Note: + # number of keypoints: K + # heatmap height: H + # heatmap width: W + # dim of tags: L + # If use flip testing, L=2; else L=1. + # + # Args: + # heatmap: np.ndarray(K, H, W). + # tag: np.ndarray(K, H, W) | np.ndarray(K, H, W, L) + # keypoints: np.ndarray of size (K, 3 + L) + # last dim is (x, y, score, tag). + # use_udp: bool-unbiased data processing + # + # Returns: + # np.ndarray: The refined keypoints. + # """ + # + # K, H, W = heatmap.shape + # if len(tag.shape) == 3: + # tag = tag[..., None] + # + # tags = [] + # for i in range(K): + # if keypoints[i, 2] > 0: + # # save tag value of detected keypoint + # x, y = keypoints[i][:2].astype(int) + # x = np.clip(x, 0, W - 1) + # y = np.clip(y, 0, H - 1) + # tags.append(tag[i, y, x]) + # + # # mean tag of current detected people + # prev_tag = np.mean(tags, axis=0) + # ans = [] + # + # for _heatmap, _tag in zip(heatmap, tag): + # # distance of all tag values with mean tag of + # # current detected people + # distance_tag = (((_tag - + # prev_tag[None, None, :])**2).sum(axis=2)**0.5) + # norm_heatmap = _heatmap - np.round(distance_tag) + # + # # find maximum position + # y, x = np.unravel_index(np.argmax(norm_heatmap), _heatmap.shape) + # xx = x.copy() + # yy = y.copy() + # # detection score at maximum position + # val = _heatmap[y, x] + # if not use_udp: + # # offset by 0.5 + # x += 0.5 + # y += 0.5 + # + # # add a quarter offset + # if _heatmap[yy, min(W - 1, xx + 1)] > _ + # heatmap[yy, max(0, xx - 1)]: + # x += 0.25 + # else: + # x -= 0.25 + # + # if _heatmap[min(H - 1, yy + 1), xx] > _ + # heatmap[max(0, yy - 1), xx]: + # y += 0.25 + # else: + # y -= 0.25 + # + # ans.append((x, y, val)) + # ans = np.array(ans) + # + # if ans is not None: + # for i in range(K): + # # add keypoint if it is not detected + # if ans[i, 2] > 0 and keypoints[i, 2] == 0: + # keypoints[i, :3] = ans[i, :3] + # + # return keypoints + + def parse(self, heatmaps, pafs, adjust=True, refine=True): + """Group keypoints into poses given heatmap and paf. + + Note: + batch size: N + number of keypoints: K + number of paf maps: P + heatmap height: H + heatmap width: W + + Args: + heatmaps (torch.Tensor[NxKxHxW]): model output heatmaps. + pafs (torch.Tensor[NxPxHxW]): model output pafs. + + Returns: + tuple: A tuple containing keypoint grouping results. + + - ans (list(np.ndarray)): Pose results. + - scores (list): Score of people. + """ + + assert 0, 'The post-process of paf have not been completed.' + # ans = self.match(**self.top_k(heatmaps, pafs)) + # + # if adjust: + # if self.use_udp: + # for i in range(len(ans)): + # if ans[i].shape[0] > 0: + # ans[i][..., :2] = post_dark_udp( + # ans[i][..., :2].copy(), heatmaps[i:i + 1, :]) + # else: + # ans = self.adjust(ans, heatmaps) + # + # scores = [i[:, 2].mean() for i in ans[0]] + # + # if refine: + # ans = ans[0] + # # for every detected person + # for i in range(len(ans)): + # heatmap_numpy = heatmaps[0].cpu().numpy() + # tag_numpy = tags[0].cpu().numpy() + # if not self.tag_per_joint: + # tag_numpy = np.tile(tag_numpy, + # (self.params.num_joints, 1, 1, 1)) + # ans[i] = self.refine( + # heatmap_numpy, tag_numpy, ans[i], use_udp=self.use_udp) + # ans = [ans] + # + # return ans, scores diff --git a/mmpose/models/detectors/paf.py b/mmpose/models/detectors/paf.py index dd2bcdb2c2..74518766f5 100644 --- a/mmpose/models/detectors/paf.py +++ b/mmpose/models/detectors/paf.py @@ -10,7 +10,7 @@ from mmpose.core.evaluation import (aggregate_results_paf, get_group_preds, get_multi_stage_outputs_paf) -from mmpose.core.post_processing.group import HeatmapParser +from mmpose.core.post_processing.group import PAFParser from .. import builder from ..registry import POSENETS from .base import BasePose @@ -56,7 +56,7 @@ def __init__(self, self.train_cfg = train_cfg self.test_cfg = test_cfg self.use_udp = test_cfg.get('use_udp', False) - self.parser = HeatmapParser(self.test_cfg) + self.parser = PAFParser(self.test_cfg) self.init_weights(pretrained=pretrained) @property @@ -74,7 +74,6 @@ def forward(self, img=None, targets=None, masks=None, - joints=None, img_metas=None, return_loss=True, return_heatmap=False, @@ -92,11 +91,10 @@ def forward(self, max_num_people: M Args: img(torch.Tensor[NxCximgHximgW]): Input image. - targets(List(torch.Tensor[NxKxHxW])): Multi-scale target heatmaps. + targets (list(list)): List of heatmaps and pafs, each of which + multi-scale targets. masks(List(torch.Tensor[NxHxW])): Masks of multi-scale target - heatmaps - joints(List(torch.Tensor[NxMxKx2])): Joints of multi-scale target - heatmaps for ae loss + heatmaps. img_metas(dict):Information about val&test By default this includes: - "image_file": image path @@ -118,12 +116,11 @@ def forward(self, """ if return_loss: - return self.forward_train(img, targets, masks, joints, img_metas, - **kwargs) + return self.forward_train(img, targets, masks, img_metas, **kwargs) return self.forward_test( img, img_metas, return_heatmap=return_heatmap, **kwargs) - def forward_train(self, img, targets, masks, joints, img_metas, **kwargs): + def forward_train(self, img, targets, masks, img_metas, **kwargs): """Forward the bottom-up model and calculate the loss. Note: @@ -138,11 +135,10 @@ def forward_train(self, img, targets, masks, joints, img_metas, **kwargs): Args: img(torch.Tensor[NxCximgHximgW]): Input image. - targets(List(torch.Tensor[NxKxHxW])): Multi-scale target heatmaps. + targets (list(list)): List of heatmaps and pafs, each of which + multi-scale targets. masks(List(torch.Tensor[NxHxW])): Masks of multi-scale target - heatmaps - joints(List(torch.Tensor[NxMxKx2])): Joints of multi-scale target - heatmaps for ae loss + heatmaps. img_metas(dict):Information about val&test By default this includes: - "image_file": image path @@ -166,7 +162,7 @@ def forward_train(self, img, targets, masks, joints, img_metas, **kwargs): losses = dict() if self.with_keypoint: keypoint_losses = self.keypoint_head.get_loss( - output, targets, masks, joints) + output, targets, masks) losses.update(keypoint_losses) return losses @@ -240,7 +236,7 @@ def forward_test(self, img, img_metas, return_heatmap=False, **kwargs): outputs, outputs_flipped, self.test_cfg['with_heatmaps'], - self.test_cfg['with_ae'], + self.test_cfg['with_pafs'], img_metas['flip_index'], img_metas['flip_index_paf'], self.test_cfg['project2image'], @@ -259,10 +255,11 @@ def forward_test(self, img, img_metas, return_heatmap=False, **kwargs): # average heatmaps of different scales aggregated_heatmaps = aggregated_heatmaps / float( len(test_scale_factor)) - tags = torch.cat(pafs, dim=4) + aggregated_pafs = aggregated_pafs / float(len(test_scale_factor)) # perform grouping - grouped, scores = self.parser.parse(aggregated_heatmaps, tags, + grouped, scores = self.parser.parse(aggregated_heatmaps, + aggregated_pafs, self.test_cfg['adjust'], self.test_cfg['refine']) diff --git a/mmpose/models/keypoint_heads/__init__.py b/mmpose/models/keypoint_heads/__init__.py index 5f3dbc63f8..44432e28ef 100644 --- a/mmpose/models/keypoint_heads/__init__.py +++ b/mmpose/models/keypoint_heads/__init__.py @@ -4,13 +4,15 @@ from .heatmap_1d_head import Heatmap1DHead from .heatmap_3d_head import HeatMap3DHead from .multilabel_classification_head import MultilabelClassificationHead +from .paf_head import PAFHead +from .paf_simple_head import PAFSimpleHead from .temporal_regression_head import TemporalRegressionHead from .top_down_multi_stage_head import TopDownMSMUHead, TopDownMultiStageHead from .top_down_simple_head import TopDownSimpleHead __all__ = [ 'TopDownSimpleHead', 'TopDownMultiStageHead', 'TopDownMSMUHead', - 'BottomUpHigherResolutionHead', 'BottomUpSimpleHead', 'FcHead', - 'TemporalRegressionHead', 'HeatMap3DHead', 'Heatmap1DHead', - 'MultilabelClassificationHead' + 'BottomUpHigherResolutionHead', 'BottomUpSimpleHead', 'PAFHead', + 'PAFSimpleHead', 'FcHead', 'TemporalRegressionHead', 'HeatMap3DHead', + 'Heatmap1DHead', 'MultilabelClassificationHead' ] diff --git a/mmpose/models/keypoint_heads/bottom_up_higher_resolution_head.py b/mmpose/models/keypoint_heads/bottom_up_higher_resolution_head.py index 9460d6024f..92f14139ff 100644 --- a/mmpose/models/keypoint_heads/bottom_up_higher_resolution_head.py +++ b/mmpose/models/keypoint_heads/bottom_up_higher_resolution_head.py @@ -169,7 +169,7 @@ def _get_deconv_cfg(deconv_kernel): return deconv_kernel, padding, output_padding - def get_loss(self, output, targets, masks, joints): + def get_loss(self, outputs, targets, masks, joints): """Calculate bottom-up keypoint loss. Note: @@ -180,18 +180,18 @@ def get_loss(self, output, targets, masks, joints): heatmaps weight: W Args: - output (torch.Tensor[NxKxHxW]): Output heatmaps. - targets(List(torch.Tensor[NxKxHxW])): Multi-scale target heatmaps. - masks(List(torch.Tensor[NxHxW])): Masks of multi-scale target - heatmaps - joints(List(torch.Tensor[NxMxKx2])): Joints of multi-scale target - heatmaps for ae loss + outputs (List(torch.Tensor[NxKxHxW])): Multi-scale output heatmaps. + targets (List(torch.Tensor[NxKxHxW])): Multi-scale target heatmaps. + masks (List(torch.Tensor[NxHxW])): Masks of multi-scale target + heatmaps + joints (List(torch.Tensor[NxMxKx2])): Joints of multi-scale target + heatmaps for ae loss """ losses = dict() heatmaps_losses, push_losses, pull_losses = self.loss( - output, targets, masks, joints) + outputs, targets, masks, joints) for idx in range(len(targets)): if heatmaps_losses[idx] is not None: diff --git a/mmpose/models/keypoint_heads/bottom_up_simple_head.py b/mmpose/models/keypoint_heads/bottom_up_simple_head.py index 534bc42189..110419c486 100644 --- a/mmpose/models/keypoint_heads/bottom_up_simple_head.py +++ b/mmpose/models/keypoint_heads/bottom_up_simple_head.py @@ -82,7 +82,7 @@ def __init__(self, stride=1, padding=padding) - def get_loss(self, output, targets, masks, joints): + def get_loss(self, outputs, targets, masks, joints): """Calculate bottom-up keypoint loss. Note: @@ -93,9 +93,9 @@ def get_loss(self, output, targets, masks, joints): heatmaps weight: W Args: - output (torch.Tensor[NxKxHxW]): Output heatmaps. - targets(List(torch.Tensor[NxKxHxW])): Multi-scale target heatmaps. - masks(List(torch.Tensor[NxHxW])): Masks of multi-scale target + outputs (list(torch.Tensor[NxKxHxW])): Multi-scale output heatmaps. + targets (List(torch.Tensor[NxKxHxW])): Multi-scale target heatmaps. + masks (List(torch.Tensor[NxHxW])): Masks of multi-scale target heatmaps joints(List(torch.Tensor[NxMxKx2])): Joints of multi-scale target heatmaps for ae loss @@ -104,7 +104,7 @@ def get_loss(self, output, targets, masks, joints): losses = dict() heatmaps_losses, push_losses, pull_losses = self.loss( - output, targets, masks, joints) + outputs, targets, masks, joints) for idx in range(len(targets)): if heatmaps_losses[idx] is not None: diff --git a/mmpose/models/keypoint_heads/paf_head.py b/mmpose/models/keypoint_heads/paf_head.py new file mode 100644 index 0000000000..1cb32189be --- /dev/null +++ b/mmpose/models/keypoint_heads/paf_head.py @@ -0,0 +1,108 @@ +import torch.nn as nn + +from mmpose.models.builder import build_head +from ..registry import HEADS + + +@HEADS.register_module() +class PAFHead(nn.Module): + """Bottom-up PAF head. + + Args: + heatmap_heads_cfg (list(dict)): Configs of heatmap heads. + paf_heads_cfg (list(dict)): Configs of paf heads. + heatmap_index (list(int)): The correspondence between heatmap heads + and input features. + paf_index (list(int)): The correspondence between paf heads + and input features. + """ + + def __init__(self, heatmap_heads_cfg, paf_heads_cfg, heatmap_index, + paf_index): + super().__init__() + + assert len(heatmap_heads_cfg) == len(heatmap_index) + assert len(paf_heads_cfg) == len(paf_index) + + # build heatmap heads + self.heatmap_heads_list = [] + for head_cfg in heatmap_heads_cfg: + self.heatmap_heads_list.append(build_head(head_cfg)) + + # build paf heads + self.paf_heads_list = [] + for head_cfg in paf_heads_cfg: + self.paf_heads_list.append(build_head(head_cfg)) + + self.heatmap_index = heatmap_index + self.paf_index = paf_index + + def get_loss(self, outputs, targets, masks): + """Calculate heatmap and paf loss. + + Note: + batch_size: N + num_channels: C + heatmaps height: H + heatmaps weight: W + + Args: + outputs (dict): Outputs of network, including heatmaps and pafs. + targets (list(list)): List of heatmaps and pafs, each of which + multi-scale targets. + masks (list(torch.Tensor[NxHxW])): Masks of multi-scale target + heatmaps and pafs. + """ + + losses = dict() + + heatmap_outputs = outputs['heatmaps'] + heatmap_targets = targets[:len(self.heatmap_heads_list)] + for idx, head in enumerate(self.heatmap_heads_list): + heatmap_losses = head.get_loss(heatmap_outputs[idx], + heatmap_targets[idx], masks) + if 'heatmap_loss' not in losses: + losses['heatmap_loss'] = heatmap_losses['loss'] + else: + losses['heatmap_loss'] += heatmap_losses['loss'] + + paf_outputs = outputs['pafs'] + paf_targets = targets[len(self.heatmap_heads_list):] + for idx, head in enumerate(self.paf_heads_list): + paf_losses = head.get_loss(paf_outputs[idx], paf_targets[idx], + masks) + if 'paf_loss' not in losses: + losses['paf_loss'] = paf_losses['loss'] + else: + losses['paf_loss'] += paf_losses['loss'] + + return losses + + def forward(self, x): + """Forward function.""" + if not isinstance(x, list): + x = [x] + + assert max(self.heatmap_index) < len(x) + assert max(self.paf_index) < len(x) + + final_outputs = {'heatmaps': [], 'pafs': []} + + for idx, head in enumerate(self.heatmap_heads_list): + features = x[self.heatmap_index[idx]] + output = head(features) + final_outputs['heatmaps'].append(output) + + for idx, head in enumerate(self.paf_heads_list): + features = x[self.paf_index[idx]] + output = head(features) + final_outputs['pafs'].append(output) + + return final_outputs + + def init_weights(self): + for head in self.heatmap_heads_list: + head.init_weights() + + for head in self.paf_heads_list: + head.init_weights() diff --git a/mmpose/models/keypoint_heads/paf_simple_head.py b/mmpose/models/keypoint_heads/paf_simple_head.py new file mode 100644 index 0000000000..b78af28bfc --- /dev/null +++ b/mmpose/models/keypoint_heads/paf_simple_head.py @@ -0,0 +1,171 @@ +import torch.nn as nn +from mmcv.cnn import (build_conv_layer, build_upsample_layer, constant_init, + normal_init) + +from mmpose.models.builder import build_loss +from ..registry import HEADS + + +@HEADS.register_module() +class PAFSimpleHead(nn.Module): + """Bottom-up simple head. + + Args: + in_channels (int): Number of input channels. + num_joints (int): Number of joints. + num_deconv_layers (int): Number of deconv layers. + num_deconv_layers should >= 0. Note that 0 means + no deconv layers. + num_deconv_filters (list|tuple): Number of filters. + If num_deconv_layers > 0, the length of + num_deconv_kernels (list|tuple): Kernel sizes. + loss_keypoint (dict): Config for loss. Default: None. + """ + + def __init__(self, + in_channels, + num_joints, + num_deconv_layers=3, + num_deconv_filters=(256, 256, 256), + num_deconv_kernels=(4, 4, 4), + extra=None, + loss_keypoint=None): + super().__init__() + + self.loss = build_loss(loss_keypoint) + + self.in_channels = in_channels + out_channels = num_joints + + if extra is not None and not isinstance(extra, dict): + raise TypeError('extra should be dict or None.') + + if num_deconv_layers > 0: + self.deconv_layers = self._make_deconv_layer( + num_deconv_layers, + num_deconv_filters, + num_deconv_kernels, + ) + elif num_deconv_layers == 0: + self.deconv_layers = nn.Identity() + else: + raise ValueError( + f'num_deconv_layers ({num_deconv_layers}) should >= 0.') + + if extra is not None and 'final_conv_kernel' in extra: + assert extra['final_conv_kernel'] in [1, 3] + if extra['final_conv_kernel'] == 3: + padding = 1 + else: + padding = 0 + kernel_size = extra['final_conv_kernel'] + else: + kernel_size = 1 + padding = 0 + + self.final_layer = build_conv_layer( + cfg=dict(type='Conv2d'), + in_channels=num_deconv_filters[-1] + if num_deconv_layers > 0 else in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=1, + padding=padding) + + def get_loss(self, outputs, targets, masks): + """Calculate bottom-up masked mse loss. + + Note: + batch_size: N + num_channels: C + heatmaps height: H + heatmaps weight: W + + Args: + outputs (List(torch.Tensor[NxCxHxW])): Multi-scale outputs. + targets (List(torch.Tensor[NxCxHxW])): Multi-scale targets. + masks (List(torch.Tensor[NxHxW])): Masks of multi-scale targets. + """ + + losses = dict() + + for idx in range(len(targets)): + if 'loss' not in losses: + losses['loss'] = self.loss(outputs[idx], targets[idx], + masks[idx]) + else: + losses['loss'] += self.loss(outputs[idx], targets[idx], + masks[idx]) + + return losses + + def forward(self, x): + """Forward function.""" + if isinstance(x, list): + x = x[0] + final_outputs = [] + x = self.deconv_layers(x) + y = self.final_layer(x) + final_outputs.append(y) + return final_outputs + + def _make_deconv_layer(self, num_layers, num_filters, num_kernels): + """Make deconv layers.""" + if num_layers != len(num_filters): + error_msg = f'num_layers({num_layers}) ' \ + f'!= length of num_filters({len(num_filters)})' + raise ValueError(error_msg) + if num_layers != len(num_kernels): + error_msg = f'num_layers({num_layers}) ' \ + f'!= length of num_kernels({len(num_kernels)})' + raise ValueError(error_msg) + + layers = [] + for i in range(num_layers): + kernel, padding, output_padding = \ + self._get_deconv_cfg(num_kernels[i]) + + planes = num_filters[i] + layers.append( + build_upsample_layer( + dict(type='deconv'), + in_channels=self.in_channels, + out_channels=planes, + kernel_size=kernel, + stride=2, + padding=padding, + output_padding=output_padding, + bias=False)) + layers.append(nn.BatchNorm2d(planes)) + layers.append(nn.ReLU(inplace=True)) + self.in_channels = planes + + return nn.Sequential(*layers) + + @staticmethod + def _get_deconv_cfg(deconv_kernel): + """Get configurations for deconv layers.""" + if deconv_kernel == 4: + padding = 1 + output_padding = 0 + elif deconv_kernel == 3: + padding = 1 + output_padding = 1 + elif deconv_kernel == 2: + padding = 0 + output_padding = 0 + else: + raise ValueError(f'Not supported num_kernels ({deconv_kernel}).') + + return deconv_kernel, padding, output_padding + + def init_weights(self): + """Initialize model weights.""" + for _, m in self.deconv_layers.named_modules(): + if isinstance(m, nn.ConvTranspose2d): + normal_init(m, std=0.001) + elif isinstance(m, nn.BatchNorm2d): + constant_init(m, 1) + for m in self.final_layer.modules(): + if isinstance(m, nn.Conv2d): + normal_init(m, std=0.001, bias=0) diff --git a/mmpose/models/losses/mse_loss.py b/mmpose/models/losses/mse_loss.py index f6cc72b611..bd26feeed3 100644 --- a/mmpose/models/losses/mse_loss.py +++ b/mmpose/models/losses/mse_loss.py @@ -150,3 +150,31 @@ def forward(self, output, target, target_weight): losses = torch.cat(losses, dim=1) return self._ohkm(losses) * self.loss_weight + + +@LOSSES.register_module() +class MaskedMSELoss(nn.Module): + """MSE loss for the bottom-up outputs with mask. + + Args: + use_mask (bool): Option to use mask of target. Default: True. + loss_weight (float): Weight of the loss. Default: 1.0. + """ + + def __init__(self, use_mask=True, loss_weight=1.): + super().__init__() + self.criterion = nn.MSELoss() + self.use_mask = use_mask + self.loss_weight = loss_weight + + def forward(self, output, target, mask): + """Forward function.""" + assert output.size() == target.size() + + if self.use_mask: + loss = self.criterion( + output, target) * mask[:, None, :, :].expand_as(output) + else: + loss = self.criterion(output, target) + + return loss * self.loss_weight