Skip to content

Commit

Permalink
[Fix] Merge resize and sr_resize. (#310)
Browse files Browse the repository at this point in the history
* [Fix] Merge resize and sr_resize.

* Fix

* Fix

Co-authored-by: liyinshuo <[email protected]>
  • Loading branch information
Yshuo-Li and liyinshuo authored May 18, 2021
1 parent ee2d010 commit a0eaf22
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 108 deletions.
4 changes: 2 additions & 2 deletions mmedit/datasets/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
GenerateTrimap, GenerateTrimapWithDistTransform,
MergeFgAndBg, PerturbBg, TransformTrimap)
from .normalization import Normalize, RescaleToZeroOne
from .sr_resize import RandomDownSampling, SRResize
from .random_down_sampling import RandomDownSampling

__all__ = [
'Collect', 'FormatTrimap', 'LoadImageFromFile', 'LoadMask',
Expand All @@ -29,7 +29,7 @@
'GenerateFrameIndices', 'GenerateFrameIndiceswithPadding', 'FixedCrop',
'LoadPairedImageFromFile', 'GenerateSoftSeg', 'GenerateSeg', 'PerturbBg',
'CropAroundFg', 'GetSpatialDiscountMask', 'RandomDownSampling',
'GenerateTrimapWithDistTransform', 'TransformTrimap', 'SRResize',
'GenerateTrimapWithDistTransform', 'TransformTrimap',
'GenerateCoordinateAndCell', 'GenerateSegmentIndices', 'MirrorSequence',
'CropLike'
]
36 changes: 27 additions & 9 deletions mmedit/datasets/pipelines/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class Resize:
keys (list[str]): The images to be resized.
scale (float | Tuple[int]): If scale is Tuple(int), target spatial
size (h, w). Otherwise, target spatial size is scaled by input
size. If any of scale is -1, we will rescale short edge.
size.
Note that when it is used, `size_factor` and `max_size` are
useless. Default: None
keep_ratio (bool): If set to True, images will be resized without
Expand All @@ -50,6 +50,12 @@ class Resize:
interpolation (str): Algorithm used for interpolation:
"nearest" | "bilinear" | "bicubic" | "area" | "lanczos".
Default: "bilinear".
backend (str | None): The image resize backend type. Options are `cv2`,
`pillow`, `None`. If backend is None, the global imread_backend
specified by ``mmcv.use_backend()`` will be used.
Default: None.
output_keys (list[str] | None): The resized images. Default: None
Note that if it is not `None`, its length shuld be equal to keys.
"""

def __init__(self,
Expand All @@ -58,8 +64,14 @@ def __init__(self,
keep_ratio=False,
size_factor=None,
max_size=None,
interpolation='bilinear'):
interpolation='bilinear',
backend=None,
output_keys=None):
assert keys, 'Keys should not be empty.'
if output_keys:
assert len(output_keys) == len(keys)
else:
output_keys = keys
if size_factor:
assert scale is None, ('When size_factor is used, scale should ',
f'be None. But received {scale}.')
Expand All @@ -83,25 +95,29 @@ def __init__(self,
f'Scale must be None, float or tuple of int, but got '
f'{type(scale)}.')
self.keys = keys
self.output_keys = output_keys
self.scale = scale
self.size_factor = size_factor
self.max_size = max_size
self.keep_ratio = keep_ratio
self.interpolation = interpolation
self.backend = backend

def _resize(self, img):
if self.keep_ratio:
img, self.scale_factor = mmcv.imrescale(
img,
self.scale,
return_scale=True,
interpolation=self.interpolation)
interpolation=self.interpolation,
backend=self.backend)
else:
img, w_scale, h_scale = mmcv.imresize(
img,
self.scale,
return_scale=True,
interpolation=self.interpolation)
interpolation=self.interpolation,
backend=self.backend)
self.scale_factor = np.array((w_scale, h_scale), dtype=np.float32)
return img

Expand All @@ -125,21 +141,23 @@ def __call__(self, results):
new_w = min(self.max_size - (self.max_size % self.size_factor),
new_w)
self.scale = (new_w, new_h)
for key in self.keys:
results[key] = self._resize(results[key])
if len(results[key].shape) == 2:
results[key] = np.expand_dims(results[key], axis=2)
for key, out_key in zip(self.keys, self.output_keys):
results[out_key] = self._resize(results[key])
if len(results[out_key].shape) == 2:
results[out_key] = np.expand_dims(results[out_key], axis=2)

results['scale_factor'] = self.scale_factor
results['keep_ratio'] = self.keep_ratio
results['interpolation'] = self.interpolation
results['backend'] = self.backend

return results

def __repr__(self):
repr_str = self.__class__.__name__
repr_str += (
f'(keys={self.keys}, scale={self.scale}, '
f'(keys={self.keys}, output_keys={self.output_keys}, '
f'scale={self.scale}, '
f'keep_ratio={self.keep_ratio}, size_factor={self.size_factor}, '
f'max_size={self.max_size},interpolation={self.interpolation})')
return repr_str
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,76 +7,6 @@
from ..registry import PIPELINES


@PIPELINES.register_module()
class SRResize:
"""Resize image by a scale, including upsampling and downsampling.
Image will be loaded from the input_key and the result will be saved
in the specified output_key (can equal to input_key).
Args:
scale (float): The resampling scale. scale > 0.
scale > 1: upsampling.
scale < 1: downsampling.
input_key (str): The input key.
output_key (str): The output key.
interpolation (str): Interpolation method, accepted values are
"nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
backend, "nearest", "bilinear", "bicubic", "box", "lanczos",
"hamming" for 'pillow' backend.
Default: "bicubic".
backend (str | None): The image resize backend type. Options are `cv2`,
`pillow`, `None`. If backend is None, the global imread_backend
specified by ``mmcv.use_backend()`` will be used.
Default: "pillow".
"""

def __init__(self,
scale,
input_key,
output_key,
interpolation='bicubic',
backend='pillow'):
self.scale = scale
self.input_key = input_key
self.output_key = output_key
self.interpolation = interpolation
self.backend = backend

def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation. self.input_key is required.
Returns:
dict: A dict containing the processed data and information.
supplement self.output_key to keys.
"""
assert self.input_key in results, f'Cannot find {self.input_key}.'
image_in = results[self.input_key]
h_in, w_in = image_in.shape[:2]
h_out = math.floor(h_in * self.scale + 1e-9)
w_out = math.floor(w_in * self.scale + 1e-9)
image_out = resize_fn(image_in, (w_out, h_out), self.interpolation,
self.backend)

