diff --git a/mmedit/datasets/pipelines/__init__.py b/mmedit/datasets/pipelines/__init__.py index a135734401..c13aa348d3 100644 --- a/mmedit/datasets/pipelines/__init__.py +++ b/mmedit/datasets/pipelines/__init__.py @@ -8,7 +8,7 @@ CropLike, FixedCrop, ModCrop, PairedRandomCrop) from .formating import (Collect, FormatTrimap, GetMaskedImage, ImageToTensor, ToTensor) -from .generate_coordinate_and_cell import GenerateCoordinateAndCell +from .generate_assistant import GenerateCoordinateAndCell, GenerateHeatmap from .loading import (GetSpatialDiscountMask, LoadImageFromFile, LoadImageFromFileList, LoadMask, LoadPairedImageFromFile, RandomLoadResizeBg) @@ -31,5 +31,5 @@ 'CropAroundFg', 'GetSpatialDiscountMask', 'RandomDownSampling', 'GenerateTrimapWithDistTransform', 'TransformTrimap', 'GenerateCoordinateAndCell', 'GenerateSegmentIndices', 'MirrorSequence', - 'CropLike' + 'CropLike', 'GenerateHeatmap' ] diff --git a/mmedit/datasets/pipelines/generate_coordinate_and_cell.py b/mmedit/datasets/pipelines/generate_assistant.py similarity index 57% rename from mmedit/datasets/pipelines/generate_coordinate_and_cell.py rename to mmedit/datasets/pipelines/generate_assistant.py index 0e9bcf11cf..5de6e78b0f 100644 --- a/mmedit/datasets/pipelines/generate_coordinate_and_cell.py +++ b/mmedit/datasets/pipelines/generate_assistant.py @@ -5,6 +5,80 @@ from .utils import make_coord +@PIPELINES.register_module() +class GenerateHeatmap: + """Generate heatmap from keypoint. + + Args: + keypoint (str): Key of keypoint in dict. + ori_size (int | Tuple[int]): Original image size of keypoint. + target_size (int | Tuple[int]): Target size of heatmap. + sigma (float): Sigma parameter of heatmap. Default: 1.0 + """ + + def __init__(self, keypoint, ori_size, target_size, sigma=1.0): + if isinstance(ori_size, int): + ori_size = (ori_size, ori_size) + else: + ori_size = ori_size[:2] + if isinstance(target_size, int): + target_size = (target_size, target_size) + else: + target_size = target_size[:2] + self.size_ratio = (target_size[0] / ori_size[0], + target_size[1] / ori_size[1]) + self.keypoint = keypoint + self.sigma = sigma + self.target_size = target_size + self.ori_size = ori_size + + def __call__(self, results): + """Call function. + + Args: + results (dict): A dict containing the necessary information and + data for augmentation. Require keypoint. + + Returns: + dict: A dict containing the processed data and information. + Add 'heatmap'. + """ + keypoint_list = [(keypoint[0] * self.size_ratio[0], + keypoint[1] * self.size_ratio[1]) + for keypoint in results[self.keypoint]] + heatmap_list = [ + self._generate_one_heatmap(keypoint) for keypoint in keypoint_list + ] + results['heatmap'] = np.stack(heatmap_list, axis=2) + return results + + def _generate_one_heatmap(self, keypoint): + """Generate One Heatmap. + + Args: + landmark (Tuple[float]): Location of a landmark. + + results: + heatmap (np.ndarray): A heatmap of landmark. + """ + w, h = self.target_size + + x_range = np.arange(start=0, stop=w, dtype=int) + y_range = np.arange(start=0, stop=h, dtype=int) + grid_x, grid_y = np.meshgrid(x_range, y_range) + dist2 = (grid_x - keypoint[0])**2 + (grid_y - keypoint[1])**2 + exponent = dist2 / 2.0 / self.sigma / self.sigma + heatmap = np.exp(-exponent) + return heatmap + + def __repr__(self): + return (f'{self.__class__.__name__}, ' + f'keypoint={self.keypoint}, ' + f'ori_size={self.ori_size}, ' + f'target_size={self.target_size}, ' + f'sigma={self.sigma}') + + @PIPELINES.register_module() class GenerateCoordinateAndCell: """Generate coordinate and cell. diff --git a/tests/test_generate_coordinate_and_cell.py b/tests/test_generate_assistant.py similarity index 53% rename from tests/test_generate_coordinate_and_cell.py rename to tests/test_generate_assistant.py index 718108fb0e..3037ccb271 100644 --- a/tests/test_generate_coordinate_and_cell.py +++ b/tests/test_generate_assistant.py @@ -1,6 +1,26 @@ import torch -from mmedit.datasets.pipelines import GenerateCoordinateAndCell +from mmedit.datasets.pipelines import (GenerateCoordinateAndCell, + GenerateHeatmap) + + +def test_generate_heatmap(): + inputs = dict(landmark=[(1, 2), (3, 4)]) + generate_heatmap = GenerateHeatmap('landmark', 4, 16) + results = generate_heatmap(inputs) + assert set(list(results.keys())) == set(['landmark', 'heatmap']) + assert results['heatmap'][:, :, 0].shape == (16, 16) + assert repr(generate_heatmap) == ( + f'{generate_heatmap.__class__.__name__}, ' + f'keypoint={generate_heatmap.keypoint}, ' + f'ori_size={generate_heatmap.ori_size}, ' + f'target_size={generate_heatmap.target_size}, ' + f'sigma={generate_heatmap.sigma}') + + generate_heatmap = GenerateHeatmap('landmark', (4, 5), (16, 17)) + results = generate_heatmap(inputs) + assert set(list(results.keys())) == set(['landmark', 'heatmap']) + assert results['heatmap'][:, :, 0].shape == (17, 16) def test_generate_coordinate_and_cell():