Skip to content

Commit

Permalink
Add SRResize (#295)
Browse files Browse the repository at this point in the history
* Add ReSampling

* Add ReSampling

* Fix docstring

* fix

* add generate-by-resize

* Rename

* update docstring

Co-authored-by: liyinshuo <[email protected]>
  • Loading branch information
Yshuo-Li and liyinshuo authored May 12, 2021
1 parent 613b260 commit d2a4142
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 33 deletions.
4 changes: 2 additions & 2 deletions mmedit/datasets/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from .compose import Compose
from .crop import (Crop, CropAroundCenter, CropAroundFg, CropAroundUnknown,
CropLike, FixedCrop, ModCrop, PairedRandomCrop)
from .down_sampling import RandomDownSampling
from .formating import (Collect, FormatTrimap, GetMaskedImage, ImageToTensor,
ToTensor)
from .generate_coordinate_and_cell import GenerateCoordinateAndCell
Expand All @@ -17,6 +16,7 @@
GenerateTrimap, GenerateTrimapWithDistTransform,
MergeFgAndBg, PerturbBg, TransformTrimap)
from .normalization import Normalize, RescaleToZeroOne
from .sr_resize import RandomDownSampling, SRResize

__all__ = [
'Collect', 'FormatTrimap', 'LoadImageFromFile', 'LoadMask',
Expand All @@ -29,7 +29,7 @@
'GenerateFrameIndices', 'GenerateFrameIndiceswithPadding', 'FixedCrop',
'LoadPairedImageFromFile', 'GenerateSoftSeg', 'GenerateSeg', 'PerturbBg',
'CropAroundFg', 'GetSpatialDiscountMask', 'RandomDownSampling',
'GenerateTrimapWithDistTransform', 'TransformTrimap',
'GenerateTrimapWithDistTransform', 'TransformTrimap', 'SRResize',
'GenerateCoordinateAndCell', 'GenerateSegmentIndices', 'MirrorSequence',
'CropLike'
]
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,76 @@
from ..registry import PIPELINES


@PIPELINES.register_module()
class SRResize:
"""Resize image by a scale, including upsampling and downsampling.
Image will be loaded from the input_key and the result will be saved
in the specified output_key (can equal to input_key).
Args:
scale (float): The resampling scale. scale > 0.
scale > 1: upsampling.
scale < 1: downsampling.
input_key (str): The input key.
output_key (str): The output key.
interpolation (str): Interpolation method, accepted values are
"nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
backend, "nearest", "bilinear", "bicubic", "box", "lanczos",
"hamming" for 'pillow' backend.
Default: "bicubic".
backend (str | None): The image resize backend type. Options are `cv2`,
`pillow`, `None`. If backend is None, the global imread_backend
specified by ``mmcv.use_backend()`` will be used.
Default: "pillow".
"""

def __init__(self,
scale,
input_key,
output_key,
interpolation='bicubic',
backend='pillow'):
self.scale = scale
self.input_key = input_key
self.output_key = output_key
self.interpolation = interpolation
self.backend = backend

def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation. self.input_key is required.
Returns:
dict: A dict containing the processed data and information.
supplement self.output_key to keys.
"""
assert self.input_key in results, f'Cannot find {self.input_key}.'
image_in = results[self.input_key]
h_in, w_in = image_in.shape[:2]
h_out = math.floor(h_in * self.scale + 1e-9)
w_out = math.floor(w_in * self.scale + 1e-9)
image_out = resize_fn(image_in, (w_out, h_out), self.interpolation,
self.backend)

results[self.output_key] = image_out

return results

def __repr__(self):
repr_str = self.__class__.__name__
repr_str += (f' scale={self.scale}, '
f'input_key={self.input_key}, '
f'output_key={self.output_key}, '
f'interpolation={self.interpolation}, '
f'backend={self.backend}')

return repr_str


@PIPELINES.register_module()
class RandomDownSampling:
"""Generate LQ image from GT (and crop), which will randomly pick a scale.
Expand Down Expand Up @@ -81,9 +151,11 @@ def __call__(self, results):

def __repr__(self):
repr_str = self.__class__.__name__
repr_str += (f'scale_min={self.scale_min}, '
repr_str += (f' scale_min={self.scale_min}, '
f'scale_max={self.scale_max}, '
f'patch_size={self.patch_size}')
f'patch_size={self.patch_size}, '
f'interpolation={self.interpolation}, '
f'backend={self.backend}')

return repr_str

Expand Down
29 changes: 0 additions & 29 deletions tests/test_down_sampling.py

This file was deleted.

54 changes: 54 additions & 0 deletions tests/test_sr_resize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import numpy as np

from mmedit.datasets.pipelines import RandomDownSampling, SRResize


def test_random_down_sampling():
img1 = np.uint8(np.random.randn(480, 640, 3) * 255)
inputs1 = dict(gt=img1)
down_sampling1 = RandomDownSampling(
scale_min=1, scale_max=4, patch_size=None)
results1 = down_sampling1(inputs1)
assert set(list(results1.keys())) == set(['gt', 'lq', 'scale'])
assert repr(down_sampling1) == (
down_sampling1.__class__.__name__ +
f' scale_min={down_sampling1.scale_min}, ' +
f'scale_max={down_sampling1.scale_max}, ' +
f'patch_size={down_sampling1.patch_size}, ' +
f'interpolation={down_sampling1.interpolation}, ' +
f'backend={down_sampling1.backend}')

img2 = np.uint8(np.random.randn(480, 640, 3) * 255)
inputs2 = dict(gt=img2)
down_sampling2 = RandomDownSampling(
scale_min=1, scale_max=4, patch_size=48)
results2 = down_sampling2(inputs2)
assert set(list(results2.keys())) == set(['gt', 'lq', 'scale'])
assert repr(down_sampling2) == (
down_sampling2.__class__.__name__ +
f' scale_min={down_sampling2.scale_min}, ' +
f'scale_max={down_sampling2.scale_max}, ' +
f'patch_size={down_sampling2.patch_size}, ' +
f'interpolation={down_sampling2.interpolation}, ' +
f'backend={down_sampling2.backend}')


def test_sr_resize():
img = np.uint8(np.random.randn(480, 640, 3) * 255)
inputs = dict(gt=img)
re_size = SRResize(scale=1 / 4, input_key='gt', output_key='lq')
results = re_size(inputs)
assert set(list(results.keys())) == set(['gt', 'lq'])
assert results['lq'].shape == (120, 160, 3)
assert repr(re_size) == (
re_size.__class__.__name__ + f' scale={re_size.scale}, ' +
f'input_key={re_size.input_key}, ' +
f'output_key={re_size.output_key}, ' +
f'interpolation={re_size.interpolation}, ' +
f'backend={re_size.backend}')

inputs = dict(gt=img)
re_size = SRResize(scale=2, input_key='gt', output_key='gt')
results = re_size(inputs)
assert set(list(results.keys())) == set(['gt'])
assert results['gt'].shape == (960, 1280, 3)

0 comments on commit d2a4142

Please sign in to comment.