Skip to content

Commit

Permalink
Add down_sampling.py for generating LQ image from GT image. (#222)
Browse files Browse the repository at this point in the history
* Add down_sampling.py for generating LQ image from GT image, which is required in LIIF.

* Add '__repr__' and test_down_sampling.py.

* Add docstring, rename parameter and change the function of resize.

* Fine-tuning code and docstring of RandomDownSampling class.

* Remove hardcode of bicubic and pillow.

Co-authored-by: 李尹硕 <SENSETIME\[email protected]>
  • Loading branch information
Yshuo-Li and 李尹硕 authored Mar 15, 2021
1 parent 79f74a1 commit ad517d4
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 1 deletion.
3 changes: 2 additions & 1 deletion mmedit/datasets/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .compose import Compose
from .crop import (Crop, CropAroundCenter, CropAroundFg, CropAroundUnknown,
FixedCrop, ModCrop, PairedRandomCrop)
from .down_sampling import RandomDownSampling
from .formating import (Collect, FormatTrimap, GetMaskedImage, ImageToTensor,
ToTensor)
from .loading import (GetSpatialDiscountMask, LoadImageFromFile,
Expand All @@ -25,6 +26,6 @@
'MergeFgAndBg', 'CompositeFg', 'TemporalReverse', 'LoadImageFromFileList',
'GenerateFrameIndices', 'GenerateFrameIndiceswithPadding', 'FixedCrop',
'LoadPairedImageFromFile', 'GenerateSoftSeg', 'GenerateSeg', 'PerturbBg',
'CropAroundFg', 'GetSpatialDiscountMask',
'CropAroundFg', 'GetSpatialDiscountMask', 'RandomDownSampling',
'GenerateTrimapWithDistTransform', 'TransformTrimap'
]
122 changes: 122 additions & 0 deletions mmedit/datasets/pipelines/down_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import math

import numpy as np
import torch
from mmcv import imresize

from ..registry import PIPELINES


@PIPELINES.register_module()
class RandomDownSampling:
"""Generate LQ image from GT (and crop), which will randomly pick a scale.
Args:
scale_min (float): The minimum of upsampling scale, inclusive.
Default: 1.0.
scale_max (float): The maximum of upsampling scale, exclusive.
Default: 4.0.
patch_size (int): The cropped lr patch size.
Default: None, means no crop.
interpolation (str): Interpolation method, accepted values are
"nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
backend, "nearest", "bilinear", "bicubic", "box", "lanczos",
"hamming" for 'pillow' backend.
Default: "bicubic".
backend (str | None): The image resize backend type. Options are `cv2`,
`pillow`, `None`. If backend is None, the global imread_backend
specified by ``mmcv.use_backend()`` will be used.
Default: "pillow".
Scale will be picked in the range of [scale_min, scale_max).
"""

def __init__(self,
scale_min=1.0,
scale_max=4.0,
patch_size=None,
interpolation='bicubic',
backend='pillow'):
assert scale_max >= scale_min
self.scale_min = scale_min
self.scale_max = scale_max
self.patch_size = patch_size
self.interpolation = interpolation
self.backend = backend

def __call__(self, results):
"""Call function.
Args:
results (dict): A dict containing the necessary information and
data for augmentation. 'gt' is required.
Returns:
dict: A dict containing the processed data and information.
modified 'gt', supplement 'lq' and 'scale' to keys.
"""
img = results['gt']
scale = np.random.uniform(self.scale_min, self.scale_max)

if self.patch_size is None:
h_lr = math.floor(img.shape[-3] / scale + 1e-9)
w_lr = math.floor(img.shape[-2] / scale + 1e-9)
img = img[:round(h_lr * scale), :round(w_lr * scale), :]
img_down = resize_fn(img, (w_lr, h_lr), self.interpolation,
self.backend)
crop_lr, crop_hr = img_down, img
else:
w_lr = self.patch_size
w_hr = round(w_lr * scale)
x0 = np.random.randint(0, img.shape[-3] - w_hr)
y0 = np.random.randint(0, img.shape[-2] - w_hr)
crop_hr = img[x0:x0 + w_hr, y0:y0 + w_hr, :]
crop_lr = resize_fn(crop_hr, w_lr, self.interpolation,
self.backend)
results['gt'] = crop_hr
results['lq'] = crop_lr
results['scale'] = scale

return results

def __repr__(self):
repr_str = self.__class__.__name__
repr_str += (f'scale_min={self.scale_min}, '
f'scale_max={self.scale_max}, '
f'patch_size={self.patch_size}')

return repr_str


def resize_fn(img, size, interpolation='bicubic', backend='pillow'):
"""Resize the given image to a given size.
Args:
img (ndarray | torch.Tensor): The input image.
size (int | tuple[int]): Target size w or (w, h).
interpolation (str): Interpolation method, accepted values are
"nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2'
backend, "nearest", "bilinear", "bicubic", "box", "lanczos",
"hamming" for 'pillow' backend.
Default: "bicubic".
backend (str | None): The image resize backend type. Options are `cv2`,
`pillow`, `None`. If backend is None, the global imread_backend
specified by ``mmcv.use_backend()`` will be used.
Default: "pillow".
Returns:
ndarray | torch.Tensor: `resized_img`, whose type is same as `img`.
"""
if isinstance(size, int):
size = (size, size)
if isinstance(img, np.ndarray):
return imresize(
img, size, interpolation=interpolation, backend=backend)
elif isinstance(img, torch.Tensor):
image = imresize(
img.numpy(), size, interpolation=interpolation, backend=backend)
return torch.from_numpy(image)

else:
raise TypeError('img should got np.ndarray or torch.Tensor,'
f'but got {type(img)}')
29 changes: 29 additions & 0 deletions tests/test_down_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import numpy as np

from mmedit.datasets.pipelines import RandomDownSampling


def test_down_sampling():
img1 = np.uint8(np.random.randn(480, 640, 3) * 255)
inputs1 = dict(gt=img1)
down_sampling1 = RandomDownSampling(
scale_min=1, scale_max=4, patch_size=None)
results1 = down_sampling1(inputs1)
assert set(list(results1.keys())) == set(['gt', 'lq', 'scale'])
assert repr(down_sampling1) == (
down_sampling1.__class__.__name__ +
f'scale_min={down_sampling1.scale_min}, ' +
f'scale_max={down_sampling1.scale_max}, ' +
f'patch_size={down_sampling1.patch_size}')

img2 = np.uint8(np.random.randn(480, 640, 3) * 255)
inputs2 = dict(gt=img2)
down_sampling2 = RandomDownSampling(
scale_min=1, scale_max=4, patch_size=48)
results2 = down_sampling2(inputs2)
assert set(list(results2.keys())) == set(['gt', 'lq', 'scale'])
assert repr(down_sampling2) == (
down_sampling2.__class__.__name__ +
f'scale_min={down_sampling2.scale_min}, ' +
f'scale_max={down_sampling2.scale_max}, ' +
f'patch_size={down_sampling2.patch_size}')

0 comments on commit ad517d4

Please sign in to comment.