Skip to content

Commit

Permalink
Add MirrorExtend for training BasicVSR and IconVSR (open-mmlab#253)
Browse files Browse the repository at this point in the history
* Add MirrorExtend

* rename to MirrorSequenceExtend

* rename to 'MirrorSequence'
  • Loading branch information
ckkelvinchan authored Apr 14, 2021
1 parent 76c00e8 commit 30522aa
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 8 deletions.
8 changes: 4 additions & 4 deletions mmedit/datasets/pipelines/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .augmentation import (BinarizeImage, Flip, GenerateFrameIndices,
GenerateFrameIndiceswithPadding, Pad, RandomAffine,
RandomJitter, RandomMaskDilation, RandomTransposeHW,
Resize, TemporalReverse)
GenerateFrameIndiceswithPadding, MirrorSequence,
Pad, RandomAffine, RandomJitter, RandomMaskDilation,
RandomTransposeHW, Resize, TemporalReverse)
from .compose import Compose
from .crop import (Crop, CropAroundCenter, CropAroundFg, CropAroundUnknown,
FixedCrop, ModCrop, PairedRandomCrop)
Expand Down Expand Up @@ -29,5 +29,5 @@
'LoadPairedImageFromFile', 'GenerateSoftSeg', 'GenerateSeg', 'PerturbBg',
'CropAroundFg', 'GetSpatialDiscountMask', 'RandomDownSampling',
'GenerateTrimapWithDistTransform', 'TransformTrimap',
'GenerateCoordinateAndCell'
'GenerateCoordinateAndCell', 'MirrorSequence'
]
39 changes: 39 additions & 0 deletions mmedit/datasets/pipelines/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,3 +895,42 @@ def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(keys={self.keys}, reverse_ratio={self.reverse_ratio})'
return repr_str


@PIPELINES.register_module()
class MirrorSequence:
"""Extend short sequences (e.g. Vimeo-90K) by mirroring the sequences
Given a sequence with N frames (x1, ..., xN), extend the sequence to
(x1, ..., xN, xN, ..., x1).
Args:
keys (list[str]): The frame lists to be extended.
"""

def __init__(self, keys):
self.keys = keys

def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Returns:
dict: A dict containing the processed data and information.
"""
for key in self.keys:
if isinstance(results[key], list):
results[key] = results[key] + results[key][::-1]
else:
raise TypeError('The input must be of class list[nparray]. '
f'Got {type(results[key])}.')

return results

def __repr__(self):
repr_str = self.__class__.__name__
repr_str += (f'(keys={self.keys})')
return repr_str
33 changes: 29 additions & 4 deletions tests/test_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
# yapf: disable
from mmedit.datasets.pipelines import (BinarizeImage, Flip,
GenerateFrameIndices,
GenerateFrameIndiceswithPadding, Pad,
RandomAffine, RandomJitter,
RandomMaskDilation, RandomTransposeHW,
Resize, TemporalReverse)
GenerateFrameIndiceswithPadding,
MirrorSequence, Pad, RandomAffine,
RandomJitter, RandomMaskDilation,
RandomTransposeHW, Resize,
TemporalReverse)

# yapf: enable

Expand Down Expand Up @@ -622,3 +623,27 @@ def test_temporal_reverse(self):
np.testing.assert_almost_equal(results['lq'][0], img_lq1)
np.testing.assert_almost_equal(results['lq'][1], img_lq2)
np.testing.assert_almost_equal(results['gt'][0], img_gt)

def mirror_sequence(self):
lqs = [np.random.rand(4, 4, 3) for _ in range(0, 5)]
gts = [np.random.rand(16, 16, 3) for _ in range(0, 5)]

target_keys = ['lq', 'gt']
mirror_sequence = MirrorSequence(keys=['lq', 'gt'])
results = dict(lq=lqs, gt=gts)
results = mirror_sequence(results)

assert self.check_keys_contain(results.keys(), target_keys)
for i in range(0, 5):
np.testing.assert_almost_equal(results['lq'][i],
results['lq'][-i - 1])
np.testing.assert_almost_equal(results['gt'][i],
results['gt'][-i - 1])

assert repr(mirror_sequence) == mirror_sequence.__class__.__name__ + (
f"(keys=['lq', 'gt'])")

# each key should contain a list of nparray
with pytest.raises(TypeError):
results = dict(lq=0, gt=gts)
mirror_sequence(results)

0 comments on commit 30522aa

Please sign in to comment.