Skip to content

Commit

Permalink
Add CropLike (#299)
Browse files Browse the repository at this point in the history
* add ModifySize

* Rename

* Rename and fix

* tiny fix

Co-authored-by: liyinshuo <[email protected]>
  • Loading branch information
Yshuo-Li and liyinshuo authored May 12, 2021
1 parent 41968e2 commit 613b260
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 4 deletions.
5 changes: 3 additions & 2 deletions mmedit/datasets/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
RandomTransposeHW, Resize, TemporalReverse)
from .compose import Compose
from .crop import (Crop, CropAroundCenter, CropAroundFg, CropAroundUnknown,
FixedCrop, ModCrop, PairedRandomCrop)
CropLike, FixedCrop, ModCrop, PairedRandomCrop)
from .down_sampling import RandomDownSampling
from .formating import (Collect, FormatTrimap, GetMaskedImage, ImageToTensor,
ToTensor)
Expand All @@ -30,5 +30,6 @@
'LoadPairedImageFromFile', 'GenerateSoftSeg', 'GenerateSeg', 'PerturbBg',
'CropAroundFg', 'GetSpatialDiscountMask', 'RandomDownSampling',
'GenerateTrimapWithDistTransform', 'TransformTrimap',
'GenerateCoordinateAndCell', 'GenerateSegmentIndices', 'MirrorSequence'
'GenerateCoordinateAndCell', 'GenerateSegmentIndices', 'MirrorSequence',
'CropLike'
]
47 changes: 47 additions & 0 deletions mmedit/datasets/pipelines/crop.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,3 +532,50 @@ def __call__(self, results):
raise ValueError(f'Wrong img ndim: {img.ndim}.')
results['gt'] = img
return results


@PIPELINES.register_module()
class CropLike:
"""Crop/pad the image in the target_key according to the size of image
in the reference_key .
Args:
target_key (str): The key needs to be cropped.
reference_key (str | None): The reference key, need its size.
Default: None.
"""

def __init__(self, target_key, reference_key=None):

assert reference_key and target_key
self.target_key = target_key
self.reference_key = reference_key

def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Require self.target_key and self.reference_key.
Returns:
dict: A dict containing the processed data and information.
Modify self.target_key.
"""
size = results[self.reference_key].shape
old_image = results[self.target_key]
old_size = old_image.shape
h, w = old_size[:2]
new_size = size[:2] + old_size[2:]
h_cover, w_cover = min(h, size[0]), min(w, size[1])

format_image = np.zeros(new_size, dtype=old_image.dtype)
format_image[:h_cover, :w_cover] = old_image[:h_cover, :w_cover]
results[self.target_key] = format_image

return results

def __repr__(self):
return (self.__class__.__name__ + f' target_key={self.target_key}, ' +
f'reference_key={self.reference_key}')
37 changes: 35 additions & 2 deletions tests/test_crop.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import pytest

from mmedit.datasets.pipelines import (Crop, CropAroundCenter, CropAroundFg,
CropAroundUnknown, FixedCrop, ModCrop,
PairedRandomCrop)
CropAroundUnknown, CropLike, FixedCrop,
ModCrop, PairedRandomCrop)


class TestAugmentations:
Expand Down Expand Up @@ -405,3 +405,36 @@ def test_paired_random_crop(self):
assert v.shape == (32, 32, 3)
np.testing.assert_almost_equal(results['gt'][0], results['gt'][1])
np.testing.assert_almost_equal(results['lq'][0], results['lq'][1])


def test_crop_like():
img = np.uint8(np.random.randn(480, 640, 3) * 255)
img_ref = np.uint8(np.random.randn(512, 512, 3) * 255)

inputs = dict(gt=img, ref=img_ref)
crop_like = CropLike(target_key='gt', reference_key='ref')
results = crop_like(inputs)
assert set(list(results.keys())) == set(['gt', 'ref'])
assert repr(crop_like) == (
crop_like.__class__.__name__ +
f' target_key={crop_like.target_key}, ' +
f'reference_key={crop_like.reference_key}')
assert results['gt'].shape == (512, 512, 3)
sum_diff = np.sum(abs(results['gt'][:480, :512] - img[:480, :512]))
assert sum_diff < 1e-6

inputs = dict(gt=img, ref=img_ref[:, :, 0])
crop_like = CropLike(target_key='gt', reference_key='ref')
results = crop_like(inputs)
assert set(list(results.keys())) == set(['gt', 'ref'])
assert results['gt'].shape == (512, 512, 3)
sum_diff = np.sum(abs(results['gt'][:480, :512] - img[:480, :512]))
assert sum_diff < 1e-6

inputs = dict(gt=img[:, :, 0], ref=img_ref)
crop_like = CropLike(target_key='gt', reference_key='ref')
results = crop_like(inputs)
assert set(list(results.keys())) == set(['gt', 'ref'])
assert results['gt'].shape == (512, 512)
sum_diff = np.sum(abs(results['gt'][:480, :512] - img[:480, :512, 0]))
assert sum_diff < 1e-6

0 comments on commit 613b260

Please sign in to comment.