results[self.output_key] = image_out

return results

def __repr__(self):
repr_str = self.__class__.__name__
repr_str += (f' scale={self.scale}, '
f'input_key={self.input_key}, '
f'output_key={self.output_key}, '
f'interpolation={self.interpolation}, '
f'backend={self.backend}')

return repr_str


@PIPELINES.register_module()
class RandomDownSampling:
"""Generate LQ image from GT (and crop), which will randomly pick a scale.
Expand Down
17 changes: 12 additions & 5 deletions tests/test_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,14 +461,21 @@ def test_resize(self):
assert results['gt_img'].shape[:2] == (128, 128)

# test input with shape (256, 256)
results = dict(gt_img=self.results['img'][..., 0].copy())
resize = Resize(['gt_img'], scale=(128, 128), keep_ratio=False)
results = dict(gt_img=self.results['img'][..., 0].copy(), alpha=alpha)
resize = Resize(['gt_img', 'alpha'],
scale=(128, 128),
keep_ratio=False,
output_keys=['lq_img', 'beta'])
results = resize(results)
assert results['gt_img'].shape == (128, 128, 1)
assert results['gt_img'].shape == (256, 256)
assert results['lq_img'].shape == (128, 128, 1)
assert results['alpha'].shape == (240, 320)
assert results['beta'].shape == (128, 128, 1)

name_ = str(resize_keep_ratio)
assert name_ == resize_keep_ratio.__class__.__name__ + (
f"(keys={['gt_img']}, scale=(128, 128), "
"(keys=['gt_img'], output_keys=['gt_img'], "
'scale=(128, 128), '
f'keep_ratio={False}, size_factor=None, '
'max_size=None,interpolation=bilinear)')

Expand Down Expand Up @@ -673,7 +680,7 @@ def mirror_sequence(self):
results['gt'][-i - 1])

assert repr(mirror_sequence) == mirror_sequence.__class__.__name__ + (
f"(keys=['lq', 'gt'])")
"(keys=['lq', 'gt'])")

# each key should contain a list of nparray
with pytest.raises(TypeError):
Expand Down
23 changes: 1 addition & 22 deletions tests/test_sr_resize.py → tests/test_random_down_sampling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

from mmedit.datasets.pipelines import RandomDownSampling, SRResize
from mmedit.datasets.pipelines import RandomDownSampling


def test_random_down_sampling():
Expand Down Expand Up @@ -31,24 +31,3 @@ def test_random_down_sampling():
f'patch_size={down_sampling2.patch_size}, ' +
f'interpolation={down_sampling2.interpolation}, ' +
f'backend={down_sampling2.backend}')


def test_sr_resize():
img = np.uint8(np.random.randn(480, 640, 3) * 255)
inputs = dict(gt=img)
re_size = SRResize(scale=1 / 4, input_key='gt', output_key='lq')
results = re_size(inputs)
assert set(list(results.keys())) == set(['gt', 'lq'])
assert results['lq'].shape == (120, 160, 3)
assert repr(re_size) == (
re_size.__class__.__name__ + f' scale={re_size.scale}, ' +
f'input_key={re_size.input_key}, ' +
f'output_key={re_size.output_key}, ' +
f'interpolation={re_size.interpolation}, ' +
f'backend={re_size.backend}')

inputs = dict(gt=img)
re_size = SRResize(scale=2, input_key='gt', output_key='gt')
results = re_size(inputs)
assert set(list(results.keys())) == set(['gt'])
assert results['gt'].shape == (960, 1280, 3)

0 comments on commit a0eaf22

Please sign in to comment.