Skip to content

Commit

Permalink
[Feature] Generate Heatmap (#336)
Browse files Browse the repository at this point in the history
* [Feature] Generate Heatmap

* Rename

* Rename

* Update

* Update

Co-authored-by: liyinshuo <[email protected]>
  • Loading branch information
Yshuo-Li and liyinshuo authored Jun 1, 2021
1 parent c6f0ce2 commit 6703c5d
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 3 deletions.
4 changes: 2 additions & 2 deletions mmedit/datasets/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
CropLike, FixedCrop, ModCrop, PairedRandomCrop)
from .formating import (Collect, FormatTrimap, GetMaskedImage, ImageToTensor,
ToTensor)
from .generate_coordinate_and_cell import GenerateCoordinateAndCell
from .generate_assistant import GenerateCoordinateAndCell, GenerateHeatmap
from .loading import (GetSpatialDiscountMask, LoadImageFromFile,
LoadImageFromFileList, LoadMask, LoadPairedImageFromFile,
RandomLoadResizeBg)
Expand All @@ -31,5 +31,5 @@
'CropAroundFg', 'GetSpatialDiscountMask', 'RandomDownSampling',
'GenerateTrimapWithDistTransform', 'TransformTrimap',
'GenerateCoordinateAndCell', 'GenerateSegmentIndices', 'MirrorSequence',
'CropLike'
'CropLike', 'GenerateHeatmap'
]
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,80 @@
from .utils import make_coord


@PIPELINES.register_module()
class GenerateHeatmap:
"""Generate heatmap from keypoint.
Args:
keypoint (str): Key of keypoint in dict.
ori_size (int | Tuple[int]): Original image size of keypoint.
target_size (int | Tuple[int]): Target size of heatmap.
sigma (float): Sigma parameter of heatmap. Default: 1.0
"""

def __init__(self, keypoint, ori_size, target_size, sigma=1.0):
if isinstance(ori_size, int):
ori_size = (ori_size, ori_size)
else:
ori_size = ori_size[:2]
if isinstance(target_size, int):
target_size = (target_size, target_size)
else:
target_size = target_size[:2]
self.size_ratio = (target_size[0] / ori_size[0],
target_size[1] / ori_size[1])
self.keypoint = keypoint
self.sigma = sigma
self.target_size = target_size
self.ori_size = ori_size

def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation. Require keypoint.
Returns:
dict: A dict containing the processed data and information.
Add 'heatmap'.
"""
keypoint_list = [(keypoint[0] * self.size_ratio[0],
keypoint[1] * self.size_ratio[1])
for keypoint in results[self.keypoint]]
heatmap_list = [
self._generate_one_heatmap(keypoint) for keypoint in keypoint_list
]
results['heatmap'] = np.stack(heatmap_list, axis=2)
return results

def _generate_one_heatmap(self, keypoint):
"""Generate One Heatmap.
Args:
landmark (Tuple[float]): Location of a landmark.
results:
heatmap (np.ndarray): A heatmap of landmark.
"""
w, h = self.target_size

x_range = np.arange(start=0, stop=w, dtype=int)
y_range = np.arange(start=0, stop=h, dtype=int)
grid_x, grid_y = np.meshgrid(x_range, y_range)
dist2 = (grid_x - keypoint[0])**2 + (grid_y - keypoint[1])**2
exponent = dist2 / 2.0 / self.sigma / self.sigma
heatmap = np.exp(-exponent)
return heatmap

def __repr__(self):
return (f'{self.__class__.__name__}, '
f'keypoint={self.keypoint}, '
f'ori_size={self.ori_size}, '
f'target_size={self.target_size}, '
f'sigma={self.sigma}')


@PIPELINES.register_module()
class GenerateCoordinateAndCell:
"""Generate coordinate and cell.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,26 @@
import torch

from mmedit.datasets.pipelines import GenerateCoordinateAndCell
from mmedit.datasets.pipelines import (GenerateCoordinateAndCell,
GenerateHeatmap)


def test_generate_heatmap():
inputs = dict(landmark=[(1, 2), (3, 4)])
generate_heatmap = GenerateHeatmap('landmark', 4, 16)
results = generate_heatmap(inputs)
assert set(list(results.keys())) == set(['landmark', 'heatmap'])
assert results['heatmap'][:, :, 0].shape == (16, 16)
assert repr(generate_heatmap) == (
f'{generate_heatmap.__class__.__name__}, '
f'keypoint={generate_heatmap.keypoint}, '
f'ori_size={generate_heatmap.ori_size}, '
f'target_size={generate_heatmap.target_size}, '
f'sigma={generate_heatmap.sigma}')

generate_heatmap = GenerateHeatmap('landmark', (4, 5), (16, 17))
results = generate_heatmap(inputs)
assert set(list(results.keys())) == set(['landmark', 'heatmap'])
assert results['heatmap'][:, :, 0].shape == (17, 16)


def test_generate_coordinate_and_cell():
Expand Down

0 comments on commit 6703c5d

Please sign in to comment.