diff --git a/mmpose/codecs/__init__.py b/mmpose/codecs/__init__.py index 60ebf8d424..e8797d7517 100644 --- a/mmpose/codecs/__init__.py +++ b/mmpose/codecs/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .associative_embedding import AssociativeEmbedding from .megvii_heatmap import MegviiHeatmap from .msra_heatmap import MSRAHeatmap from .regression_label import RegressionLabel @@ -7,5 +8,5 @@ __all__ = [ 'MSRAHeatmap', 'MegviiHeatmap', 'UDPHeatmap', 'RegressionLabel', - 'SimCCLabel' + 'SimCCLabel', 'AssociativeEmbedding' ] diff --git a/mmpose/codecs/associative_embedding.py b/mmpose/codecs/associative_embedding.py new file mode 100644 index 0000000000..14e585aac2 --- /dev/null +++ b/mmpose/codecs/associative_embedding.py @@ -0,0 +1,521 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import namedtuple +from itertools import product +from typing import Any, List, Optional, Tuple + +import numpy as np +import torch +from munkres import Munkres +from torch import Tensor + +from mmpose.registry import KEYPOINT_CODECS +from mmpose.utils.tensor_utils import to_numpy +from .base import BaseKeypointCodec +from .utils import (batch_heatmap_nms, generate_gaussian_heatmaps, + generate_udp_gaussian_heatmaps, refine_keypoints, + refine_keypoints_dark_udp) + + +def _group_keypoints_by_tags(vals: np.ndarray, + tags: np.ndarray, + locs: np.ndarray, + keypoint_order: List[int], + val_thr: float, + tag_dist_thr: float = 1.0, + max_groups: Optional[int] = None): + """Group the keypoints by tags using Munkres algorithm. + + Note: + + - keypoint number: K + - candidate number: M + - tag dimenssion: L + - coordinate dimension: D + - group number: G + + Args: + vals (np.ndarray): The heatmap response values of keypoints in shape + (K, M) + tags (np.ndarray): The tags of the keypoint candidates in shape + (K, M, L) + locs (np.ndarray): The locations of the keypoint candidates in shape + (K, M, D) + keypoint_order (List[int]): The grouping order of the keypoints. + The groupping usually starts from a keypoints around the head and + torso, and gruadually moves out to the limbs + val_thr (float): The threshold of the keypoint response value + tag_dist_thr (float): The maximum allowed tag distance when matching a + keypoint to a group. A keypoint with larger tag distance to any + of the existing groups will initializes a new group + max_groups (int, optional): The maximum group number. ``None`` means + no limitation. Defaults to ``None`` + + Returns: + tuple: + - grouped_keypoints (np.ndarray): The grouped keypoints in shape + (G, K, D) + - grouped_keypoint_scores (np.ndarray): The grouped keypoint scores + in shape (G, K) + """ + K, M, D = locs.shape + assert vals.shape == tags.shape[:2] == (K, M) + assert len(keypoint_order) == K + + # Build Munkres instance + munkres = Munkres() + + # Build a group pool, each group contains the keypoints of an instance + groups = [] + + Group = namedtuple('Group', field_names=['kpts', 'scores', 'tag_list']) + + def _init_group(): + """Initialize a group, which is composed of the keypoints, keypoint + scores and the tag of each keypoint.""" + _group = Group( + kpts=np.zeros((K, D), dtype=np.float32), + scores=np.zeros(K, dtype=np.float32), + tag_list=[]) + return _group + + for i in keypoint_order: + # Get all valid candidate of the i-th keypoints + valid = vals[i] > val_thr + if not valid.any(): + continue + + tags_i = tags[i, valid] # (M', L) + vals_i = vals[i, valid] # (M',) + locs_i = locs[i, valid] # (M', D) + + if len(groups) == 0: # Initialize the group pool + for tag, val, loc in zip(tags_i, vals_i, locs_i): + group = _init_group() + group.kpts[i] = loc + group.scores[i] = val + group.tag_list.append(tag) + + groups.append(group) + + else: # Match keypoints to existing groups + groups = groups[:max_groups] + group_tags = [np.mean(g.tag_list, axis=0) for g in groups] + + # Calculate distance matrix between group tags and tag candidates + # of the i-th keypoint + # Shape: (M', 1, L) , (1, G, L) -> (M', G, L) + diff = tags_i[:, None] - np.array(group_tags)[None] + dists = np.linalg.norm(diff, ord=2, axis=2) + num_kpts, num_groups = dists.shape[:2] + + # Experimental cost function for keypoint-group matching + costs = np.round(dists) * 100 - vals_i + if num_kpts > num_groups: + padding = np.full((num_kpts, num_kpts - num_groups), + 1e10, + dtype=np.float32) + costs = np.concatenate((costs, padding), axis=1) + + # Match keypoints and groups by Munkres algorithm + matches = munkres.compute(costs) + for kpt_idx, group_idx in matches: + if group_idx < num_groups and dists[kpt_idx, + group_idx] < tag_dist_thr: + # Add the keypoint to the matched group + group = groups[group_idx] + else: + # Initialize a new group with unmatched keypoint + group = _init_group() + groups.append(group) + + group.kpts[i] = locs_i[kpt_idx] + group.scores[i] = vals_i[kpt_idx] + group.tag_list.append(tags_i[kpt_idx]) + + groups = groups[:max_groups] + grouped_keypoints = np.stack((g.kpts for g in groups)) # (G, K, D) + grouped_keypoint_scores = np.stack((g.scores for g in groups)) # (G, K) + + return grouped_keypoints, grouped_keypoint_scores + + +@KEYPOINT_CODECS.register_module() +class AssociativeEmbedding(BaseKeypointCodec): + """Encode/decode keypoints with the method introduced in "Associative + Embedding". This is an asymmetric codec, where the keypoints are + represented as gaussian heatmaps and position indices during encoding, and + reostred from predicted heatmaps and group tags. + + See the paper `Associative Embedding: End-to-End Learning for Joint + Detection and Grouping`_ by Newell et al (2017) for details + + Note: + + - instance number: N + - keypoint number: K + - keypoint dimension: D + - embedding tag dimension: L + - image size: [w, h] + - heatmap size: [W, H] + + Args: + input_size (tuple): Image size in [w, h] + heatmap_size (tuple): Heatmap size in [W, H] + sigma (float): The sigma value of the Gaussian heatmap + use_udp (bool): Whether use unbiased data processing. See + `UDP (CVPR 2020)`_ for details. Defaults to ``False`` + decode_keypoint_order (List[int]): The grouping order of the + keypoint indices. The groupping usually starts from a keypoints + around the head and torso, and gruadually moves out to the limbs + decode_thr (float): The threshold of keypoint response value in + heatmaps. Defaults to 0.1 + decode_nms_kernel (int): The kernel size of the NMS during decoding, + which should be an odd integer. Defaults to 5 + decode_gaussian_kernel (int): The kernel size of the Gaussian blur + during decoding, which should be an odd integer. It is only used + when ``self.use_udp==True``. Defaults to 3 + decode_topk (int): The number top-k candidates of each keypoints that + will be retrieved from the heatmaps during dedocding. Defaults to + 20 + decode_max_instances (int, optional): The maximum number of instances + to decode. ``None`` means no limitation to the instance number. + Defaults to ``None`` + + .. _`Associative Embedding: End-to-End Learning for Joint Detection and + Grouping`: https://arxiv.org/abs/1611.05424 + .. _`UDP (CVPR 2020)`: https://arxiv.org/abs/1911.07524 + """ + + def __init__(self, + input_size: Tuple[int, int], + heatmap_size: Tuple[int, int], + sigma: Optional[float] = None, + use_udp: bool = False, + decode_keypoint_order: List[int] = [], + decode_nms_kernel: int = 5, + decode_gaussian_kernel: int = 3, + decode_thr: float = 0.1, + decode_topk: int = 20, + decode_max_instances: Optional[int] = None, + tag_per_keypoint: bool = True) -> None: + super().__init__() + self.input_size = input_size + self.heatmap_size = heatmap_size + self.use_udp = use_udp + self.decode_nms_kernel = decode_nms_kernel + self.decode_gaussian_kernel = decode_gaussian_kernel + self.decode_thr = decode_thr + self.decode_topk = decode_topk + self.decode_max_instances = decode_max_instances + self.tag_per_keypoint = tag_per_keypoint + self.dedecode_keypoint_order = decode_keypoint_order.copy() + + if sigma is None: + sigma = (heatmap_size[0] * heatmap_size[1])**0.5 / 64 + self.sigma = sigma + + def _get_scale_factor(self, input_size: Tuple[int, int], + heatmap_size: Tuple[int, int]) -> np.ndarray: + """Calculate scale factors from the input size and the heatmap size. + + Args: + input_size (tuple): Image size in [w, h] + heatmap_size (tuple): Heatmap size in [W, H] + + Returns: + np.ndarray: scale factors in [fx, fy] where :math:`fx=w/W` and + :math:`fy=h/H`. + """ + if self.use_udp: + scale_factor = ((np.array(input_size) - 1) / + (np.array(heatmap_size) - 1)).astype(np.float32) + else: + scale_factor = (np.array(input_size) / + heatmap_size).astype(np.float32) + return scale_factor + + def encode( + self, + keypoints: np.ndarray, + keypoints_visible: Optional[np.ndarray] = None + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Encode keypoints into heatmaps and position indices. Note that the + original keypoint coordinates should be in the input image space. + + Args: + keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D) + keypoints_visible (np.ndarray): Keypoint visibilities in shape + (N, K) + + Returns: + tuple: + - heatmaps (np.ndarray): The generated heatmap in shape + (K, H, W) where [W, H] is the `heatmap_size` + - keypoint_indices (np.ndarray): The keypoint position indices + in shape (N, K, 2). Each keypoint's index is [i, v], where i + is the position index in the heatmap (:math:`i=y*w+x`) and v + is the visibility + - keypoint_weights (np.ndarray): The target weights in shape + (N, K) + """ + + scale_factor = self._get_scale_factor(self.input_size, + self.heatmap_size) + + if keypoints_visible is None: + keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32) + + # keypoint coordinates in heatmap + _keypoints = keypoints / scale_factor + + if self.use_udp: + heatmaps, keypoints_weights = generate_udp_gaussian_heatmaps( + heatmap_size=self.heatmap_size, + keypoints=_keypoints, + keypoints_visible=keypoints_visible, + sigma=self.sigma) + else: + heatmaps, keypoints_weights = generate_gaussian_heatmaps( + heatmap_size=self.heatmap_size, + keypoints=_keypoints, + keypoints_visible=keypoints_visible, + sigma=self.sigma) + + keypoint_indices = self._encode_keypoint_indices( + heatmap_size=self.heatmap_size, + keypoints=_keypoints, + keypoints_visible=keypoints_visible) + + return heatmaps, keypoint_indices, keypoints_weights + + def _encode_keypoint_indices(self, heatmap_size: Tuple[int, int], + keypoints: np.ndarray, + keypoints_visible: np.ndarray) -> np.ndarray: + w, h = heatmap_size + N, K, _ = keypoints.shape + keypoint_indices = np.zeros((N, K, 2), dtype=np.int64) + + for n, k in product(range(N), range(K)): + x, y = (keypoints[n, k] + 0.5).astype(np.int64) + index = y * w + x + vis = (keypoints_visible[n, k] > 0.5 and 0 <= x < w and 0 <= y < h) + keypoint_indices[n, k] = [index, vis] + + return keypoint_indices + + def decode(self, encoded: Any) -> Tuple[np.ndarray, np.ndarray]: + raise NotImplementedError() + + def _get_batch_topk(self, batch_heatmaps: Tensor, batch_tags: Tensor, + k: int): + """Get top-k response values from the heatmaps and corresponding tag + values from the tagging heatmaps. + + Args: + batch_heatmaps (Tensor): Keypoint detection heatmaps in shape + (B, K, H, W) + batch_tags (Tensor): Tagging heatmaps in shape (B, C, H, W), where + the tag dim C is 2*K when using flip testing, or K otherwise + k (int): The number of top responses to get + + Returns: + tuple: + - topk_vals (Tensor): Top-k response values of each heatmap in + shape (B, K, Topk) + - topk_tags (Tensor): The corresponding embedding tags of the + top-k responses, in shape (B, K, Topk, L) + - topk_locs (Tensor): The location of the top-k responses in each + heatmap, in shape (B, K, Topk, 2) where last dimension + represents x and y coordinates + """ + B, K, H, W = batch_heatmaps.shape + L = batch_tags.shape[1] // K + + # shape of topk_val, top_indices: (B, K, TopK) + topk_vals, topk_indices = batch_heatmaps.flatten(-2, -1).topk( + k, dim=-1) + + topk_tags_per_kpts = [ + torch.gather(_tag, dim=2, index=topk_indices) + for _tag in torch.unbind(batch_tags.view(B, K, L, H * W), dim=2) + ] + + topk_tags = torch.stack(topk_tags_per_kpts, dim=-1) # (B, K, TopK, L) + topk_locs = torch.stack([topk_indices % W, topk_indices // W], + dim=-1) # (B, K, TopK, 2) + + return topk_vals, topk_tags, topk_locs + + def _group_keypoints(self, batch_vals: np.ndarray, batch_tags: np.ndarray, + batch_locs: np.ndarray): + """Group keypoints into groups (each represents an instance) by tags. + + Args: + batch_vals (Tensor): Heatmap response values of keypoint + candidates in shape (B, K, Topk) + batch_tags (Tensor): Tags of keypoint candidates in shape + (B, K, Topk, L) + batch_locs (Tensor): Locations of keypoint candidates in shape + (B, K, Topk, 2) + + Returns: + List[Tuple[np.ndarray, np.ndarray]]: Grouping results of a batch, + eath element is a tuple of keypoints (in shape [N, K, D]) and + keypoint scores (in shape [N, K]) decoded from one image. + """ + + def _group_func(inputs: Tuple): + vals, tags, locs = inputs + return _group_keypoints_by_tags( + vals, + tags, + locs, + keypoint_order=self.dedecode_keypoint_order, + val_thr=self.decode_thr, + max_groups=self.decode_max_instances) + + _results = map(_group_func, zip(batch_vals, batch_tags, batch_locs)) + results = list(_results) + return results + + def _fill_missing_keypoints(self, keypoints: np.ndarray, + keypoint_scores: np.ndarray, + heatmaps: np.ndarray, tags: np.ndarray): + """Fill the missing keypoints in the initial predictions. + + Args: + keypoints (np.ndarray): Keypoint predictions in shape (N, K, D) + keypoint_scores (np.ndarray): Keypint score predictions in shape + (N, K), in which 0 means the corresponding keypoint is + missing in the initial prediction + heatmaps (np.ndarry): Heatmaps in shape (K, H, W) + tags (np.ndarray): Tagging heatmaps in shape (C, H, W) where + C=K*L + + Returns: + tuple: + - keypoints (np.ndarray): Keypoint predictions with missing + ones filled + - keypoint_scores (np.ndarray): Keypoint score predictions with + missing ones filled + """ + + N, K = keypoints.shape[:2] + H, W = heatmaps.shape[1:] + keypoint_tags = np.split(tags, K, axis=0) + + for n in range(N): + # Calculate the instance tag (mean tag of detected keypoints) + _tag = [] + for k in range(K): + if keypoint_scores[n, k] > 0: + x, y = keypoints[n, k, :2].astype(np.int64) + x = np.clip(x, 0, W - 1) + y = np.clip(y, 0, H - 1) + _tag.append(keypoint_tags[k][:, y, x]) + tag = np.mean(_tag, axis=0) + + # Search maximum response of the missing keypoints + for k in range(K): + if keypoint_scores[n, k] > 0: + continue + dist_map = np.linalg.norm(keypoint_tags - tag, ord=2, axis=0) + cost_map = np.round(dist_map) * 100 - heatmaps[k] # H, W + y, x = np.unravel_index(np.argmin(cost_map), shape=(H, W)) + keypoints[n, k] = [x, y] + keypoint_scores[n, k] = heatmaps[k, y, x] + + return keypoints, keypoint_scores + + def batch_decode( + self, + batch_heatmaps: Tensor, + batch_tags: Tensor, + input_sizes: Optional[Tuple[int, int]] = None + ) -> Tuple[List[np.ndarray], List[np.ndarray]]: + """Decode the keypoint coordinates from a batch of heatmaps and tagging + heatmaps. The decoded keypoint coordinates are in the input image + space. + + Args: + batch_heatmaps (Tensor): Keypoint detection heatmaps in shape + (B, K, H, W) + batch_tags (Tensor): Tagging heatmaps in shape (B, C, H, W), where + :math:`C=L` if `tag_per_keypoint==False`, or + :math:`C=L*K` otherwise + input_sizes (List[Tuple[int, int]], optional): Manually set the + input size [w, h] of each sample for decoding. This is useful + when inference a model on images with arbitrary sizes. If not + given, the value `self.input_size` set at initialization will + be used for all samples. Defaults to ``None`` + + Returns: + tuple: + - batch_keypoints (List[np.ndarray]): Decoded keypoint coordinates + of the batch, each is in shape (N, K, D) + - batch_scores (List[np.ndarray]): Decoded keypoint scores of the + batch, each is in shape (N, K). It usually represents the + confidience of the keypoint prediction + """ + B, K, H, W = batch_heatmaps.shape + assert batch_tags.shape[0] == B and batch_tags.shape[2:4] == (H, W), ( + f'Unmatched shapes of heatmap ({batch_heatmaps.shape}) and ' + f'tagging map ({batch_tags.shape})') + + if not self.tag_per_keypoint: + batch_tags = batch_tags.tile((1, K, 1, 1)) + + # Heatmap NMS + batch_heatmaps = batch_heatmap_nms(batch_heatmaps, + self.decode_nms_kernel) + + # Get top-k in each heatmap and and convert to numpy + batch_topk_vals, batch_topk_tags, batch_topk_locs = to_numpy( + self._get_batch_topk( + batch_heatmaps, batch_tags, k=self.decode_topk)) + + # Group keypoint candidates into groups (instances) + batch_groups = self._group_keypoints(batch_topk_vals, batch_topk_tags, + batch_topk_locs) + + batch_keypoints, batch_keypoint_scores = map(list, zip(*batch_groups)) + + # Convert to numpy + batch_heatmaps_np = to_numpy(batch_heatmaps) + batch_tags_np = to_numpy(batch_tags) + + # Refine the keypoint prediction + for i, (keypoints, scores, heatmaps, tags) in enumerate( + zip(batch_keypoints, batch_keypoint_scores, batch_heatmaps_np, + batch_tags_np)): + + # identify missing keypoints + keypoints, scores = self._fill_missing_keypoints( + keypoints, scores, heatmaps, tags) + + # refine keypoint coordinates according to heatmap distribution + if self.use_udp: + keypoints = refine_keypoints_dark_udp( + keypoints, + heatmaps, + blur_kernel_size=self.decode_gaussian_kernel) + else: + keypoints = refine_keypoints(keypoints, heatmaps) + + batch_keypoints[i] = keypoints + batch_keypoint_scores[i] = scores + + # restore keypoint scale + if input_sizes is None: + input_sizes = [self.input_size] * B + else: + assert len(input_sizes) == B + + heatmap_size = (W, H) + + batch_keypoints = [ + kpts * self._get_scale_factor(input_size, heatmap_size) + for kpts, input_size in zip(batch_keypoints, input_sizes) + ] + + return batch_keypoints, batch_keypoint_scores diff --git a/mmpose/codecs/msra_heatmap.py b/mmpose/codecs/msra_heatmap.py index 5796e43c8c..880311a6ab 100644 --- a/mmpose/codecs/msra_heatmap.py +++ b/mmpose/codecs/msra_heatmap.py @@ -5,9 +5,10 @@ from mmpose.registry import KEYPOINT_CODECS from .base import BaseKeypointCodec -from .utils import (gaussian_blur, generate_gaussian_heatmaps, - get_heatmap_maximum) -from .utils.gaussian_heatmap import generate_unbiased_gaussian_heatmaps +from .utils.gaussian_heatmap import (generate_gaussian_heatmaps, + generate_unbiased_gaussian_heatmaps) +from .utils.post_processing import get_heatmap_maximum +from .utils.refinement import refine_keypoints, refine_keypoints_dark @KEYPOINT_CODECS.register_module() @@ -126,26 +127,16 @@ def decode(self, encoded: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: keypoints, scores = get_heatmap_maximum(heatmaps) + # Unsqueeze the instance dimension for single-instance results + keypoints, scores = keypoints[None], scores[None] + if self.unbiased: # Alleviate biased coordinate - # Apply Gaussian distribution modulation. - heatmaps = gaussian_blur(heatmaps, kernel=self.blur_kernel_size) - heatmaps = np.log(np.maximum(heatmaps, 1e-10)) - for k in range(K): - keypoints[k] = self._taylor_decode( - heatmap=heatmaps[k], keypoint=keypoints[k]) + keypoints = refine_keypoints_dark( + keypoints, heatmaps, blur_kernel_size=self.blur_kernel_size) + else: - # Add +/-0.25 shift to the predicted locations for higher acc. - for k in range(K): - heatmap = heatmaps[k] - px = int(keypoints[k, 0]) - py = int(keypoints[k, 1]) - if 1 < px < W - 1 and 1 < py < H - 1: - diff = np.array([ - heatmap[py][px + 1] - heatmap[py][px - 1], - heatmap[py + 1][px] - heatmap[py - 1][px] - ]) - keypoints[k] += np.sign(diff) * 0.25 + keypoints = refine_keypoints(keypoints, heatmaps) # Unsqueeze the instance dimension for single-instance results # and restore the keypoint scales @@ -153,42 +144,3 @@ def decode(self, encoded: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: scores = scores[None] return keypoints, scores - - @staticmethod - def _taylor_decode(heatmap: np.ndarray, - keypoint: np.ndarray) -> np.ndarray: - """Distribution aware coordinate decoding for a single keypoint. - - Note: - - heatmap height: H - - heatmap width: W - - Args: - heatmap (np.ndarray[H, W]): Heatmap of a particular keypoint type. - keypoint (np.ndarray[2,]): Coordinates of the predicted keypoint. - - Returns: - np.ndarray[2,]: Updated coordinates. - """ - H, W = heatmap.shape[:2] - px, py = int(keypoint[0]), int(keypoint[1]) - if 1 < px < W - 2 and 1 < py < H - 2: - dx = 0.5 * (heatmap[py][px + 1] - heatmap[py][px - 1]) - dy = 0.5 * (heatmap[py + 1][px] - heatmap[py - 1][px]) - dxx = 0.25 * ( - heatmap[py][px + 2] - 2 * heatmap[py][px] + - heatmap[py][px - 2]) - dxy = 0.25 * ( - heatmap[py + 1][px + 1] - heatmap[py - 1][px + 1] - - heatmap[py + 1][px - 1] + heatmap[py - 1][px - 1]) - dyy = 0.25 * ( - heatmap[py + 2 * 1][px] - 2 * heatmap[py][px] + - heatmap[py - 2 * 1][px]) - derivative = np.array([[dx], [dy]]) - hessian = np.array([[dxx, dxy], [dxy, dyy]]) - if dxx * dyy - dxy**2 != 0: - hessianinv = np.linalg.inv(hessian) - offset = -hessianinv @ derivative - offset = np.squeeze(np.array(offset.T), axis=0) - keypoint += offset - return keypoint diff --git a/mmpose/codecs/udp_heatmap.py b/mmpose/codecs/udp_heatmap.py index 1491deb878..9d48b39f9f 100644 --- a/mmpose/codecs/udp_heatmap.py +++ b/mmpose/codecs/udp_heatmap.py @@ -7,7 +7,7 @@ from mmpose.registry import KEYPOINT_CODECS from .base import BaseKeypointCodec from .utils import (generate_offset_heatmap, generate_udp_gaussian_heatmaps, - get_heatmap_maximum) + get_heatmap_maximum, refine_keypoints_dark_udp) @KEYPOINT_CODECS.register_module() @@ -139,8 +139,13 @@ def decode(self, encoded: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: if self.heatmap_type == 'gaussian': keypoints, scores = get_heatmap_maximum(heatmaps) - keypoints = self._postprocess_dark_udp(heatmaps, keypoints, - self.blur_kernel_size) + # unsqueeze the instance dimension for single-instance results + keypoints = keypoints[None] + scores = scores[None] + + keypoints = refine_keypoints_dark_udp( + keypoints, heatmaps, blur_kernel_size=self.blur_kernel_size) + elif self.heatmap_type == 'combined': _K, H, W = heatmaps.shape K = _K // 3 @@ -163,61 +168,11 @@ def decode(self, encoded: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: index += W * H * np.arange(0, K) index = index.astype(int) keypoints += np.stack((x_offset[index], y_offset[index]), axis=-1) + # unsqueeze the instance dimension for single-instance results + keypoints = keypoints[None].astype(np.float32) + scores = scores[None] - # Unsqueeze the instance dimension for single-instance results W, H = self.heatmap_size - keypoints = keypoints[None] / [W - 1, H - 1] * self.input_size - scores = scores[None] + keypoints = keypoints / [W - 1, H - 1] * self.input_size return keypoints, scores - - @staticmethod - def _postprocess_dark_udp(heatmaps: np.ndarray, keypoints: np.ndarray, - kernel_size: int) -> np.ndarray: - """Distribution aware post-processing for UDP. - - Args: - heatmaps (np.ndarray): Heatmaps in shape (K, H, W) - keypoints (np.ndarray): Keypoint coordinates in shape (K, D) - kernel_size (int): The Gaussian blur kernel size of the heatmap - modulation - - Returns: - np.ndarray: Post-processed keypoint coordinates - """ - K, H, W = heatmaps.shape - - for k in range(K): - cv2.GaussianBlur(heatmaps[k], (kernel_size, kernel_size), 0, - heatmaps[k]) - - np.clip(heatmaps, 0.001, 50., heatmaps) - np.log(heatmaps, heatmaps) - heatmaps_pad = np.pad( - heatmaps, ((0, 0), (1, 1), (1, 1)), mode='edge').flatten() - - index = keypoints[..., 0] + 1 + (keypoints[..., 1] + 1) * (W + 2) - index += (W + 2) * (H + 2) * np.arange(0, K) - index = index.astype(int).reshape(-1, 1) - i_ = heatmaps_pad[index] - ix1 = heatmaps_pad[index + 1] - iy1 = heatmaps_pad[index + W + 2] - ix1y1 = heatmaps_pad[index + W + 3] - ix1_y1_ = heatmaps_pad[index - W - 3] - ix1_ = heatmaps_pad[index - 1] - iy1_ = heatmaps_pad[index - 2 - W] - - dx = 0.5 * (ix1 - ix1_) - dy = 0.5 * (iy1 - iy1_) - derivative = np.concatenate([dx, dy], axis=1) - derivative = derivative.reshape(K, 2, 1) - - dxx = ix1 - 2 * i_ + ix1_ - dyy = iy1 - 2 * i_ + iy1_ - dxy = 0.5 * (ix1y1 - ix1 - iy1 + i_ + i_ - ix1_ - iy1_ + ix1_y1_) - hessian = np.concatenate([dxx, dxy, dxy, dyy], axis=1) - hessian = hessian.reshape(K, 2, 2) - hessian = np.linalg.inv(hessian + np.finfo(np.float32).eps * np.eye(2)) - keypoints -= np.einsum('imn,ink->imk', hessian, derivative).squeeze() - - return keypoints diff --git a/mmpose/codecs/utils/__init__.py b/mmpose/codecs/utils/__init__.py index 290cba0873..b6acf34cc9 100644 --- a/mmpose/codecs/utils/__init__.py +++ b/mmpose/codecs/utils/__init__.py @@ -3,11 +3,15 @@ generate_udp_gaussian_heatmaps, generate_unbiased_gaussian_heatmaps) from .offset_heatmap import generate_offset_heatmap -from .post_processing import (gaussian_blur, get_heatmap_maximum, - get_simcc_maximum) +from .post_processing import (batch_heatmap_nms, gaussian_blur, + get_heatmap_maximum, get_simcc_maximum) +from .refinement import (refine_keypoints, refine_keypoints_dark, + refine_keypoints_dark_udp) __all__ = [ 'generate_gaussian_heatmaps', 'generate_udp_gaussian_heatmaps', 'generate_unbiased_gaussian_heatmaps', 'gaussian_blur', - 'get_heatmap_maximum', 'get_simcc_maximum', 'generate_offset_heatmap' + 'get_heatmap_maximum', 'get_simcc_maximum', 'generate_offset_heatmap', + 'batch_heatmap_nms', 'refine_keypoints', 'refine_keypoints_dark', + 'refine_keypoints_dark_udp' ] diff --git a/mmpose/codecs/utils/post_processing.py b/mmpose/codecs/utils/post_processing.py index 3f60a7e060..e6c9b71de3 100644 --- a/mmpose/codecs/utils/post_processing.py +++ b/mmpose/codecs/utils/post_processing.py @@ -3,6 +3,9 @@ import cv2 import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor def get_simcc_maximum(simcc_x: np.ndarray, @@ -136,3 +139,28 @@ def gaussian_blur(heatmaps: np.ndarray, kernel: int = 11) -> np.ndarray: heatmaps[k] = dr[border:-border, border:-border].copy() heatmaps[k] *= origin_max / np.max(heatmaps[k]) return heatmaps + + +def batch_heatmap_nms(batch_heatmaps: Tensor, kernel_size: int = 5): + """Apply NMS on a batch of heatmaps. + + Args: + batch_heatmaps (Tensor): batch heatmaps in shape (B, K, H, W) + kernel_size (int): The kernel size of the NMS which should be + a odd integer. Defaults to 5 + + Returns: + Tensor: The batch heatmaps after NMS. + """ + + assert isinstance(kernel_size, int) and kernel_size % 2 == 1, \ + f'The kernel_size should be an odd integer, got {kernel_size}' + + padding = (kernel_size - 1) // 2 + + maximum = F.max_pool2d( + batch_heatmaps, kernel_size, stride=1, padding=padding) + maximum_indicator = torch.eq(batch_heatmaps, maximum) + batch_heatmaps = batch_heatmaps * maximum_indicator.float() + + return batch_heatmaps diff --git a/mmpose/codecs/utils/refinement.py b/mmpose/codecs/utils/refinement.py new file mode 100644 index 0000000000..53e9bdbf3a --- /dev/null +++ b/mmpose/codecs/utils/refinement.py @@ -0,0 +1,165 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from itertools import product + +import numpy as np + +from .post_processing import gaussian_blur + + +def refine_keypoints(keypoints: np.ndarray, + heatmaps: np.ndarray) -> np.ndarray: + """Refine keypoint predictions by moving from the maximum towards the + second maximum by 0.25 pixel. The operation is in-place. + + Note: + + - instance number: N + - keypoint number: K + - keypoint dimension: D + - heatmap size: [W, H] + + Args: + keypoints (np.ndarray): The keypoint coordinates in shape (N, K, D) + heatmaps (np.ndarray): The heatmaps in shape (K, H, W) + + Returns: + np.ndarray: Refine keypoint coordinates in shape (N, K, D) + """ + N, K = keypoints.shape[:2] + H, W = heatmaps.shape[1:] + + for n, k in product(range(N), range(K)): + x, y = keypoints[n, k, :2].astype(int) + + if 1 < x < W - 1: + dx = heatmaps[k, y, x + 1] - heatmaps[k, y, x - 1] + else: + dx = 0. + + if 1 < y < H - 1: + dy = heatmaps[k, y + 1, x] - heatmaps[k, y - 1, x] + else: + dy = 0. + + keypoints[n] += np.sign([dx, dy], dtype=np.float32) * 0.25 + + return keypoints + + +def refine_keypoints_dark(keypoints: np.ndarray, heatmaps: np.ndarray, + blur_kernel_size: int) -> np.ndarray: + """Refine keypoint predictions using distribution aware coordinate + decoding. See `Dark Pose`_ for details. The operation is in-place. + + Note: + + - instance number: N + - keypoint number: K + - keypoint dimension: D + - heatmap size: [W, H] + + Args: + keypoints (np.ndarray): The keypoint coordinates in shape (N, K, D) + heatmaps (np.ndarray): The heatmaps in shape (K, H, W) + blur_kernel_size (int): The Gaussian blur kernel size of the heatmap + modulation + + Returns: + np.ndarray: Refine keypoint coordinates in shape (N, K, D) + + .. _`Dark Pose`: https://arxiv.org/abs/1910.06278 + """ + N, K = keypoints.shape[:2] + H, W = heatmaps.shape[1:] + + # modulate heatmaps + heatmaps = gaussian_blur(heatmaps, blur_kernel_size) + np.maximum(heatmaps, 1e-10, heatmaps) + np.log(heatmaps, heatmaps) + + for n, k in product(range(N), range(K)): + x, y = keypoints[n, k, :2].astype(int) + if 1 < x < W - 2 and 1 < y < H - 2: + dx = 0.5 * (heatmaps[k, y, x + 1] - heatmaps[k, y, x - 1]) + dy = 0.5 * (heatmaps[k, y + 1, x] - heatmaps[k, y - 1, x]) + + dxx = 0.25 * ( + heatmaps[k, y, x + 2] - 2 * heatmaps[k, y, x] + + heatmaps[k, y, x - 1]) + dxy = 0.25 * ( + heatmaps[k, y + 1, x + 1] - heatmaps[k, y - 1, x + 1] - + heatmaps[k, y + 1, x - 1] + heatmaps[k, y - 1, x - 1]) + dyy = 0.25 * ( + heatmaps[k, y + 2, x] - 2 * heatmaps[k, y, x] + + heatmaps[k, y - 2, x]) + derivative = np.array([[dx], [dy]]) + hessian = np.array([[dxx, dxy], [dxy, dyy]]) + if dxx * dyy - dxy**2 != 0: + hessianinv = np.linalg.inv(hessian) + offset = -hessianinv @ derivative + offset = np.squeeze(np.array(offset.T), axis=0) + keypoints[n, k, :2] += offset + return keypoints + + +def refine_keypoints_dark_udp(keypoints: np.ndarray, heatmaps: np.ndarray, + blur_kernel_size: int) -> np.ndarray: + """Refine keypoint predictions using distribution aware coordinate decoding + for UDP. See `UDP`_ for details. The operation is in-place. + + Note: + + - instance number: N + - keypoint number: K + - keypoint dimension: D + - heatmap size: [W, H] + + Args: + keypoints (np.ndarray): The keypoint coordinates in shape (N, K, D) + heatmaps (np.ndarray): The heatmaps in shape (K, H, W) + blur_kernel_size (int): The Gaussian blur kernel size of the heatmap + modulation + + Returns: + np.ndarray: Refine keypoint coordinates in shape (N, K, D) + + .. _`UDP`: https://arxiv.org/abs/1911.07524 + """ + N, K = keypoints.shape[:2] + H, W = heatmaps.shape[1:] + + # modulate heatmaps + heatmaps = gaussian_blur(heatmaps, blur_kernel_size) + np.clip(heatmaps, 1e-3, 50., heatmaps) + np.log(heatmaps, heatmaps) + + heatmaps_pad = np.pad( + heatmaps, ((0, 0), (1, 1), (1, 1)), mode='edge').flatten() + + for n in range(N): + index = keypoints[n, :, 0] + 1 + (keypoints[n, :, 1] + 1) * (W + 2) + index += (W + 2) * (H + 2) * np.arange(0, K) + index = index.astype(int).reshape(-1, 1) + i_ = heatmaps_pad[index] + ix1 = heatmaps_pad[index + 1] + iy1 = heatmaps_pad[index + W + 2] + ix1y1 = heatmaps_pad[index + W + 3] + ix1_y1_ = heatmaps_pad[index - W - 3] + ix1_ = heatmaps_pad[index - 1] + iy1_ = heatmaps_pad[index - 2 - W] + + dx = 0.5 * (ix1 - ix1_) + dy = 0.5 * (iy1 - iy1_) + derivative = np.concatenate([dx, dy], axis=1) + derivative = derivative.reshape(K, 2, 1) + + dxx = ix1 - 2 * i_ + ix1_ + dyy = iy1 - 2 * i_ + iy1_ + dxy = 0.5 * (ix1y1 - ix1 - iy1 + i_ + i_ - ix1_ - iy1_ + ix1_y1_) + hessian = np.concatenate([dxx, dxy, dxy, dyy], axis=1) + hessian = hessian.reshape(K, 2, 2) + hessian = np.linalg.inv(hessian + np.finfo(np.float32).eps * np.eye(2)) + keypoints[n] -= np.einsum('imn,ink->imk', hessian, + derivative).squeeze() + + return keypoints diff --git a/mmpose/datasets/transforms/bottomup_transforms.py b/mmpose/datasets/transforms/bottomup_transforms.py index c7d9063dd5..cddf43f688 100644 --- a/mmpose/datasets/transforms/bottomup_transforms.py +++ b/mmpose/datasets/transforms/bottomup_transforms.py @@ -323,21 +323,21 @@ class BottomupResize(BaseTransform): Added Keys: - input_size - - input_scales - - resize_mode + Args: input_size (Tuple[int, int]): The input size of the model in [w, h]. Note that the actually size of the resized image will be affected - by ``resize_mode`` and ``base_size``, thus may not exactly equals + by ``resize_mode`` and ``size_factor``, thus may not exactly equals to the ``input_size`` - input_scales (List[float], optional): The scales to build an image - pyramid, where the image size at i-th level will be - :math:`input_size * input_scales[i]`. If not given, the result - will be a single image resized to ``input_size``. Defaults to - ``None`` - base_size (int): The actual input size will be ceiled to - a multiple of the `base` value at both sides. Defaults to 8 + aux_scales (List[float], optional): The auxiliary input scales for + multi-scale testing. If given, the input image will be resized + to different scales to build a image pyramid. And heatmaps from + all scales will be aggregated to make final prediction. Defaults + to ``None`` + size_factor (int): The actual input size will be ceiled to + a multiple of the `size_factor` value at both sides. + Defaults to 8 resize_mode (str): The method to resize the image to the input size. Options are: @@ -357,16 +357,16 @@ class BottomupResize(BaseTransform): def __init__(self, input_size: Tuple[int, int], - input_scales: Optional[List[float]] = None, - base_size: int = 8, + aux_scales: Optional[List[float]] = None, + size_factor: int = 8, resize_mode: str = 'fit', use_udp: bool = False): super().__init__() self.input_size = input_size - self.input_scales = input_scales + self.aux_scales = aux_scales self.resize_mode = resize_mode - self.base_size = base_size + self.size_factor = size_factor self.use_udp = use_udp @staticmethod @@ -374,7 +374,7 @@ def _ceil_to_multiple(size: Tuple[int, int], base: int): """Ceil the given size (tuple of [w, h]) to a multiple of the base.""" return tuple(int(np.ceil(s / base) * base) for s in size) - def _get_target_size(self, img_size: Tuple[int, int], + def _get_actual_size(self, img_size: Tuple[int, int], input_size: Tuple[int, int]) -> Tuple: """Calculate the actual input size and the size of the resized image. @@ -384,44 +384,44 @@ def _get_target_size(self, img_size: Tuple[int, int], Returns: tuple: - - target_input_size (Tuple[int, int]): The target size to generate + - actual_input_size (Tuple[int, int]): The target size to generate the model input which will contain the resized image - - target_img_size (Tuple[int, int]): The target size to resize the + - actual_img_size (Tuple[int, int]): The target size to resize the image """ img_w, img_h = img_size ratio = img_w / img_h if self.resize_mode == 'fit': - target_input_size = self._ceil_to_multiple(input_size, - self.base_size) - if target_input_size != input_size: + actual_input_size = self._ceil_to_multiple(input_size, + self.size_factor) + if actual_input_size != input_size: raise ValueError( 'When ``resize_mode==\'fit\', the input size (height and' - ' width) should be mulitples of the base_size(' - f'{self.base_size}) at all scales. Got invalid input size' - f' {input_size}.') + ' width) should be mulitples of the size_factor(' + f'{self.size_factor}) at all scales. Got invalid input ' + f'size {input_size}.') - tgt_w, tgt_h = target_input_size + tgt_w, tgt_h = actual_input_size rsz_w = min(tgt_w, tgt_h * ratio) rsz_h = min(tgt_h, tgt_w / ratio) - target_img_size = (rsz_w, rsz_h) + actual_img_size = (rsz_w, rsz_h) elif self.resize_mode == 'expand': - _target_input_size = self._ceil_to_multiple( - input_size, self.base_size) - tgt_w, tgt_h = _target_input_size + _actual_input_size = self._ceil_to_multiple( + input_size, self.size_factor) + tgt_w, tgt_h = _actual_input_size rsz_w = max(tgt_w, tgt_h * ratio) rsz_h = max(tgt_h, tgt_w / ratio) - target_img_size = (rsz_w, rsz_h) - target_input_size = self._ceil_to_multiple(target_img_size, - self.base_size) + actual_img_size = (rsz_w, rsz_h) + actual_input_size = self._ceil_to_multiple(actual_img_size, + self.size_factor) else: raise ValueError(f'Invalid resize mode {self.resize_mode}') - return target_input_size, target_img_size + return actual_input_size, actual_img_size def transform(self, results: Dict) -> Optional[dict]: """The transform function of :class:`BottomupResize` to perform @@ -441,17 +441,19 @@ def transform(self, results: Dict) -> Optional[dict]: img_h, img_w = results['img_shape'] w, h = self.input_size - if self.input_scales: - input_sizes = [(int(w * s), int(h * s)) for s in self.input_scales] - else: - input_sizes = [(w, h)] + input_sizes = [(w, h)] + if self.aux_scales: + input_sizes += [(int(w * s), int(h * s)) + for s in self.input_scales] imgs = [] warp_mats = [] + actual_input_sizes = [] + actual_img_sizes = [] for _w, _h in input_sizes: - target_input_size, target_img_size = self._get_target_size( + actual_input_size, actual_img_size = self._get_actual_size( img_size=(img_w, img_h), input_size=(_w, _h)) if self.use_udp: @@ -462,7 +464,7 @@ def transform(self, results: Dict) -> Optional[dict]: center=center, scale=scale, rot=0, - output_size=target_img_size) + output_size=actual_img_size) else: center = np.array([img_w / 2, img_h / 2], dtype=np.float32) scale = np.array([img_w, img_h], dtype=np.float32) @@ -470,19 +472,24 @@ def transform(self, results: Dict) -> Optional[dict]: center=center, scale=scale, rot=0, - output_size=target_img_size) + output_size=actual_img_size) _img = cv2.warpAffine( - img, warp_mat, target_input_size, flags=cv2.INTER_LINEAR) + img, warp_mat, actual_input_size, flags=cv2.INTER_LINEAR) imgs.append(_img) warp_mats.append(warp_mat) + actual_input_sizes[actual_input_size] + actual_img_sizes.append(actual_img_size) - if self.input_scales: + if self.aux_scales: results['img'] = imgs - results['warp_mat'] = warp_mats else: results['img'] = imgs[0] - results['warp_mat'] = warp_mats[0] + + # The size/transform information of the original image + results['warp_mat'] = warp_mats[0] + results['input_size'] = actual_input_sizes[0] + results['img_size'] = actual_img_sizes[0] return results diff --git a/mmpose/datasets/transforms/common_transforms.py b/mmpose/datasets/transforms/common_transforms.py index c3b928c85b..5ecec9715a 100644 --- a/mmpose/datasets/transforms/common_transforms.py +++ b/mmpose/datasets/transforms/common_transforms.py @@ -95,7 +95,7 @@ class RandomFlip(BaseTransform): - img - img_shape - - flip_pairs + - flip_indices - bbox (optional) - bbox_center (optional) - keypoints (optional) diff --git a/mmpose/models/utils/tta.py b/mmpose/models/utils/tta.py index c32c994122..e486d8a49a 100644 --- a/mmpose/models/utils/tta.py +++ b/mmpose/models/utils/tta.py @@ -18,7 +18,7 @@ def flip_heatmaps(heatmaps: Tensor, flip_mode (str): Specify the flipping mode. Options are: - ``'heatmap'``: horizontally flip the heatmaps and swap heatmaps - of symmetric keypoints according to ``flip_pairs`` + of symmetric keypoints according to ``flip_indices`` - ``'udp_combined'``: similar to ``'heatmap'`` mode but further flip the x_offset values shift_heatmap (bool): Shift the flipped heatmaps to align with the diff --git a/mmpose/testing/_utils.py b/mmpose/testing/_utils.py index 21e9a8451b..5be9387e76 100644 --- a/mmpose/testing/_utils.py +++ b/mmpose/testing/_utils.py @@ -15,7 +15,8 @@ def get_coco_sample( num_instances=1, with_bbox_cs=True, with_img_mask=False, - random_keypoints_visible=False): + random_keypoints_visible=False, + non_occlusion=False): """Create a dummy data sample in COCO style.""" rng = np.random.RandomState(0) h, w = img_shape @@ -24,7 +25,13 @@ def get_coco_sample( else: img = np.full((h, w, 3), img_fill, dtype=np.uint8) - bbox = _rand_bboxes(rng, num_instances, w, h) + if non_occlusion: + bbox = _rand_bboxes(rng, num_instances, w / num_instances, h) + for i in range(num_instances): + bbox[i, 0::2] += w / num_instances * i + else: + bbox = _rand_bboxes(rng, num_instances, w, h) + keypoints = _rand_keypoints(rng, bbox, 17) if random_keypoints_visible: keypoints_visible = np.random.randint(0, 2, (num_instances, diff --git a/tests/test_codecs/test_associative_embedding.py b/tests/test_codecs/test_associative_embedding.py new file mode 100644 index 0000000000..6f5266d6d6 --- /dev/null +++ b/tests/test_codecs/test_associative_embedding.py @@ -0,0 +1,303 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from itertools import product +from unittest import TestCase + +import numpy as np +import torch +from munkres import Munkres + +from mmpose.codecs import AssociativeEmbedding +from mmpose.registry import KEYPOINT_CODECS +from mmpose.testing import get_coco_sample + + +class TestAssociativeEmbedding(TestCase): + + def setUp(self) -> None: + self.decode_keypoint_order = [ + 0, 1, 2, 3, 4, 5, 6, 11, 12, 7, 8, 9, 10, 13, 14, 15, 16 + ] + + def test_build(self): + cfg = dict( + type='AssociativeEmbedding', + input_size=(256, 256), + heatmap_size=(64, 64), + use_udp=False, + decode_keypoint_order=self.decode_keypoint_order, + ) + codec = KEYPOINT_CODECS.build(cfg) + self.assertIsInstance(codec, AssociativeEmbedding) + + def test_encode(self): + data = get_coco_sample(img_shape=(256, 256), num_instances=1) + + # w/o UDP + codec = AssociativeEmbedding( + input_size=(256, 256), + heatmap_size=(64, 64), + use_udp=False, + decode_keypoint_order=self.decode_keypoint_order) + + heatmaps, keypoint_indices, keypoint_weights = codec.encode( + data['keypoints'], data['keypoints_visible']) + + self.assertEqual(heatmaps.shape, (17, 64, 64)) + self.assertEqual(keypoint_indices.shape, (1, 17, 2)) + self.assertEqual(keypoint_weights.shape, (1, 17)) + + for k in range(heatmaps.shape[0]): + index_expected = np.argmax(heatmaps[k]) + index_encoded = keypoint_indices[0, k, 0] + self.assertEqual(index_expected, index_encoded) + + # w/ UDP + codec = AssociativeEmbedding( + input_size=(256, 256), + heatmap_size=(64, 64), + use_udp=True, + decode_keypoint_order=self.decode_keypoint_order) + + heatmaps, keypoint_indices, keypoint_weights = codec.encode( + data['keypoints'], data['keypoints_visible']) + + self.assertEqual(heatmaps.shape, (17, 64, 64)) + self.assertEqual(keypoint_indices.shape, (1, 17, 2)) + self.assertEqual(keypoint_weights.shape, (1, 17)) + + for k in range(heatmaps.shape[0]): + index_expected = np.argmax(heatmaps[k]) + index_encoded = keypoint_indices[0, k, 0] + self.assertEqual(index_expected, index_encoded) + + def _get_tags(self, heatmaps, keypoint_indices, tag_per_keypoint: bool): + + K, H, W = heatmaps.shape + N = keypoint_indices.shape[0] + + if tag_per_keypoint: + tags = np.zeros((K, H, W), dtype=np.float32) + else: + tags = np.zeros((1, H, W), dtype=np.float32) + + for n, k in product(range(N), range(K)): + y, x = np.unravel_index(keypoint_indices[n, k, 0], (H, W)) + if tag_per_keypoint: + tags[k, y, x] = n + else: + tags[0, y, x] = n + + return tags + + def _sort_preds(self, keypoints_pred, scores_pred, keypoints_gt): + """Sort multi-instance predictions to best match the ground-truth. + + Args: + keypoints_pred (np.ndarray): predictions in shape (N, K, D) + scores (np.ndarray): predictions in shape (N, K) + keypoints_gt (np.ndarray): ground-truth in shape (N, K, D) + + Returns: + np.ndarray: Sorted predictions + """ + assert keypoints_gt.shape == keypoints_pred.shape + costs = np.linalg.norm( + keypoints_gt[None] - keypoints_pred[:, None], ord=2, + axis=3).mean(axis=2) + match = Munkres().compute(costs) + keypoints_pred_sorted = np.zeros_like(keypoints_pred) + scores_pred_sorted = np.zeros_like(scores_pred) + for i, j in match: + keypoints_pred_sorted[i] = keypoints_pred[j] + scores_pred_sorted[i] = scores_pred[j] + + return keypoints_pred_sorted, scores_pred_sorted + + def test_decode(self): + data = get_coco_sample( + img_shape=(256, 256), num_instances=2, non_occlusion=True) + + # w/o UDP, tag_per_keypoint==True + codec = AssociativeEmbedding( + input_size=(256, 256), + heatmap_size=(64, 64), + use_udp=False, + decode_keypoint_order=self.decode_keypoint_order, + tag_per_keypoint=True) + + heatmaps, keypoint_indices, _ = codec.encode(data['keypoints'], + data['keypoints_visible']) + + tags = self._get_tags( + heatmaps, keypoint_indices, tag_per_keypoint=True) + + # to Tensor + batch_heatmaps = torch.from_numpy(heatmaps[None]) + batch_tags = torch.from_numpy(tags[None]) + + batch_keypoints, batch_keypoint_scores = codec.batch_decode( + batch_heatmaps, batch_tags) + + self.assertIsInstance(batch_keypoints, list) + self.assertIsInstance(batch_keypoint_scores, list) + self.assertEqual(len(batch_keypoints), 1) + self.assertEqual(len(batch_keypoint_scores), 1) + + keypoints, scores = self._sort_preds(batch_keypoints[0], + batch_keypoint_scores[0], + data['keypoints']) + + self.assertIsInstance(keypoints, np.ndarray) + self.assertIsInstance(scores, np.ndarray) + self.assertEqual(keypoints.shape, (2, 17, 2)) + self.assertEqual(scores.shape, (2, 17)) + + self.assertTrue(np.allclose(keypoints, data['keypoints'], atol=4.0)) + + # w/o UDP, tag_per_keypoint==False + codec = AssociativeEmbedding( + input_size=(256, 256), + heatmap_size=(64, 64), + use_udp=False, + decode_keypoint_order=self.decode_keypoint_order, + tag_per_keypoint=False) + + heatmaps, keypoint_indices, _ = codec.encode(data['keypoints'], + data['keypoints_visible']) + + tags = self._get_tags( + heatmaps, keypoint_indices, tag_per_keypoint=False) + + # to Tensor + batch_heatmaps = torch.from_numpy(heatmaps[None]) + batch_tags = torch.from_numpy(tags[None]) + + batch_keypoints, batch_keypoint_scores = codec.batch_decode( + batch_heatmaps, batch_tags) + + self.assertIsInstance(batch_keypoints, list) + self.assertIsInstance(batch_keypoint_scores, list) + self.assertEqual(len(batch_keypoints), 1) + self.assertEqual(len(batch_keypoint_scores), 1) + + keypoints, scores = self._sort_preds(batch_keypoints[0], + batch_keypoint_scores[0], + data['keypoints']) + + self.assertIsInstance(keypoints, np.ndarray) + self.assertIsInstance(scores, np.ndarray) + self.assertEqual(keypoints.shape, (2, 17, 2)) + self.assertEqual(scores.shape, (2, 17)) + + self.assertTrue(np.allclose(keypoints, data['keypoints'], atol=4.0)) + + # w/ UDP, tag_per_keypoint==True + codec = AssociativeEmbedding( + input_size=(256, 256), + heatmap_size=(64, 64), + use_udp=True, + decode_keypoint_order=self.decode_keypoint_order, + tag_per_keypoint=True) + + heatmaps, keypoint_indices, _ = codec.encode(data['keypoints'], + data['keypoints_visible']) + + tags = self._get_tags( + heatmaps, keypoint_indices, tag_per_keypoint=True) + + # to Tensor + batch_heatmaps = torch.from_numpy(heatmaps[None]) + batch_tags = torch.from_numpy(tags[None]) + + batch_keypoints, batch_keypoint_scores = codec.batch_decode( + batch_heatmaps, batch_tags) + + self.assertIsInstance(batch_keypoints, list) + self.assertIsInstance(batch_keypoint_scores, list) + self.assertEqual(len(batch_keypoints), 1) + self.assertEqual(len(batch_keypoint_scores), 1) + + keypoints, scores = self._sort_preds(batch_keypoints[0], + batch_keypoint_scores[0], + data['keypoints']) + + self.assertIsInstance(keypoints, np.ndarray) + self.assertIsInstance(scores, np.ndarray) + self.assertEqual(keypoints.shape, (2, 17, 2)) + self.assertEqual(scores.shape, (2, 17)) + + self.assertTrue(np.allclose(keypoints, data['keypoints'], atol=4.0)) + + # w/ UDP, tag_per_keypoint==False + codec = AssociativeEmbedding( + input_size=(256, 256), + heatmap_size=(64, 64), + use_udp=True, + decode_keypoint_order=self.decode_keypoint_order, + tag_per_keypoint=False) + + heatmaps, keypoint_indices, _ = codec.encode(data['keypoints'], + data['keypoints_visible']) + + tags = self._get_tags( + heatmaps, keypoint_indices, tag_per_keypoint=False) + + # to Tensor + batch_heatmaps = torch.from_numpy(heatmaps[None]) + batch_tags = torch.from_numpy(tags[None]) + + batch_keypoints, batch_keypoint_scores = codec.batch_decode( + batch_heatmaps, batch_tags) + + self.assertIsInstance(batch_keypoints, list) + self.assertIsInstance(batch_keypoint_scores, list) + self.assertEqual(len(batch_keypoints), 1) + self.assertEqual(len(batch_keypoint_scores), 1) + + keypoints, scores = self._sort_preds(batch_keypoints[0], + batch_keypoint_scores[0], + data['keypoints']) + + self.assertIsInstance(keypoints, np.ndarray) + self.assertIsInstance(scores, np.ndarray) + self.assertEqual(keypoints.shape, (2, 17, 2)) + self.assertEqual(scores.shape, (2, 17)) + + self.assertTrue(np.allclose(keypoints, data['keypoints'], atol=4.0)) + + # Dynamic input sizes in decoder + codec = AssociativeEmbedding( + input_size=(256, 256), + heatmap_size=(64, 64), + use_udp=False, + decode_keypoint_order=self.decode_keypoint_order, + tag_per_keypoint=True) + + heatmaps, keypoint_indices, _ = codec.encode(data['keypoints'], + data['keypoints_visible']) + + tags = self._get_tags( + heatmaps, keypoint_indices, tag_per_keypoint=True) + + # to Tensor + batch_heatmaps = torch.from_numpy(heatmaps[None]) + batch_tags = torch.from_numpy(tags[None]) + + batch_keypoints, batch_keypoint_scores = codec.batch_decode( + batch_heatmaps, batch_tags, input_sizes=[(256, 256)]) + + self.assertIsInstance(batch_keypoints, list) + self.assertIsInstance(batch_keypoint_scores, list) + self.assertEqual(len(batch_keypoints), 1) + self.assertEqual(len(batch_keypoint_scores), 1) + + keypoints, scores = self._sort_preds(batch_keypoints[0], + batch_keypoint_scores[0], + data['keypoints']) + + self.assertIsInstance(keypoints, np.ndarray) + self.assertIsInstance(scores, np.ndarray) + self.assertEqual(keypoints.shape, (2, 17, 2)) + self.assertEqual(scores.shape, (2, 17)) + + self.assertTrue(np.allclose(keypoints, data['keypoints'], atol=4.0))