diff --git a/mmedit/datasets/pipelines/__init__.py b/mmedit/datasets/pipelines/__init__.py index ac326c5de3..4868fa0c87 100644 --- a/mmedit/datasets/pipelines/__init__.py +++ b/mmedit/datasets/pipelines/__init__.py @@ -6,7 +6,6 @@ from .compose import Compose from .crop import (Crop, CropAroundCenter, CropAroundFg, CropAroundUnknown, CropLike, FixedCrop, ModCrop, PairedRandomCrop) -from .down_sampling import RandomDownSampling from .formating import (Collect, FormatTrimap, GetMaskedImage, ImageToTensor, ToTensor) from .generate_coordinate_and_cell import GenerateCoordinateAndCell @@ -17,6 +16,7 @@ GenerateTrimap, GenerateTrimapWithDistTransform, MergeFgAndBg, PerturbBg, TransformTrimap) from .normalization import Normalize, RescaleToZeroOne +from .sr_resize import RandomDownSampling, SRResize __all__ = [ 'Collect', 'FormatTrimap', 'LoadImageFromFile', 'LoadMask', @@ -29,7 +29,7 @@ 'GenerateFrameIndices', 'GenerateFrameIndiceswithPadding', 'FixedCrop', 'LoadPairedImageFromFile', 'GenerateSoftSeg', 'GenerateSeg', 'PerturbBg', 'CropAroundFg', 'GetSpatialDiscountMask', 'RandomDownSampling', - 'GenerateTrimapWithDistTransform', 'TransformTrimap', + 'GenerateTrimapWithDistTransform', 'TransformTrimap', 'SRResize', 'GenerateCoordinateAndCell', 'GenerateSegmentIndices', 'MirrorSequence', 'CropLike' ] diff --git a/mmedit/datasets/pipelines/down_sampling.py b/mmedit/datasets/pipelines/sr_resize.py similarity index 61% rename from mmedit/datasets/pipelines/down_sampling.py rename to mmedit/datasets/pipelines/sr_resize.py index 346a7ff024..5143dc04ea 100644 --- a/mmedit/datasets/pipelines/down_sampling.py +++ b/mmedit/datasets/pipelines/sr_resize.py @@ -7,6 +7,76 @@ from ..registry import PIPELINES +@PIPELINES.register_module() +class SRResize: + """Resize image by a scale, including upsampling and downsampling. + + Image will be loaded from the input_key and the result will be saved + in the specified output_key (can equal to input_key). + + Args: + scale (float): The resampling scale. scale > 0. + scale > 1: upsampling. + scale < 1: downsampling. + input_key (str): The input key. + output_key (str): The output key. + interpolation (str): Interpolation method, accepted values are + "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2' + backend, "nearest", "bilinear", "bicubic", "box", "lanczos", + "hamming" for 'pillow' backend. + Default: "bicubic". + backend (str | None): The image resize backend type. Options are `cv2`, + `pillow`, `None`. If backend is None, the global imread_backend + specified by ``mmcv.use_backend()`` will be used. + Default: "pillow". + """ + + def __init__(self, + scale, + input_key, + output_key, + interpolation='bicubic', + backend='pillow'): + self.scale = scale + self.input_key = input_key + self.output_key = output_key + self.interpolation = interpolation + self.backend = backend + + def __call__(self, results): + """Call function. + + Args: + results (dict): A dict containing the necessary information and + data for augmentation. self.input_key is required. + + Returns: + dict: A dict containing the processed data and information. + supplement self.output_key to keys. + """ + assert self.input_key in results, f'Cannot find {self.input_key}.' + image_in = results[self.input_key] + h_in, w_in = image_in.shape[:2] + h_out = math.floor(h_in * self.scale + 1e-9) + w_out = math.floor(w_in * self.scale + 1e-9) + image_out = resize_fn(image_in, (w_out, h_out), self.interpolation, + self.backend) + + results[self.output_key] = image_out + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += (f' scale={self.scale}, ' + f'input_key={self.input_key}, ' + f'output_key={self.output_key}, ' + f'interpolation={self.interpolation}, ' + f'backend={self.backend}') + + return repr_str + + @PIPELINES.register_module() class RandomDownSampling: """Generate LQ image from GT (and crop), which will randomly pick a scale. @@ -81,9 +151,11 @@ def __call__(self, results): def __repr__(self): repr_str = self.__class__.__name__ - repr_str += (f'scale_min={self.scale_min}, ' + repr_str += (f' scale_min={self.scale_min}, ' f'scale_max={self.scale_max}, ' - f'patch_size={self.patch_size}') + f'patch_size={self.patch_size}, ' + f'interpolation={self.interpolation}, ' + f'backend={self.backend}') return repr_str diff --git a/tests/test_down_sampling.py b/tests/test_down_sampling.py deleted file mode 100644 index d761f9aaa2..0000000000 --- a/tests/test_down_sampling.py +++ /dev/null @@ -1,29 +0,0 @@ -import numpy as np - -from mmedit.datasets.pipelines import RandomDownSampling - - -def test_down_sampling(): - img1 = np.uint8(np.random.randn(480, 640, 3) * 255) - inputs1 = dict(gt=img1) - down_sampling1 = RandomDownSampling( - scale_min=1, scale_max=4, patch_size=None) - results1 = down_sampling1(inputs1) - assert set(list(results1.keys())) == set(['gt', 'lq', 'scale']) - assert repr(down_sampling1) == ( - down_sampling1.__class__.__name__ + - f'scale_min={down_sampling1.scale_min}, ' + - f'scale_max={down_sampling1.scale_max}, ' + - f'patch_size={down_sampling1.patch_size}') - - img2 = np.uint8(np.random.randn(480, 640, 3) * 255) - inputs2 = dict(gt=img2) - down_sampling2 = RandomDownSampling( - scale_min=1, scale_max=4, patch_size=48) - results2 = down_sampling2(inputs2) - assert set(list(results2.keys())) == set(['gt', 'lq', 'scale']) - assert repr(down_sampling2) == ( - down_sampling2.__class__.__name__ + - f'scale_min={down_sampling2.scale_min}, ' + - f'scale_max={down_sampling2.scale_max}, ' + - f'patch_size={down_sampling2.patch_size}') diff --git a/tests/test_sr_resize.py b/tests/test_sr_resize.py new file mode 100644 index 0000000000..c8ef599bff --- /dev/null +++ b/tests/test_sr_resize.py @@ -0,0 +1,54 @@ +import numpy as np + +from mmedit.datasets.pipelines import RandomDownSampling, SRResize + + +def test_random_down_sampling(): + img1 = np.uint8(np.random.randn(480, 640, 3) * 255) + inputs1 = dict(gt=img1) + down_sampling1 = RandomDownSampling( + scale_min=1, scale_max=4, patch_size=None) + results1 = down_sampling1(inputs1) + assert set(list(results1.keys())) == set(['gt', 'lq', 'scale']) + assert repr(down_sampling1) == ( + down_sampling1.__class__.__name__ + + f' scale_min={down_sampling1.scale_min}, ' + + f'scale_max={down_sampling1.scale_max}, ' + + f'patch_size={down_sampling1.patch_size}, ' + + f'interpolation={down_sampling1.interpolation}, ' + + f'backend={down_sampling1.backend}') + + img2 = np.uint8(np.random.randn(480, 640, 3) * 255) + inputs2 = dict(gt=img2) + down_sampling2 = RandomDownSampling( + scale_min=1, scale_max=4, patch_size=48) + results2 = down_sampling2(inputs2) + assert set(list(results2.keys())) == set(['gt', 'lq', 'scale']) + assert repr(down_sampling2) == ( + down_sampling2.__class__.__name__ + + f' scale_min={down_sampling2.scale_min}, ' + + f'scale_max={down_sampling2.scale_max}, ' + + f'patch_size={down_sampling2.patch_size}, ' + + f'interpolation={down_sampling2.interpolation}, ' + + f'backend={down_sampling2.backend}') + + +def test_sr_resize(): + img = np.uint8(np.random.randn(480, 640, 3) * 255) + inputs = dict(gt=img) + re_size = SRResize(scale=1 / 4, input_key='gt', output_key='lq') + results = re_size(inputs) + assert set(list(results.keys())) == set(['gt', 'lq']) + assert results['lq'].shape == (120, 160, 3) + assert repr(re_size) == ( + re_size.__class__.__name__ + f' scale={re_size.scale}, ' + + f'input_key={re_size.input_key}, ' + + f'output_key={re_size.output_key}, ' + + f'interpolation={re_size.interpolation}, ' + + f'backend={re_size.backend}') + + inputs = dict(gt=img) + re_size = SRResize(scale=2, input_key='gt', output_key='gt') + results = re_size(inputs) + assert set(list(results.keys())) == set(['gt']) + assert results['gt'].shape == (960, 1280, 3)