diff --git a/mmedit/datasets/pipelines/__init__.py b/mmedit/datasets/pipelines/__init__.py index 61d7b4ae3f..b337dd1b40 100644 --- a/mmedit/datasets/pipelines/__init__.py +++ b/mmedit/datasets/pipelines/__init__.py @@ -1,7 +1,7 @@ from .augmentation import (BinarizeImage, Flip, GenerateFrameIndices, - GenerateFrameIndiceswithPadding, Pad, RandomAffine, - RandomJitter, RandomMaskDilation, RandomTransposeHW, - Resize, TemporalReverse) + GenerateFrameIndiceswithPadding, MirrorSequence, + Pad, RandomAffine, RandomJitter, RandomMaskDilation, + RandomTransposeHW, Resize, TemporalReverse) from .compose import Compose from .crop import (Crop, CropAroundCenter, CropAroundFg, CropAroundUnknown, FixedCrop, ModCrop, PairedRandomCrop) @@ -29,5 +29,5 @@ 'LoadPairedImageFromFile', 'GenerateSoftSeg', 'GenerateSeg', 'PerturbBg', 'CropAroundFg', 'GetSpatialDiscountMask', 'RandomDownSampling', 'GenerateTrimapWithDistTransform', 'TransformTrimap', - 'GenerateCoordinateAndCell' + 'GenerateCoordinateAndCell', 'MirrorSequence' ] diff --git a/mmedit/datasets/pipelines/augmentation.py b/mmedit/datasets/pipelines/augmentation.py index ea9faba214..2b76c444cb 100644 --- a/mmedit/datasets/pipelines/augmentation.py +++ b/mmedit/datasets/pipelines/augmentation.py @@ -895,3 +895,42 @@ def __repr__(self): repr_str = self.__class__.__name__ repr_str += f'(keys={self.keys}, reverse_ratio={self.reverse_ratio})' return repr_str + + +@PIPELINES.register_module() +class MirrorSequence: + """Extend short sequences (e.g. Vimeo-90K) by mirroring the sequences + + Given a sequence with N frames (x1, ..., xN), extend the sequence to + (x1, ..., xN, xN, ..., x1). + + Args: + keys (list[str]): The frame lists to be extended. + """ + + def __init__(self, keys): + self.keys = keys + + def __call__(self, results): + """Call function. + + Args: + results (dict): A dict containing the necessary information and + data for augmentation. + + Returns: + dict: A dict containing the processed data and information. + """ + for key in self.keys: + if isinstance(results[key], list): + results[key] = results[key] + results[key][::-1] + else: + raise TypeError('The input must be of class list[nparray]. ' + f'Got {type(results[key])}.') + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += (f'(keys={self.keys})') + return repr_str diff --git a/tests/test_augmentation.py b/tests/test_augmentation.py index 227587105e..a2452cbe5b 100644 --- a/tests/test_augmentation.py +++ b/tests/test_augmentation.py @@ -7,10 +7,11 @@ # yapf: disable from mmedit.datasets.pipelines import (BinarizeImage, Flip, GenerateFrameIndices, - GenerateFrameIndiceswithPadding, Pad, - RandomAffine, RandomJitter, - RandomMaskDilation, RandomTransposeHW, - Resize, TemporalReverse) + GenerateFrameIndiceswithPadding, + MirrorSequence, Pad, RandomAffine, + RandomJitter, RandomMaskDilation, + RandomTransposeHW, Resize, + TemporalReverse) # yapf: enable @@ -622,3 +623,27 @@ def test_temporal_reverse(self): np.testing.assert_almost_equal(results['lq'][0], img_lq1) np.testing.assert_almost_equal(results['lq'][1], img_lq2) np.testing.assert_almost_equal(results['gt'][0], img_gt) + + def mirror_sequence(self): + lqs = [np.random.rand(4, 4, 3) for _ in range(0, 5)] + gts = [np.random.rand(16, 16, 3) for _ in range(0, 5)] + + target_keys = ['lq', 'gt'] + mirror_sequence = MirrorSequence(keys=['lq', 'gt']) + results = dict(lq=lqs, gt=gts) + results = mirror_sequence(results) + + assert self.check_keys_contain(results.keys(), target_keys) + for i in range(0, 5): + np.testing.assert_almost_equal(results['lq'][i], + results['lq'][-i - 1]) + np.testing.assert_almost_equal(results['gt'][i], + results['gt'][-i - 1]) + + assert repr(mirror_sequence) == mirror_sequence.__class__.__name__ + ( + f"(keys=['lq', 'gt'])") + + # each key should contain a list of nparray + with pytest.raises(TypeError): + results = dict(lq=0, gt=gts) + mirror_sequence(results)