diff --git a/mmedit/datasets/pipelines/__init__.py b/mmedit/datasets/pipelines/__init__.py index 1538edca87..ac326c5de3 100644 --- a/mmedit/datasets/pipelines/__init__.py +++ b/mmedit/datasets/pipelines/__init__.py @@ -5,7 +5,7 @@ RandomTransposeHW, Resize, TemporalReverse) from .compose import Compose from .crop import (Crop, CropAroundCenter, CropAroundFg, CropAroundUnknown, - FixedCrop, ModCrop, PairedRandomCrop) + CropLike, FixedCrop, ModCrop, PairedRandomCrop) from .down_sampling import RandomDownSampling from .formating import (Collect, FormatTrimap, GetMaskedImage, ImageToTensor, ToTensor) @@ -30,5 +30,6 @@ 'LoadPairedImageFromFile', 'GenerateSoftSeg', 'GenerateSeg', 'PerturbBg', 'CropAroundFg', 'GetSpatialDiscountMask', 'RandomDownSampling', 'GenerateTrimapWithDistTransform', 'TransformTrimap', - 'GenerateCoordinateAndCell', 'GenerateSegmentIndices', 'MirrorSequence' + 'GenerateCoordinateAndCell', 'GenerateSegmentIndices', 'MirrorSequence', + 'CropLike' ] diff --git a/mmedit/datasets/pipelines/crop.py b/mmedit/datasets/pipelines/crop.py index df4443f6f7..1d1e4e00c9 100644 --- a/mmedit/datasets/pipelines/crop.py +++ b/mmedit/datasets/pipelines/crop.py @@ -532,3 +532,50 @@ def __call__(self, results): raise ValueError(f'Wrong img ndim: {img.ndim}.') results['gt'] = img return results + + +@PIPELINES.register_module() +class CropLike: + """Crop/pad the image in the target_key according to the size of image + in the reference_key . + + Args: + target_key (str): The key needs to be cropped. + reference_key (str | None): The reference key, need its size. + Default: None. + """ + + def __init__(self, target_key, reference_key=None): + + assert reference_key and target_key + self.target_key = target_key + self.reference_key = reference_key + + def __call__(self, results): + """Call function. + + Args: + results (dict): A dict containing the necessary information and + data for augmentation. + Require self.target_key and self.reference_key. + + Returns: + dict: A dict containing the processed data and information. + Modify self.target_key. + """ + size = results[self.reference_key].shape + old_image = results[self.target_key] + old_size = old_image.shape + h, w = old_size[:2] + new_size = size[:2] + old_size[2:] + h_cover, w_cover = min(h, size[0]), min(w, size[1]) + + format_image = np.zeros(new_size, dtype=old_image.dtype) + format_image[:h_cover, :w_cover] = old_image[:h_cover, :w_cover] + results[self.target_key] = format_image + + return results + + def __repr__(self): + return (self.__class__.__name__ + f' target_key={self.target_key}, ' + + f'reference_key={self.reference_key}') diff --git a/tests/test_crop.py b/tests/test_crop.py index 46a1fbdcf9..2db7c7f380 100644 --- a/tests/test_crop.py +++ b/tests/test_crop.py @@ -4,8 +4,8 @@ import pytest from mmedit.datasets.pipelines import (Crop, CropAroundCenter, CropAroundFg, - CropAroundUnknown, FixedCrop, ModCrop, - PairedRandomCrop) + CropAroundUnknown, CropLike, FixedCrop, + ModCrop, PairedRandomCrop) class TestAugmentations: @@ -405,3 +405,36 @@ def test_paired_random_crop(self): assert v.shape == (32, 32, 3) np.testing.assert_almost_equal(results['gt'][0], results['gt'][1]) np.testing.assert_almost_equal(results['lq'][0], results['lq'][1]) + + +def test_crop_like(): + img = np.uint8(np.random.randn(480, 640, 3) * 255) + img_ref = np.uint8(np.random.randn(512, 512, 3) * 255) + + inputs = dict(gt=img, ref=img_ref) + crop_like = CropLike(target_key='gt', reference_key='ref') + results = crop_like(inputs) + assert set(list(results.keys())) == set(['gt', 'ref']) + assert repr(crop_like) == ( + crop_like.__class__.__name__ + + f' target_key={crop_like.target_key}, ' + + f'reference_key={crop_like.reference_key}') + assert results['gt'].shape == (512, 512, 3) + sum_diff = np.sum(abs(results['gt'][:480, :512] - img[:480, :512])) + assert sum_diff < 1e-6 + + inputs = dict(gt=img, ref=img_ref[:, :, 0]) + crop_like = CropLike(target_key='gt', reference_key='ref') + results = crop_like(inputs) + assert set(list(results.keys())) == set(['gt', 'ref']) + assert results['gt'].shape == (512, 512, 3) + sum_diff = np.sum(abs(results['gt'][:480, :512] - img[:480, :512])) + assert sum_diff < 1e-6 + + inputs = dict(gt=img[:, :, 0], ref=img_ref) + crop_like = CropLike(target_key='gt', reference_key='ref') + results = crop_like(inputs) + assert set(list(results.keys())) == set(['gt', 'ref']) + assert results['gt'].shape == (512, 512) + sum_diff = np.sum(abs(results['gt'][:480, :512] - img[:480, :512, 0])) + assert sum_diff < 1e-6