Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Merge resize and sr_resize. #310

Merged
merged 3 commits into from
May 18, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mmedit/datasets/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
GenerateTrimap, GenerateTrimapWithDistTransform,
MergeFgAndBg, PerturbBg, TransformTrimap)
from .normalization import Normalize, RescaleToZeroOne
from .sr_resize import RandomDownSampling, SRResize
from .random_down_sampling import RandomDownSampling

__all__ = [
'Collect', 'FormatTrimap', 'LoadImageFromFile', 'LoadMask',
Expand All @@ -29,7 +29,7 @@
'GenerateFrameIndices', 'GenerateFrameIndiceswithPadding', 'FixedCrop',
'LoadPairedImageFromFile', 'GenerateSoftSeg', 'GenerateSeg', 'PerturbBg',
'CropAroundFg', 'GetSpatialDiscountMask', 'RandomDownSampling',
'GenerateTrimapWithDistTransform', 'TransformTrimap', 'SRResize',
'GenerateTrimapWithDistTransform', 'TransformTrimap',
'GenerateCoordinateAndCell', 'GenerateSegmentIndices', 'MirrorSequence',
'CropLike'
]
36 changes: 27 additions & 9 deletions mmedit/datasets/pipelines/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class Resize:
keys (list[str]): The images to be resized.
scale (float | Tuple[int]): If scale is Tuple(int), target spatial
size (h, w). Otherwise, target spatial size is scaled by input
size. If any of scale is -1, we will rescale short edge.
size.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reason for the deletion?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code doesn't support scale=-1
if isinstance(scale, float): if scale <= 0: raise ValueError(f'Invalid scale {scale}, must be positive.')

Note that when it is used, `size_factor` and `max_size` are
useless. Default: None
keep_ratio (bool): If set to True, images will be resized without
Expand All @@ -50,6 +50,12 @@ class Resize:
interpolation (str): Algorithm used for interpolation:
"nearest" | "bilinear" | "bicubic" | "area" | "lanczos".
Default: "bilinear".
backend (str | None): The image resize backend type. Options are `cv2`,
`pillow`, `None`. If backend is None, the global imread_backend
specified by ``mmcv.use_backend()`` will be used.
Default: "pillow".
Yshuo-Li marked this conversation as resolved.
Show resolved Hide resolved
output_keys (list[str] | None): The resized images. Default: None
Note that if it is not `None`, its length shuld be equal to keys.
"""

def __init__(self,
Expand All @@ -58,8 +64,14 @@ def __init__(self,
keep_ratio=False,
size_factor=None,
max_size=None,
interpolation='bilinear'):
interpolation='bilinear',
backend=None,
output_keys=None):
assert keys, 'Keys should not be empty.'
if output_keys:
assert len(output_keys) == len(keys)
else:
output_keys = keys
if size_factor:
assert scale is None, ('When size_factor is used, scale should ',
f'be None. But received {scale}.')
Expand All @@ -83,25 +95,29 @@ def __init__(self,
f'Scale must be None, float or tuple of int, but got '
f'{type(scale)}.')
self.keys = keys
self.output_keys = output_keys
self.scale = scale
self.size_factor = size_factor
self.max_size = max_size
self.keep_ratio = keep_ratio
self.interpolation = interpolation
self.backend = backend

def _resize(self, img):
if self.keep_ratio:
img, self.scale_factor = mmcv.imrescale(
img,
self.scale,
return_scale=True,
interpolation=self.interpolation)
interpolation=self.interpolation,
backend=self.backend)
else:
img, w_scale, h_scale = mmcv.imresize(
img,
self.scale,
return_scale=True,
interpolation=self.interpolation)
interpolation=self.interpolation,
backend=self.backend)
self.scale_factor = np.array((w_scale, h_scale), dtype=np.float32)
return img

Expand All @@ -125,21 +141,23 @@ def __call__(self, results):
new_w = min(self.max_size - (self.max_size % self.size_factor),
new_w)
self.scale = (new_w, new_h)
for key in self.keys:
results[key] = self._resize(results[key])
if len(results[key].shape) == 2:
results[key] = np.expand_dims(results[key], axis=2)
for key, out_key in zip(self.keys, self.output_keys):
results[out_key] = self._resize(results[key])
if len(results[out_key].shape) == 2:
results[out_key] = np.expand_dims(results[out_key], axis=2)

results['scale_factor'] = self.scale_factor
results['keep_ratio'] = self.keep_ratio
results['interpolation'] = self.interpolation
results['backend'] = self.backend

return results

def __repr__(self):
repr_str = self.__class__.__name__
repr_str += (
f'(keys={self.keys}, scale={self.scale}, '
f'(keys={self.keys}, output_keys={self.output_keys}, '
f'scale={self.scale}, '
f'keep_ratio={self.keep_ratio}, size_factor={self.size_factor}, '
f'max_size={self.max_size},interpolation={self.interpolation})')
return repr_str
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,76 +7,6 @@
from ..registry import PIPELINES


@PIPELINES.register_module()
class SRResize:
"""Resize image by a scale, including upsampling and downsampling.

Image will be loaded from the input_key and the result will be saved
in the specified output_key (can equal to input_key).

