Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add down_sampling.py for generating LQ image from GT image. #222

Merged
merged 5 commits into from
Mar 15, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mmedit/datasets/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .compose import Compose
from .crop import (Crop, CropAroundCenter, CropAroundFg, CropAroundUnknown,
FixedCrop, ModCrop, PairedRandomCrop)
from .down_sampling import DownSampling
from .formating import (Collect, FormatTrimap, GetMaskedImage, ImageToTensor,
ToTensor)
from .loading import (GetSpatialDiscountMask, LoadImageFromFile,
Expand All @@ -25,6 +26,6 @@
'MergeFgAndBg', 'CompositeFg', 'TemporalReverse', 'LoadImageFromFileList',
'GenerateFrameIndices', 'GenerateFrameIndiceswithPadding', 'FixedCrop',
'LoadPairedImageFromFile', 'GenerateSoftSeg', 'GenerateSeg', 'PerturbBg',
'CropAroundFg', 'GetSpatialDiscountMask',
'CropAroundFg', 'GetSpatialDiscountMask', 'DownSampling',
'GenerateTrimapWithDistTransform', 'TransformTrimap'
]
73 changes: 73 additions & 0 deletions mmedit/datasets/pipelines/down_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import math
import random

import numpy as np
import torch
from PIL import Image
from torchvision import transforms

from ..registry import PIPELINES


@PIPELINES.register_module()
class DownSampling:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DownSampling -> RandomDownSampling

"""Generate LQ image from GT (and crop).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Give a detailed description here. e.g. it randomly pick a scale


Args:
scale_min (int): The minimum of upsampling scale. Default: 1.
scale_max (int): The maximum of upsampling scale. Default: 4.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this min to max range inclusive or exclusive

inp_size (int): The input size, i.e. cropped lr patch size.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inp_size -> input_size

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or patch_size?

Default: None, means no crop.
"""

def __init__(self, scale_min=1, scale_max=4, inp_size=None):
assert scale_max >= scale_min
self.scale_min = scale_min
self.scale_max = scale_max
self.inp_size = inp_size

def __call__(self, results):
"""Call function.

Args:
results (dict): A dict containing the necessary information and
data for augmentation.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clearly specify which keys are required, which are modified


Returns:
dict: A dict containing the processed data and information.
"""
img = results['gt']
scale = random.uniform(self.scale_min, self.scale_max)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use np.random

if self.inp_size is None:
h_lr = math.floor(img.shape[-3] / scale + 1e-9)
w_lr = math.floor(img.shape[-2] / scale + 1e-9)
img = img[:round(h_lr * scale), :round(w_lr * scale), :]
img_down = resize_fn(img, (w_lr, h_lr))
crop_lr, crop_hr = img_down, img
else:
w_lr = self.inp_size
w_hr = round(w_lr * scale)
x0 = random.randint(0, img.shape[-3] - w_hr)
y0 = random.randint(0, img.shape[-2] - w_hr)
crop_hr = img[x0:x0 + w_hr, y0:y0 + w_hr, :]
crop_lr = resize_fn(crop_hr, w_lr)
results['gt'] = crop_hr
results['lq'] = crop_lr
results['scale'] = scale

return results


def resize_fn(img, size):
if isinstance(size, int):
size = (size, size)
if isinstance(img, np.ndarray):
return np.asarray(Image.fromarray(img).resize(size, Image.BICUBIC))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Will cv2.resize or mmcv.imresize work here?
  • Is the PIL image a hard requirement

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Will cv2.resize or mmcv.imresize work here?
  • Is the PIL image a hard requirement?

elif isinstance(img, torch.Tensor):
return transforms.ToTensor()(
transforms.Resize(size,
Image.BICUBIC)(transforms.ToPILImage()(img)))

else:
raise TypeError('img should got np.ndarray or torch.Tensor,'
f'but got {type(img)}')