From a0eaf2264f23825d4d002844998f8fd3a55ce808 Mon Sep 17 00:00:00 2001 From: ys-li <56712176+Yshuo-Li@users.noreply.github.com> Date: Tue, 18 May 2021 15:11:05 +0800 Subject: [PATCH] [Fix] Merge resize and sr_resize. (#310) * [Fix] Merge resize and sr_resize. * Fix * Fix Co-authored-by: liyinshuo --- mmedit/datasets/pipelines/__init__.py | 4 +- mmedit/datasets/pipelines/augmentation.py | 36 +++++++--- .../{sr_resize.py => random_down_sampling.py} | 70 ------------------- tests/test_augmentation.py | 17 +++-- ...resize.py => test_random_down_sampling.py} | 23 +----- 5 files changed, 42 insertions(+), 108 deletions(-) rename mmedit/datasets/pipelines/{sr_resize.py => random_down_sampling.py} (64%) rename tests/{test_sr_resize.py => test_random_down_sampling.py} (59%) diff --git a/mmedit/datasets/pipelines/__init__.py b/mmedit/datasets/pipelines/__init__.py index 4868fa0c87..a135734401 100644 --- a/mmedit/datasets/pipelines/__init__.py +++ b/mmedit/datasets/pipelines/__init__.py @@ -16,7 +16,7 @@ GenerateTrimap, GenerateTrimapWithDistTransform, MergeFgAndBg, PerturbBg, TransformTrimap) from .normalization import Normalize, RescaleToZeroOne -from .sr_resize import RandomDownSampling, SRResize +from .random_down_sampling import RandomDownSampling __all__ = [ 'Collect', 'FormatTrimap', 'LoadImageFromFile', 'LoadMask', @@ -29,7 +29,7 @@ 'GenerateFrameIndices', 'GenerateFrameIndiceswithPadding', 'FixedCrop', 'LoadPairedImageFromFile', 'GenerateSoftSeg', 'GenerateSeg', 'PerturbBg', 'CropAroundFg', 'GetSpatialDiscountMask', 'RandomDownSampling', - 'GenerateTrimapWithDistTransform', 'TransformTrimap', 'SRResize', + 'GenerateTrimapWithDistTransform', 'TransformTrimap', 'GenerateCoordinateAndCell', 'GenerateSegmentIndices', 'MirrorSequence', 'CropLike' ] diff --git a/mmedit/datasets/pipelines/augmentation.py b/mmedit/datasets/pipelines/augmentation.py index a5bbf83fa8..0071295de7 100644 --- a/mmedit/datasets/pipelines/augmentation.py +++ b/mmedit/datasets/pipelines/augmentation.py @@ -33,7 +33,7 @@ class Resize: keys (list[str]): The images to be resized. scale (float | Tuple[int]): If scale is Tuple(int), target spatial size (h, w). Otherwise, target spatial size is scaled by input - size. If any of scale is -1, we will rescale short edge. + size. Note that when it is used, `size_factor` and `max_size` are useless. Default: None keep_ratio (bool): If set to True, images will be resized without @@ -50,6 +50,12 @@ class Resize: interpolation (str): Algorithm used for interpolation: "nearest" | "bilinear" | "bicubic" | "area" | "lanczos". Default: "bilinear". + backend (str | None): The image resize backend type. Options are `cv2`, + `pillow`, `None`. If backend is None, the global imread_backend + specified by ``mmcv.use_backend()`` will be used. + Default: None. + output_keys (list[str] | None): The resized images. Default: None + Note that if it is not `None`, its length shuld be equal to keys. """ def __init__(self, @@ -58,8 +64,14 @@ def __init__(self, keep_ratio=False, size_factor=None, max_size=None, - interpolation='bilinear'): + interpolation='bilinear', + backend=None, + output_keys=None): assert keys, 'Keys should not be empty.' + if output_keys: + assert len(output_keys) == len(keys) + else: + output_keys = keys if size_factor: assert scale is None, ('When size_factor is used, scale should ', f'be None. But received {scale}.') @@ -83,11 +95,13 @@ def __init__(self, f'Scale must be None, float or tuple of int, but got ' f'{type(scale)}.') self.keys = keys + self.output_keys = output_keys self.scale = scale self.size_factor = size_factor self.max_size = max_size self.keep_ratio = keep_ratio self.interpolation = interpolation + self.backend = backend def _resize(self, img): if self.keep_ratio: @@ -95,13 +109,15 @@ def _resize(self, img): img, self.scale, return_scale=True, - interpolation=self.interpolation) + interpolation=self.interpolation, + backend=self.backend) else: img, w_scale, h_scale = mmcv.imresize( img, self.scale, return_scale=True, - interpolation=self.interpolation) + interpolation=self.interpolation, + backend=self.backend) self.scale_factor = np.array((w_scale, h_scale), dtype=np.float32) return img @@ -125,21 +141,23 @@ def __call__(self, results): new_w = min(self.max_size - (self.max_size % self.size_factor), new_w) self.scale = (new_w, new_h) - for key in self.keys: - results[key] = self._resize(results[key]) - if len(results[key].shape) == 2: - results[key] = np.expand_dims(results[key], axis=2) + for key, out_key in zip(self.keys, self.output_keys): + results[out_key] = self._resize(results[key]) + if len(results[out_key].shape) == 2: + results[out_key] = np.expand_dims(results[out_key], axis=2) results['scale_factor'] = self.scale_factor results['keep_ratio'] = self.keep_ratio results['interpolation'] = self.interpolation + results['backend'] = self.backend return results def __repr__(self): repr_str = self.__class__.__name__ repr_str += ( - f'(keys={self.keys}, scale={self.scale}, ' + f'(keys={self.keys}, output_keys={self.output_keys}, ' + f'scale={self.scale}, ' f'keep_ratio={self.keep_ratio}, size_factor={self.size_factor}, ' f'max_size={self.max_size},interpolation={self.interpolation})') return repr_str diff --git a/mmedit/datasets/pipelines/sr_resize.py b/mmedit/datasets/pipelines/random_down_sampling.py similarity index 64% rename from mmedit/datasets/pipelines/sr_resize.py rename to mmedit/datasets/pipelines/random_down_sampling.py index 5143dc04ea..f1187f7f76 100644 --- a/mmedit/datasets/pipelines/sr_resize.py +++ b/mmedit/datasets/pipelines/random_down_sampling.py @@ -7,76 +7,6 @@ from ..registry import PIPELINES -@PIPELINES.register_module() -class SRResize: - """Resize image by a scale, including upsampling and downsampling. - - Image will be loaded from the input_key and the result will be saved - in the specified output_key (can equal to input_key). - - Args: - scale (float): The resampling scale. scale > 0. - scale > 1: upsampling. - scale < 1: downsampling. - input_key (str): The input key. - output_key (str): The output key. - interpolation (str): Interpolation method, accepted values are - "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2' - backend, "nearest", "bilinear", "bicubic", "box", "lanczos", - "hamming" for 'pillow' backend. - Default: "bicubic". - backend (str | None): The image resize backend type. Options are `cv2`, - `pillow`, `None`. If backend is None, the global imread_backend - specified by ``mmcv.use_backend()`` will be used. - Default: "pillow". - """ - - def __init__(self, - scale, - input_key, - output_key, - interpolation='bicubic', - backend='pillow'): - self.scale = scale - self.input_key = input_key - self.output_key = output_key - self.interpolation = interpolation - self.backend = backend - - def __call__(self, results): - """Call function. - - Args: - results (dict): A dict containing the necessary information and - data for augmentation. self.input_key is required. - - Returns: - dict: A dict containing the processed data and information. - supplement self.output_key to keys. - """ - assert self.input_key in results, f'Cannot find {self.input_key}.' - image_in = results[self.input_key] - h_in, w_in = image_in.shape[:2] - h_out = math.floor(h_in * self.scale + 1e-9) - w_out = math.floor(w_in * self.scale + 1e-9) - image_out = resize_fn(image_in, (w_out, h_out), self.interpolation, - self.backend) - - results[self.output_key] = image_out - - return results - - def __repr__(self): - repr_str = self.__class__.__name__ - repr_str += (f' scale={self.scale}, ' - f'input_key={self.input_key}, ' - f'output_key={self.output_key}, ' - f'interpolation={self.interpolation}, ' - f'backend={self.backend}') - - return repr_str - - @PIPELINES.register_module() class RandomDownSampling: """Generate LQ image from GT (and crop), which will randomly pick a scale. diff --git a/tests/test_augmentation.py b/tests/test_augmentation.py index 18113e9b9d..9affd937c7 100644 --- a/tests/test_augmentation.py +++ b/tests/test_augmentation.py @@ -461,14 +461,21 @@ def test_resize(self): assert results['gt_img'].shape[:2] == (128, 128) # test input with shape (256, 256) - results = dict(gt_img=self.results['img'][..., 0].copy()) - resize = Resize(['gt_img'], scale=(128, 128), keep_ratio=False) + results = dict(gt_img=self.results['img'][..., 0].copy(), alpha=alpha) + resize = Resize(['gt_img', 'alpha'], + scale=(128, 128), + keep_ratio=False, + output_keys=['lq_img', 'beta']) results = resize(results) - assert results['gt_img'].shape == (128, 128, 1) + assert results['gt_img'].shape == (256, 256) + assert results['lq_img'].shape == (128, 128, 1) + assert results['alpha'].shape == (240, 320) + assert results['beta'].shape == (128, 128, 1) name_ = str(resize_keep_ratio) assert name_ == resize_keep_ratio.__class__.__name__ + ( - f"(keys={['gt_img']}, scale=(128, 128), " + "(keys=['gt_img'], output_keys=['gt_img'], " + 'scale=(128, 128), ' f'keep_ratio={False}, size_factor=None, ' 'max_size=None,interpolation=bilinear)') @@ -673,7 +680,7 @@ def mirror_sequence(self): results['gt'][-i - 1]) assert repr(mirror_sequence) == mirror_sequence.__class__.__name__ + ( - f"(keys=['lq', 'gt'])") + "(keys=['lq', 'gt'])") # each key should contain a list of nparray with pytest.raises(TypeError): diff --git a/tests/test_sr_resize.py b/tests/test_random_down_sampling.py similarity index 59% rename from tests/test_sr_resize.py rename to tests/test_random_down_sampling.py index c8ef599bff..7b6749026d 100644 --- a/tests/test_sr_resize.py +++ b/tests/test_random_down_sampling.py @@ -1,6 +1,6 @@ import numpy as np -from mmedit.datasets.pipelines import RandomDownSampling, SRResize +from mmedit.datasets.pipelines import RandomDownSampling def test_random_down_sampling(): @@ -31,24 +31,3 @@ def test_random_down_sampling(): f'patch_size={down_sampling2.patch_size}, ' + f'interpolation={down_sampling2.interpolation}, ' + f'backend={down_sampling2.backend}') - - -def test_sr_resize(): - img = np.uint8(np.random.randn(480, 640, 3) * 255) - inputs = dict(gt=img) - re_size = SRResize(scale=1 / 4, input_key='gt', output_key='lq') - results = re_size(inputs) - assert set(list(results.keys())) == set(['gt', 'lq']) - assert results['lq'].shape == (120, 160, 3) - assert repr(re_size) == ( - re_size.__class__.__name__ + f' scale={re_size.scale}, ' + - f'input_key={re_size.input_key}, ' + - f'output_key={re_size.output_key}, ' + - f'interpolation={re_size.interpolation}, ' + - f'backend={re_size.backend}') - - inputs = dict(gt=img) - re_size = SRResize(scale=2, input_key='gt', output_key='gt') - results = re_size(inputs) - assert set(list(results.keys())) == set(['gt']) - assert results['gt'].shape == (960, 1280, 3)