Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhance] support all image super resolution models inference #1662

Merged
merged 6 commits into from
Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions configs/dic/dic_x8c48b6_4xb2-150k_celeba-hq.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,30 @@
]
test_pipeline = valid_pipeline

inference_pipeline = [
dict(
type='LoadImageFromFile',
key='img',
color_type='color',
channel_order='rgb',
imdecode_backend='cv2'),
dict(
type='Resize',
scale=(128, 128),
keys=['img'],
interpolation='bicubic',
backend='pillow'),
dict(
type='Resize',
scale=1 / 8,
keep_ratio=True,
keys=['img'],
output_keys=['img'],
interpolation='bicubic',
backend='pillow'),
dict(type='PackEditInputs')
]

# dataset settings
dataset_type = 'BasicImageDataset'
data_root = 'data'
Expand Down
15 changes: 15 additions & 0 deletions configs/glean/glean_x8_2xb8_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,21 @@
dict(type='PackEditInputs')
]

inference_pipeline = [
dict(
type='LoadImageFromFile',
key='img',
color_type='color',
channel_order='rgb'),
dict(
type='Resize',
scale=(32, 32),
keys=['img'],
interpolation='bicubic',
backend='pillow'),
dict(type='PackEditInputs')
]

# dataset settings
dataset_type = 'BasicImageDataset'

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,16 @@
from mmengine import mkdir_or_exist
from mmengine.dataset import Compose
from mmengine.dataset.utils import default_collate as collate
from torch.nn.parallel import scatter

from mmedit.utils import tensor2img
from .base_mmedit_inferencer import BaseMMEditInferencer, InputsType, PredType


class RestorationInferencer(BaseMMEditInferencer):
class ImageSuperResolutionInferencer(BaseMMEditInferencer):
"""inferencer that predicts with restoration models."""

func_kwargs = dict(
preprocess=['img'],
preprocess=['img', 'ref'],
forward=[],
visualize=['result_out_dir'],
postprocess=[])
Expand All @@ -38,14 +37,15 @@ def preprocess(self, img: InputsType, ref: InputsType = None) -> Dict:
device = next(self.model.parameters()).device # model device

# select the data pipeline
if cfg.get('demo_pipeline', None):
if cfg.get('inference_pipeline', None):
test_pipeline = cfg.inference_pipeline
elif cfg.get('demo_pipeline', None):
test_pipeline = cfg.demo_pipeline
elif cfg.get('test_pipeline', None):
test_pipeline = cfg.test_pipeline
else:
test_pipeline = cfg.val_pipeline

# remove gt from test_pipeline
keys_to_remove = ['gt', 'gt_path']
for key in keys_to_remove:
for pipeline in list(test_pipeline):
Expand All @@ -57,31 +57,31 @@ def preprocess(self, img: InputsType, ref: InputsType = None) -> Dict:
test_pipeline.remove(pipeline)
if 'meta_keys' in pipeline and key in pipeline['meta_keys']:
pipeline['meta_keys'].remove(key)

# build the data pipeline
test_pipeline = Compose(test_pipeline)

# prepare data
if ref: # Ref-SR
data = dict(img_path=img, ref_path=ref)
data = dict(img_path=img, gt_path=ref)
else: # SISR
data = dict(img_path=img)
_data = test_pipeline(data)

data = dict()
data_preprocessor = cfg['model']['data_preprocessor']
mean = torch.Tensor(data_preprocessor['mean']).view([3, 1, 1])
std = torch.Tensor(data_preprocessor['std']).view([3, 1, 1])
data['inputs'] = (_data['inputs'] - mean) / std
data = collate([data])

if ref:
data['data_samples'] = [_data['data_samples']]
if 'cuda' in str(device):
data = scatter(data, [device])[0]
data['inputs'] = data['inputs'].cuda()
if ref:
data['data_samples'][0].img_lq.data = data['data_samples'][
0].img_lq.data.to(device)
data['data_samples'][0].ref_lq.data = data['data_samples'][
0].ref_lq.data.to(device)
data['data_samples'][0].ref_img.data = data['data_samples'][
0].ref_img.data.to(device)
data['data_samples'][0] = data['data_samples'][0].cuda()

return data

def forward(self, inputs: InputsType) -> PredType:
Expand Down
6 changes: 3 additions & 3 deletions mmedit/apis/inferencers/mmedit_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from .colorization_inferencer import ColorizationInferencer
from .conditional_inferencer import ConditionalInferencer
from .eg3d_inferencer import EG3DInferencer
from .image_super_resolution_inferencer import ImageSuperResolutionInferencer
from .inpainting_inferencer import InpaintingInferencer
from .matting_inferencer import MattingInferencer
from .restoration_inferencer import RestorationInferencer
from .text2image_inferencer import Text2ImageInferencer
from .translation_inferencer import TranslationInferencer
from .unconditional_inferencer import UnconditionalInferencer
Expand Down Expand Up @@ -55,8 +55,8 @@ def __init__(self,
elif self.task in ['translation', 'Image2Image']:
self.inferencer = TranslationInferencer(
config, ckpt, device, extra_parameters, seed=seed)
elif self.task in ['restoration', 'Image Super-Resolution']:
self.inferencer = RestorationInferencer(
elif self.task in ['Image super-resolution', 'Image Super-Resolution']:
self.inferencer = ImageSuperResolutionInferencer(
config, ckpt, device, extra_parameters, seed=seed)
elif self.task in ['video_restoration', 'Video Super-Resolution']:
self.inferencer = VideoRestorationInferencer(
Expand Down
14 changes: 12 additions & 2 deletions mmedit/edit.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ class MMEdit:
>>> # see demo/mmediting_inference_tutorial.ipynb for more examples
"""
# unsupported now
# singan
# singan, liif
# output should be checked
# dic, glean

inference_supported_models = [
# colorization models
Expand Down Expand Up @@ -71,8 +73,16 @@ class MMEdit:
'pix2pix',
'cyclegan',

# restoration models
# image super-resolution models
'srcnn',
'srgan_resnet',
'edsr',
'esrgan',
'rdn',
'dic',
'ttsr',
'glean',
'real_esrgan',

# video_interpolation models
'flavr',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import pytest
import torch

from mmedit.apis.inferencers.restoration_inferencer import \
RestorationInferencer
from mmedit.apis.inferencers.image_super_resolution_inferencer import \
ImageSuperResolutionInferencer
from mmedit.utils import register_all_modules

register_all_modules()
Expand All @@ -15,16 +15,16 @@
@pytest.mark.skipif(
'win' in platform.system().lower() and 'cu' in torch.__version__,
reason='skip on windows-cuda due to limited RAM.')
def test_restoration_inferencer():
def test_image_super_resolution_inferencer():
data_root = osp.join(osp.dirname(__file__), '../../../')
config = data_root + 'configs/esrgan/esrgan_x4c64b23g32_1xb16-400k_div2k.py' # noqa
img_path = data_root + 'tests/data/image/lq/baboon_x4.png'
result_out_dir = osp.join(
osp.dirname(__file__), '..', '..', 'data/out',
'restoration_result.png')
'image_super_resolution_result.png')

inferencer_instance = \
RestorationInferencer(config, None)
ImageSuperResolutionInferencer(config, None)
inferencer_instance(img=img_path)
inference_result = inferencer_instance(
img=img_path, result_out_dir=result_out_dir)
Expand Down