Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CropLike #299

Merged
merged 4 commits into from
May 12, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions mmedit/datasets/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
RandomTransposeHW, Resize, TemporalReverse)
from .compose import Compose
from .crop import (Crop, CropAroundCenter, CropAroundFg, CropAroundUnknown,
FixedCrop, ModCrop, PairedRandomCrop)
FixedCrop, ModCrop, ModifySize, PairedRandomCrop)
from .down_sampling import RandomDownSampling
from .formating import (Collect, FormatTrimap, GetMaskedImage, ImageToTensor,
ToTensor)
Expand All @@ -30,5 +30,6 @@
'LoadPairedImageFromFile', 'GenerateSoftSeg', 'GenerateSeg', 'PerturbBg',
'CropAroundFg', 'GetSpatialDiscountMask', 'RandomDownSampling',
'GenerateTrimapWithDistTransform', 'TransformTrimap',
'GenerateCoordinateAndCell', 'GenerateSegmentIndices', 'MirrorSequence'
'GenerateCoordinateAndCell', 'GenerateSegmentIndices', 'MirrorSequence',
'ModifySize'
]
61 changes: 61 additions & 0 deletions mmedit/datasets/pipelines/crop.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,3 +532,64 @@ def __call__(self, results):
raise ValueError(f'Wrong img ndim: {img.ndim}.')
results['gt'] = img
return results


@PIPELINES.register_module()
class ModifySize:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Random thought: how about rename it to "TopLeftCrop"?

"""Modify size.

Modify the size of image by cropping or padding. Align upper-left.

Args:
target_key (str): The target key.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is a "target key"?

source_key (str | None): The source key. Default: None.
target_size (Tuple[int] | None): The target size. [h, w]
Default: None.

The priority of getting 'target size' is:
1, results[source_key].shape
2, target_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a note for the required, added and modified keys.

"""

def __init__(self, target_key, source_key=None, target_size=None):

assert (source_key or target_size), 'Need source_key or target_size'
self.target_key = target_key
self.source_key = source_key
if isinstance(target_size, int):
self.target_size = (target_size, target_size)
else:
self.target_size = target_size

def __call__(self, results):
"""Call function.

Args:
results (dict): A dict containing the necessary information and
data for augmentation.

Returns:
dict: A dict containing the processed data and information.
"""
if self.source_key and self.source_key in results:
size = results[self.source_key].shape
elif self.target_size:
size = self.target_size
else:
raise ValueError('Need effective target_key or target_size')
old_image = results[self.target_key]
old_size = old_image.shape
h, w = old_size[:2]
new_size = size[:2] + old_size[2:]
h_cover, w_cover = min(h, size[0]), min(w, size[1])

format_image = np.zeros(new_size, dtype=old_image.dtype)
format_image[:h_cover, :w_cover] = old_image[:h_cover, :w_cover]
results[self.target_key] = format_image

return results

def __repr__(self):
return self.__class__.__name__ + (f' target_key={self.target_key}, ' +
f'source_key={self.source_key}, ' +
f'target_size={self.target_size}')
44 changes: 43 additions & 1 deletion tests/test_crop.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from mmedit.datasets.pipelines import (Crop, CropAroundCenter, CropAroundFg,
CropAroundUnknown, FixedCrop, ModCrop,
PairedRandomCrop)
ModifySize, PairedRandomCrop)


class TestAugmentations:
Expand Down Expand Up @@ -405,3 +405,45 @@ def test_paired_random_crop(self):
assert v.shape == (32, 32, 3)
np.testing.assert_almost_equal(results['gt'][0], results['gt'][1])
np.testing.assert_almost_equal(results['lq'][0], results['lq'][1])


def test_generate_by_resize():
img = np.uint8(np.random.randn(480, 640, 3) * 255)
img_ref = np.uint8(np.random.randn(512, 512, 3) * 255)

inputs = dict(gt=img, ref=img_ref)
modify_size = ModifySize(target_key='gt', source_key='ref')
results = modify_size(inputs)
assert set(list(results.keys())) == set(['gt', 'ref'])
assert repr(modify_size) == (
modify_size.__class__.__name__ +
f' target_key={modify_size.target_key}, ' +
f'source_key={modify_size.source_key}, ' +
f'target_size={modify_size.target_size}')
assert results['gt'].shape == (512, 512, 3)
sum_diff = np.sum(abs(results['gt'][:480, :512] - img[:480, :512]))
assert sum_diff < 1e-6

inputs = dict(gt=img)
modify_size = ModifySize(target_key='gt', target_size=(300, 700))
results = modify_size(inputs)
assert set(list(results.keys())) == set(['gt'])
assert results['gt'].shape == (300, 700, 3)
sum_diff = np.sum(abs(results['gt'][:300, :640] - img[:300, :640]))
assert sum_diff < 1e-6

inputs = dict(gt=img, ref=img_ref[:, :, 0])
modify_size = ModifySize(target_key='gt', source_key='ref')
results = modify_size(inputs)
assert set(list(results.keys())) == set(['gt', 'ref'])
assert results['gt'].shape == (512, 512, 3)
sum_diff = np.sum(abs(results['gt'][:480, :512] - img[:480, :512]))
assert sum_diff < 1e-6

inputs = dict(gt=img[:, :, 0], ref=img_ref)
modify_size = ModifySize(target_key='gt', source_key='ref')
results = modify_size(inputs)
assert set(list(results.keys())) == set(['gt', 'ref'])
assert results['gt'].shape == (512, 512)
sum_diff = np.sum(abs(results['gt'][:480, :512] - img[:480, :512, 0]))
assert sum_diff < 1e-6