Skip to content

Commit

Permalink
[Feature] Generate Heatmap
Browse files Browse the repository at this point in the history
  • Loading branch information
liyinshuo committed May 29, 2021
1 parent 6238c69 commit 585ebfa
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 3 deletions.
4 changes: 2 additions & 2 deletions mmedit/datasets/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
CropLike, FixedCrop, ModCrop, PairedRandomCrop)
from .formating import (Collect, FormatTrimap, GetMaskedImage, ImageToTensor,
ToTensor)
from .generate_coordinate_and_cell import GenerateCoordinateAndCell
from .generate_assistant import GenerateCoordinateAndCell, GenerateHeatmap
from .loading import (GetSpatialDiscountMask, LoadImageFromFile,
LoadImageFromFileList, LoadMask, LoadPairedImageFromFile,
RandomLoadResizeBg)
Expand All @@ -31,5 +31,5 @@
'CropAroundFg', 'GetSpatialDiscountMask', 'RandomDownSampling',
'GenerateTrimapWithDistTransform', 'TransformTrimap',
'GenerateCoordinateAndCell', 'GenerateSegmentIndices', 'MirrorSequence',
'CropLike'
'CropLike', 'GenerateHeatmap'
]
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,78 @@
from .utils import make_coord


@PIPELINES.register_module()
class GenerateHeatmap:
"""Generate heatmap from keypoint.
Args:
keypoint (str): Key of keypoint in dict.
ori_size (int | Tuple[int]): Original image size of keypoint.
target_size (int | Tuple[int]): Target size of heatmap.
sigma (float): Sigma parameter of heatmap. Default: 1.0
"""

def __init__(self, keypoint, ori_size, target_size, sigma=1.0):
if isinstance(ori_size, int):
ori_size = (ori_size, ori_size)
else:
ori_size = ori_size[:2]
if isinstance(target_size, int):
target_size = (target_size, target_size)
else:
target_size = target_size[:2]
self.rate = (target_size[0] / ori_size[0],
target_size[1] / ori_size[1])
self.keypoint = keypoint
self.sigma = sigma
self.target_size = target_size

def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation. Require keypoint.
Returns:
dict: A dict containing the processed data and information.
Add 'heatmap'.
"""
keypoint_list = [(keypoint[0] * self.rate[0],
keypoint[1] * self.rate[1])
for keypoint in results[self.keypoint]]
heatmap_list = [
self._generate_one_heatmap(keypoint) for keypoint in keypoint_list
]
results['heatmap'] = np.stack(heatmap_list, axis=2)
return results

def _generate_one_heatmap(self, keypoint):
"""Generate One Heatmap.
Args:
landmark (Tuple[float]): Location of a landmark.
results:
heatmap (np.ndarray): A heatmap of landmark.
"""
w, h = self.target_size

x_range = np.arange(start=0, stop=w, dtype=int)
y_range = np.arange(start=0, stop=h, dtype=int)
xx, yy = np.meshgrid(x_range, y_range)
d2 = (xx - keypoint[0])**2 + (yy - keypoint[1])**2
exponent = d2 / 2.0 / self.sigma / self.sigma
heatmap = np.exp(-exponent)
return heatmap

def __repr__(self):
return (f'{self.__class__.__name__}, '
f'keypoint={self.keypoint}, '
f'rate={self.rate}, target_size={self.target_size}, '
f'sigma={self.sigma}')


@PIPELINES.register_module()
class GenerateCoordinateAndCell:
"""Generate coordinate and cell.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,26 @@
import torch

from mmedit.datasets.pipelines import GenerateCoordinateAndCell
from mmedit.datasets.pipelines import (GenerateCoordinateAndCell,
GenerateHeatmap)


def test_generate_heatmap():
inputs = dict(landmark=[(1, 2), (3, 4)])
generate_heatmap = GenerateHeatmap('landmark', 4, 16)
results = generate_heatmap(inputs)
assert set(list(results.keys())) == set(['landmark', 'heatmap'])
assert results['heatmap'][:, :, 0].shape == (16, 16)
assert repr(generate_heatmap) == (
f'{generate_heatmap.__class__.__name__}, '
f'keypoint={generate_heatmap.keypoint}, '
f'rate={generate_heatmap.rate}, '
f'target_size={generate_heatmap.target_size}, '
f'sigma={generate_heatmap.sigma}')

generate_heatmap = GenerateHeatmap('landmark', (4, 5), (16, 17))
results = generate_heatmap(inputs)
assert set(list(results.keys())) == set(['landmark', 'heatmap'])
assert results['heatmap'][:, :, 0].shape == (17, 16)


def test_generate_coordinate_and_cell():
Expand Down

0 comments on commit 585ebfa

Please sign in to comment.