Args:
scale (float): The resampling scale. scale > 0.
scale > 1: upsampling.
scale < 1: downsampling.
input_key (str): The input key.
output_key (str): The output key.
interpolation (str): Interpolation method, accepted values are
"nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
backend, "nearest", "bilinear", "bicubic", "box", "lanczos",
"hamming" for 'pillow' backend.
Default: "bicubic".
backend (str | None): The image resize backend type. Options are `cv2`,
`pillow`, `None`. If backend is None, the global imread_backend
specified by ``mmcv.use_backend()`` will be used.
Default: "pillow".
"""

def __init__(self,
scale,
input_key,
output_key,
interpolation='bicubic',
backend='pillow'):
self.scale = scale
self.input_key = input_key
self.output_key = output_key
self.interpolation = interpolation
self.backend = backend

def __call__(self, results):
"""Call function.

Args:
results (dict): A dict containing the necessary information and
data for augmentation. self.input_key is required.

Returns:
dict: A dict containing the processed data and information.
supplement self.output_key to keys.
"""
assert self.input_key in results, f'Cannot find {self.input_key}.'
image_in = results[self.input_key]
h_in, w_in = image_in.shape[:2]
h_out = math.floor(h_in * self.scale + 1e-9)
w_out = math.floor(w_in * self.scale + 1e-9)
image_out = resize_fn(image_in, (w_out, h_out), self.interpolation,
self.backend)

results[self.output_key] = image_out

return results

def __repr__(self):
repr_str = self.__class__.__name__
repr_str += (f' scale={self.scale}, '
f'input_key={self.input_key}, '
f'output_key={self.output_key}, '
f'interpolation={self.interpolation}, '
f'backend={self.backend}')

return repr_str


@PIPELINES.register_module()
class RandomDownSampling:
"""Generate LQ image from GT (and crop), which will randomly pick a scale.
Expand Down
23 changes: 18 additions & 5 deletions tests/test_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,14 +461,21 @@ def test_resize(self):
assert results['gt_img'].shape[:2] == (128, 128)

# test input with shape (256, 256)
results = dict(gt_img=self.results['img'][..., 0].copy())
resize = Resize(['gt_img'], scale=(128, 128), keep_ratio=False)
results = dict(gt_img=self.results['img'][..., 0].copy(), alpha=alpha)
resize = Resize(['gt_img', 'alpha'],
scale=(128, 128),
Yshuo-Li marked this conversation as resolved.
Show resolved Hide resolved
keep_ratio=False,
output_keys=['lq_img', 'beta'])
results = resize(results)
assert results['gt_img'].shape == (128, 128, 1)
assert results['gt_img'].shape == (256, 256)
assert results['lq_img'].shape == (128, 128, 1)
assert results['alpha'].shape == (240, 320)
assert results['beta'].shape == (128, 128, 1)

name_ = str(resize_keep_ratio)
assert name_ == resize_keep_ratio.__class__.__name__ + (
f"(keys={['gt_img']}, scale=(128, 128), "
f"(keys={['gt_img']}, output_keys={['gt_img']}, "
Yshuo-Li marked this conversation as resolved.
Show resolved Hide resolved
'scale=(128, 128), '
f'keep_ratio={False}, size_factor=None, '
'max_size=None,interpolation=bilinear)')

Expand Down Expand Up @@ -673,9 +680,15 @@ def mirror_sequence(self):
results['gt'][-i - 1])

assert repr(mirror_sequence) == mirror_sequence.__class__.__name__ + (
f"(keys=['lq', 'gt'])")
"(keys=['lq', 'gt'])")

# each key should contain a list of nparray
with pytest.raises(TypeError):
results = dict(lq=0, gt=gts)
mirror_sequence(results)


if __name__ == '__main__':
Yshuo-Li marked this conversation as resolved.
Show resolved Hide resolved
aug = TestAugmentations()
aug.setup_class()
aug.test_resize()
23 changes: 1 addition & 22 deletions tests/test_sr_resize.py → tests/test_random_down_sampling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

from mmedit.datasets.pipelines import RandomDownSampling, SRResize
from mmedit.datasets.pipelines import RandomDownSampling


def test_random_down_sampling():
Expand Down Expand Up @@ -31,24 +31,3 @@ def test_random_down_sampling():
f'patch_size={down_sampling2.patch_size}, ' +
f'interpolation={down_sampling2.interpolation}, ' +
f'backend={down_sampling2.backend}')


def test_sr_resize():
img = np.uint8(np.random.randn(480, 640, 3) * 255)
inputs = dict(gt=img)
re_size = SRResize(scale=1 / 4, input_key='gt', output_key='lq')
results = re_size(inputs)
assert set(list(results.keys())) == set(['gt', 'lq'])
assert results['lq'].shape == (120, 160, 3)
assert repr(re_size) == (
re_size.__class__.__name__ + f' scale={re_size.scale}, ' +
f'input_key={re_size.input_key}, ' +
f'output_key={re_size.output_key}, ' +
f'interpolation={re_size.interpolation}, ' +
f'backend={re_size.backend}')

inputs = dict(gt=img)
re_size = SRResize(scale=2, input_key='gt', output_key='gt')
results = re_size(inputs)
assert set(list(results.keys())) == set(['gt'])
assert results['gt'].shape == (960, 1280, 3)