Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Generate Heatmap #336

Merged
merged 5 commits into from
Jun 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mmedit/datasets/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
CropLike, FixedCrop, ModCrop, PairedRandomCrop)
from .formating import (Collect, FormatTrimap, GetMaskedImage, ImageToTensor,
ToTensor)
from .generate_coordinate_and_cell import GenerateCoordinateAndCell
from .generate_assistant import GenerateCoordinateAndCell, GenerateHeatmap
from .loading import (GetSpatialDiscountMask, LoadImageFromFile,
LoadImageFromFileList, LoadMask, LoadPairedImageFromFile,
RandomLoadResizeBg)
Expand All @@ -31,5 +31,5 @@
'CropAroundFg', 'GetSpatialDiscountMask', 'RandomDownSampling',
'GenerateTrimapWithDistTransform', 'TransformTrimap',
'GenerateCoordinateAndCell', 'GenerateSegmentIndices', 'MirrorSequence',
'CropLike'
'CropLike', 'GenerateHeatmap'
]
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,80 @@
from .utils import make_coord


@PIPELINES.register_module()
class GenerateHeatmap:
"""Generate heatmap from keypoint.

Args:
keypoint (str): Key of keypoint in dict.
ori_size (int | Tuple[int]): Original image size of keypoint.
target_size (int | Tuple[int]): Target size of heatmap.
sigma (float): Sigma parameter of heatmap. Default: 1.0
"""

def __init__(self, keypoint, ori_size, target_size, sigma=1.0):
if isinstance(ori_size, int):
ori_size = (ori_size, ori_size)
else:
ori_size = ori_size[:2]
if isinstance(target_size, int):
target_size = (target_size, target_size)
else:
target_size = target_size[:2]
self.size_ratio = (target_size[0] / ori_size[0],
target_size[1] / ori_size[1])
self.keypoint = keypoint
self.sigma = sigma
self.target_size = target_size
self.ori_size = ori_size

def __call__(self, results):
"""Call function.

Args:
results (dict): A dict containing the necessary information and
data for augmentation. Require keypoint.

Returns:
dict: A dict containing the processed data and information.
Add 'heatmap'.
"""
keypoint_list = [(keypoint[0] * self.size_ratio[0],
keypoint[1] * self.size_ratio[1])
for keypoint in results[self.keypoint]]
heatmap_list = [
self._generate_one_heatmap(keypoint) for keypoint in keypoint_list
]
results['heatmap'] = np.stack(heatmap_list, axis=2)
return results

def _generate_one_heatmap(self, keypoint):
"""Generate One Heatmap.

Args:
landmark (Tuple[float]): Location of a landmark.

results:
heatmap (np.ndarray): A heatmap of landmark.
"""
w, h = self.target_size

x_range = np.arange(start=0, stop=w, dtype=int)
y_range = np.arange(start=0, stop=h, dtype=int)
grid_x, grid_y = np.meshgrid(x_range, y_range)
dist2 = (grid_x - keypoint[0])**2 + (grid_y - keypoint[1])**2
exponent = dist2 / 2.0 / self.sigma / self.sigma
heatmap = np.exp(-exponent)
return heatmap

def __repr__(self):
return (f'{self.__class__.__name__}, '
f'keypoint={self.keypoint}, '
f'ori_size={self.ori_size}, '
f'target_size={self.target_size}, '
f'sigma={self.sigma}')


@PIPELINES.register_module()
class GenerateCoordinateAndCell:
"""Generate coordinate and cell.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,26 @@
import torch

from mmedit.datasets.pipelines import GenerateCoordinateAndCell
from mmedit.datasets.pipelines import (GenerateCoordinateAndCell,
GenerateHeatmap)


def test_generate_heatmap():
inputs = dict(landmark=[(1, 2), (3, 4)])
generate_heatmap = GenerateHeatmap('landmark', 4, 16)
results = generate_heatmap(inputs)
assert set(list(results.keys())) == set(['landmark', 'heatmap'])
assert results['heatmap'][:, :, 0].shape == (16, 16)
assert repr(generate_heatmap) == (
f'{generate_heatmap.__class__.__name__}, '
f'keypoint={generate_heatmap.keypoint}, '
f'ori_size={generate_heatmap.ori_size}, '
f'target_size={generate_heatmap.target_size}, '
f'sigma={generate_heatmap.sigma}')

generate_heatmap = GenerateHeatmap('landmark', (4, 5), (16, 17))
results = generate_heatmap(inputs)
assert set(list(results.keys())) == set(['landmark', 'heatmap'])
assert results['heatmap'][:, :, 0].shape == (17, 16)


def test_generate_coordinate_and_cell():
Expand Down