Skip to content

Commit

Permalink
rename to 'MirrorSequence'
Browse files Browse the repository at this point in the history
  • Loading branch information
ckkelvinchan committed Apr 14, 2021
1 parent 9acd9a9 commit 4421f51
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 14 deletions.
9 changes: 4 additions & 5 deletions mmedit/datasets/pipelines/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from .augmentation import (BinarizeImage, Flip, GenerateFrameIndices,
GenerateFrameIndiceswithPadding,
MirrorSequenceExtend, Pad, RandomAffine,
RandomJitter, RandomMaskDilation, RandomTransposeHW,
Resize, TemporalReverse)
GenerateFrameIndiceswithPadding, MirrorSequence,
Pad, RandomAffine, RandomJitter, RandomMaskDilation,
RandomTransposeHW, Resize, TemporalReverse)
from .compose import Compose
from .crop import (Crop, CropAroundCenter, CropAroundFg, CropAroundUnknown,
FixedCrop, ModCrop, PairedRandomCrop)
Expand Down Expand Up @@ -30,5 +29,5 @@
'LoadPairedImageFromFile', 'GenerateSoftSeg', 'GenerateSeg', 'PerturbBg',
'CropAroundFg', 'GetSpatialDiscountMask', 'RandomDownSampling',
'GenerateTrimapWithDistTransform', 'TransformTrimap',
'GenerateCoordinateAndCell', 'MirrorSequenceExtend'
'GenerateCoordinateAndCell', 'MirrorSequence'
]
2 changes: 1 addition & 1 deletion mmedit/datasets/pipelines/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -898,7 +898,7 @@ def __repr__(self):


@PIPELINES.register_module()
class MirrorSequenceExtend:
class MirrorSequence:
"""Extend short sequences (e.g. Vimeo-90K) by mirroring the sequences
Given a sequence with N frames (x1, ..., xN), extend the sequence to
Expand Down
15 changes: 7 additions & 8 deletions tests/test_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from mmedit.datasets.pipelines import (BinarizeImage, Flip,
GenerateFrameIndices,
GenerateFrameIndiceswithPadding,
MirrorSequenceExtend, Pad, RandomAffine,
MirrorSequence, Pad, RandomAffine,
RandomJitter, RandomMaskDilation,
RandomTransposeHW, Resize,
TemporalReverse)
Expand Down Expand Up @@ -624,14 +624,14 @@ def test_temporal_reverse(self):
np.testing.assert_almost_equal(results['lq'][1], img_lq2)
np.testing.assert_almost_equal(results['gt'][0], img_gt)

def mirror_sequence_extend(self):
def mirror_sequence(self):
lqs = [np.random.rand(4, 4, 3) for _ in range(0, 5)]
gts = [np.random.rand(16, 16, 3) for _ in range(0, 5)]

target_keys = ['lq', 'gt']
mirror_sequence_extend = MirrorSequenceExtend(keys=['lq', 'gt'])
mirror_sequence = MirrorSequence(keys=['lq', 'gt'])
results = dict(lq=lqs, gt=gts)
results = mirror_sequence_extend(results)
results = mirror_sequence(results)

assert self.check_keys_contain(results.keys(), target_keys)
for i in range(0, 5):
Expand All @@ -640,11 +640,10 @@ def mirror_sequence_extend(self):
np.testing.assert_almost_equal(results['gt'][i],
results['gt'][-i - 1])

assert repr(mirror_sequence_extend
) == mirror_sequence_extend.__class__.__name__ + (
f"(keys=['lq', 'gt'])")
assert repr(mirror_sequence) == mirror_sequence.__class__.__name__ + (
f"(keys=['lq', 'gt'])")

# each key should contain a list of nparray
with pytest.raises(TypeError):
results = dict(lq=0, gt=gts)
mirror_sequence_extend(results)
mirror_sequence(results)

0 comments on commit 4421f51

Please sign in to comment.