From 99d58ac8ae697a1c91323c72003bed2d7f6918ae Mon Sep 17 00:00:00 2001 From: ckkelvinchan Date: Thu, 23 Sep 2021 14:28:01 +0800 Subject: [PATCH 1/3] add RealESRGAN model --- mmedit/models/restorers/__init__.py | 3 +- mmedit/models/restorers/real_esrgan.py | 235 ++++++++++++++++ .../test_restorers/test_real_esrgan.py | 259 ++++++++++++++++++ 3 files changed, 496 insertions(+), 1 deletion(-) create mode 100644 mmedit/models/restorers/real_esrgan.py create mode 100644 tests/test_models/test_restorers/test_real_esrgan.py diff --git a/mmedit/models/restorers/__init__.py b/mmedit/models/restorers/__init__.py index 0f69055a65..f334495608 100644 --- a/mmedit/models/restorers/__init__.py +++ b/mmedit/models/restorers/__init__.py @@ -6,11 +6,12 @@ from .esrgan import ESRGAN from .glean import GLEAN from .liif import LIIF +from .real_esrgan import RealESRGAN from .srgan import SRGAN from .tdan import TDAN from .ttsr import TTSR __all__ = [ 'BasicRestorer', 'SRGAN', 'ESRGAN', 'EDVR', 'LIIF', 'BasicVSR', 'TTSR', - 'GLEAN', 'TDAN', 'DIC' + 'GLEAN', 'TDAN', 'DIC', 'RealESRGAN' ] diff --git a/mmedit/models/restorers/real_esrgan.py b/mmedit/models/restorers/real_esrgan.py new file mode 100644 index 0000000000..5ca71f9ec7 --- /dev/null +++ b/mmedit/models/restorers/real_esrgan.py @@ -0,0 +1,235 @@ +import numbers +import os.path as osp +from copy import deepcopy + +import mmcv +import torch +from mmcv.parallel import is_module_wrapper + +from mmedit.core import tensor2img +from ..common import set_requires_grad +from ..registry import MODELS +from .srgan import SRGAN + + +@MODELS.register_module() +class RealESRGAN(SRGAN): + """Real-ESRGAN model for single image super-resolution. + + Ref: + Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure + Synthetic Data, 2021. + + Args: + generator (dict): Config for the generator. + discriminator (dict, optional): Config for the discriminator. + Default: None. + gan_loss (dict, optional): Config for the gan loss. + Note that the loss weight in gan loss is only for the generator. + pixel_loss (dict, optional): Config for the pixel loss. Default: None. + perceptual_loss (dict, optional): Config for the perceptual loss. + Default: None. + is_use_sharpened_gt_in_pixel (bool, optional): Whether to use the image + sharpened by unsharp masking as the GT for pixel loss. + Default: False. + is_use_sharpened_gt_in_percep (bool, optional): Whether to use the + image sharpened by unsharp masking as the GT for perceptual loss. + Default: False. + is_use_sharpened_gt_in_gan (bool, optional): Whether to use the + image sharpened by unsharp masking as the GT for adversarial loss. + Default: False. + is_use_ema (bool, optional): When to apply exponential moving average + on the network weights. Default: True. + train_cfg (dict): Config for training. Default: None. + You may change the training of gan by setting: + `disc_steps`: how many discriminator updates after one generate + update; + `disc_init_steps`: how many discriminator updates at the start of + the training. + These two keys are useful when training with WGAN. + test_cfg (dict): Config for testing. Default: None. + pretrained (str): Path for pretrained model. Default: None. + """ + + def __init__(self, + generator, + discriminator=None, + gan_loss=None, + pixel_loss=None, + perceptual_loss=None, + is_use_sharpened_gt_in_pixel=False, + is_use_sharpened_gt_in_percep=False, + is_use_sharpened_gt_in_gan=False, + is_use_ema=True, + train_cfg=None, + test_cfg=None, + pretrained=None): + + super().__init__(generator, discriminator, gan_loss, pixel_loss, + perceptual_loss, train_cfg, test_cfg, pretrained) + + self.is_use_sharpened_gt_in_pixel = is_use_sharpened_gt_in_pixel + self.is_use_sharpened_gt_in_percep = is_use_sharpened_gt_in_percep + self.is_use_sharpened_gt_in_gan = is_use_sharpened_gt_in_gan + + self.is_use_ema = is_use_ema + if is_use_ema: + self.generator_ema = deepcopy(self.generator) + else: + self.generator_ema = None + + del self.step_counter + self.register_buffer('step_counter', torch.zeros(1)) + + if train_cfg is not None: # used for initializeing from ema model + self.start_iter = train_cfg.get('start_iter', -1) + else: + self.start_iter = -1 + + def train_step(self, data_batch, optimizer): + """Train step. + + Args: + data_batch (dict): A batch of data. + optimizer (obj): Optimizer. + + Returns: + dict: Returned output. + """ + # during initialization, load weights from the ema model + if (self.step_counter == self.start_iter + and self.generator_ema is not None): + if is_module_wrapper(self.generator): + self.generator.module.load_state_dict( + self.generator_ema.module.state_dict()) + else: + self.generator.load_state_dict(self.generator_ema.state_dict()) + + # self.generator = deepcopy(self.generator_ema) + + # data + lq = data_batch['lq'] + gt = data_batch['gt'] + + gt_pixel, gt_percep, gt_gan = gt.clone(), gt.clone(), gt.clone() + if self.is_use_sharpened_gt_in_pixel: + gt_pixel = data_batch['gt_unsharp'] + if self.is_use_sharpened_gt_in_percep: + gt_percep = data_batch['gt_unsharp'] + if self.is_use_sharpened_gt_in_gan: + gt_gan = data_batch['gt_unsharp'] + + # generator + fake_g_output = self.generator(lq) + + losses = dict() + log_vars = dict() + + # no updates to discriminator parameters. + if self.gan_loss: + set_requires_grad(self.discriminator, False) + + if (self.step_counter % self.disc_steps == 0 + and self.step_counter >= self.disc_init_steps): + if self.pixel_loss: + losses['loss_pix'] = self.pixel_loss(fake_g_output, gt_pixel) + if self.perceptual_loss: + loss_percep, loss_style = self.perceptual_loss( + fake_g_output, gt_percep) + if loss_percep is not None: + losses['loss_perceptual'] = loss_percep + if loss_style is not None: + losses['loss_style'] = loss_style + + # gan loss for generator + if self.gan_loss: + fake_g_pred = self.discriminator(fake_g_output) + losses['loss_gan'] = self.gan_loss( + fake_g_pred, target_is_real=True, is_disc=False) + + # parse loss + loss_g, log_vars_g = self.parse_losses(losses) + log_vars.update(log_vars_g) + + # optimize + optimizer['generator'].zero_grad() + loss_g.backward() + optimizer['generator'].step() + + # discriminator + if self.gan_loss: + set_requires_grad(self.discriminator, True) + # real + real_d_pred = self.discriminator(gt_gan) + loss_d_real = self.gan_loss( + real_d_pred, target_is_real=True, is_disc=True) + loss_d, log_vars_d = self.parse_losses( + dict(loss_d_real=loss_d_real)) + optimizer['discriminator'].zero_grad() + loss_d.backward() + log_vars.update(log_vars_d) + # fake + fake_d_pred = self.discriminator(fake_g_output.detach()) + loss_d_fake = self.gan_loss( + fake_d_pred, target_is_real=False, is_disc=True) + loss_d, log_vars_d = self.parse_losses( + dict(loss_d_fake=loss_d_fake)) + loss_d.backward() + log_vars.update(log_vars_d) + + optimizer['discriminator'].step() + + self.step_counter += 1 + + log_vars.pop('loss') # remove the unnecessary 'loss' + outputs = dict( + log_vars=log_vars, + num_samples=len(gt.data), + results=dict(lq=lq.cpu(), gt=gt.cpu(), output=fake_g_output.cpu())) + + return outputs + + def forward_test(self, + lq, + gt=None, + meta=None, + save_image=False, + save_path=None, + iteration=None): + """Testing forward function. + + Args: + lq (Tensor): LQ Tensor with shape (n, c, h, w). + gt (Tensor): GT Tensor with shape (n, c, h, w). Default: None. + save_image (bool): Whether to save image. Default: False. + save_path (str): Path to save image. Default: None. + iteration (int): Iteration for the saving image name. + Default: None. + + Returns: + dict: Output results. + """ + _model = self.generator_ema if self.is_use_ema else self.generator + output = _model(lq) + + if self.test_cfg is not None and self.test_cfg.get( + 'metrics', None) and gt is not None: + results = dict(eval_result=self.evaluate(output, gt)) + else: + results = dict(lq=lq.cpu(), output=output.cpu()) + + # save image + if save_image: + lq_path = meta[0]['lq_path'] + folder_name = osp.splitext(osp.basename(lq_path))[0] + if isinstance(iteration, numbers.Number): + save_path = osp.join(save_path, folder_name, + f'{folder_name}-{iteration + 1:06d}.png') + elif iteration is None: + save_path = osp.join(save_path, f'{folder_name}.png') + else: + raise ValueError('iteration should be number or None, ' + f'but got {type(iteration)}') + mmcv.imwrite(tensor2img(output), save_path) + + return results diff --git a/tests/test_models/test_restorers/test_real_esrgan.py b/tests/test_models/test_restorers/test_real_esrgan.py new file mode 100644 index 0000000000..154f62a34f --- /dev/null +++ b/tests/test_models/test_restorers/test_real_esrgan.py @@ -0,0 +1,259 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest.mock import patch + +import pytest +import torch +from mmcv.runner import obj_from_dict + +from mmedit.models import build_model +from mmedit.models.backbones import MSRResNet +from mmedit.models.components import ModifiedVGG +from mmedit.models.losses import GANLoss, L1Loss + + +def test_real_esrgan(): + + model_cfg = dict( + type='RealESRGAN', + generator=dict( + type='MSRResNet', + in_channels=3, + out_channels=3, + mid_channels=4, + num_blocks=1, + upscale_factor=4), + discriminator=dict(type='ModifiedVGG', in_channels=3, mid_channels=2), + pixel_loss=dict(type='L1Loss', loss_weight=1.0, reduction='mean'), + gan_loss=dict( + type='GANLoss', + gan_type='vanilla', + loss_weight=1e-1, + real_label_val=1.0, + fake_label_val=0), + is_use_sharpened_gt_in_pixel=True, + is_use_sharpened_gt_in_percep=True, + is_use_sharpened_gt_in_gan=True, + is_use_ema=True, + ) + + train_cfg = None + test_cfg = None + + # build restorer + restorer = build_model(model_cfg, train_cfg=train_cfg, test_cfg=test_cfg) + + # test attributes + assert restorer.__class__.__name__ == 'RealESRGAN' + assert isinstance(restorer.generator, MSRResNet) + assert isinstance(restorer.discriminator, ModifiedVGG) + assert isinstance(restorer.pixel_loss, L1Loss) + assert isinstance(restorer.gan_loss, GANLoss) + + # prepare data + inputs = torch.rand(1, 3, 32, 32) + targets = torch.rand(1, 3, 128, 128) + data_batch = {'lq': inputs, 'gt': targets, 'gt_unsharp': targets} + + # prepare optimizer + optim_cfg = dict(type='Adam', lr=2e-4, betas=(0.9, 0.999)) + optimizer = { + 'generator': + obj_from_dict(optim_cfg, torch.optim, + dict( + params=getattr(restorer, 'generator').parameters())), + 'discriminator': + obj_from_dict( + optim_cfg, torch.optim, + dict(params=getattr(restorer, 'discriminator').parameters())) + } + + # no forward train in GAN models, raise ValueError + with pytest.raises(ValueError): + restorer(**data_batch, test_mode=False) + + # test forward_test + data_batch.pop('gt_unsharp') + with torch.no_grad(): + outputs = restorer(**data_batch, test_mode=True) + assert torch.equal(outputs['lq'], data_batch['lq']) + assert torch.is_tensor(outputs['output']) + assert outputs['output'].size() == (1, 3, 128, 128) + + # test forward_dummy + with torch.no_grad(): + output = restorer.forward_dummy(data_batch['lq']) + assert torch.is_tensor(output) + assert output.size() == (1, 3, 128, 128) + + # val_step + with torch.no_grad(): + outputs = restorer.val_step(data_batch) + data_batch['gt_unsharp'] = targets + assert torch.equal(outputs['lq'], data_batch['lq']) + assert torch.is_tensor(outputs['output']) + assert outputs['output'].size() == (1, 3, 128, 128) + + # test train_step + with patch.object( + restorer, + 'perceptual_loss', + return_value=(torch.tensor(1.0), torch.tensor(2.0))): + outputs = restorer.train_step(data_batch, optimizer) + assert isinstance(outputs, dict) + assert isinstance(outputs['log_vars'], dict) + for v in [ + 'loss_perceptual', 'loss_gan', 'loss_d_real', 'loss_d_fake', + 'loss_pix' + ]: + assert isinstance(outputs['log_vars'][v], float) + assert outputs['num_samples'] == 1 + assert torch.equal(outputs['results']['lq'], data_batch['lq']) + assert torch.equal(outputs['results']['gt'], data_batch['gt']) + assert torch.is_tensor(outputs['results']['output']) + assert outputs['results']['output'].size() == (1, 3, 128, 128) + + # test train_step and forward_test (gpu) + if torch.cuda.is_available(): + restorer = restorer.cuda() + optimizer = { + 'generator': + obj_from_dict( + optim_cfg, torch.optim, + dict(params=getattr(restorer, 'generator').parameters())), + 'discriminator': + obj_from_dict( + optim_cfg, torch.optim, + dict(params=getattr(restorer, 'discriminator').parameters())) + } + data_batch = { + 'lq': inputs.cuda(), + 'gt': targets.cuda(), + 'gt_unsharp': targets.cuda() + } + + # forward_test + data_batch.pop('gt_unsharp') + with torch.no_grad(): + outputs = restorer(**data_batch, test_mode=True) + assert torch.equal(outputs['lq'], data_batch['lq'].cpu()) + assert torch.is_tensor(outputs['output']) + assert outputs['output'].size() == (1, 3, 128, 128) + + # val_step + with torch.no_grad(): + outputs = restorer.val_step(data_batch) + data_batch['gt_unsharp'] = targets.cuda() + assert torch.equal(outputs['lq'], data_batch['lq'].cpu()) + assert torch.is_tensor(outputs['output']) + assert outputs['output'].size() == (1, 3, 128, 128) + + # train_step + with patch.object( + restorer, + 'perceptual_loss', + return_value=(torch.tensor(1.0).cuda(), + torch.tensor(2.0).cuda())): + outputs = restorer.train_step(data_batch, optimizer) + assert isinstance(outputs, dict) + assert isinstance(outputs['log_vars'], dict) + for v in [ + 'loss_perceptual', 'loss_gan', 'loss_d_real', + 'loss_d_fake', 'loss_pix' + ]: + assert isinstance(outputs['log_vars'][v], float) + assert outputs['num_samples'] == 1 + assert torch.equal(outputs['results']['lq'], + data_batch['lq'].cpu()) + assert torch.equal(outputs['results']['gt'], + data_batch['gt'].cpu()) + assert torch.is_tensor(outputs['results']['output']) + assert outputs['results']['output'].size() == (1, 3, 128, 128) + + # test disc_steps and disc_init_steps + data_batch = { + 'lq': inputs.cpu(), + 'gt': targets.cpu(), + 'gt_unsharp': targets.cpu() + } + train_cfg = dict(disc_steps=2, disc_init_steps=2) + restorer = build_model(model_cfg, train_cfg=train_cfg, test_cfg=test_cfg) + with patch.object( + restorer, + 'perceptual_loss', + return_value=(torch.tensor(1.0), torch.tensor(2.0))): + outputs = restorer.train_step(data_batch, optimizer) + assert isinstance(outputs, dict) + assert isinstance(outputs['log_vars'], dict) + for v in ['loss_d_real', 'loss_d_fake']: + assert isinstance(outputs['log_vars'][v], float) + assert outputs['num_samples'] == 1 + assert torch.equal(outputs['results']['lq'], data_batch['lq']) + assert torch.equal(outputs['results']['gt'], data_batch['gt']) + assert torch.is_tensor(outputs['results']['output']) + assert outputs['results']['output'].size() == (1, 3, 128, 128) + + # test no discriminator (testing mode) + model_cfg_ = model_cfg.copy() + model_cfg_.pop('discriminator') + restorer = build_model(model_cfg_, train_cfg=train_cfg, test_cfg=test_cfg) + data_batch.pop('gt_unsharp') + with torch.no_grad(): + outputs = restorer(**data_batch, test_mode=True) + data_batch['gt_unsharp'] = targets.cpu() + assert torch.equal(outputs['lq'], data_batch['lq']) + assert torch.is_tensor(outputs['output']) + assert outputs['output'].size() == (1, 3, 128, 128) + + # test without pixel loss and perceptual loss + model_cfg_ = model_cfg.copy() + model_cfg_.pop('pixel_loss') + restorer = build_model(model_cfg_, train_cfg=None, test_cfg=None) + + outputs = restorer.train_step(data_batch, optimizer) + assert isinstance(outputs, dict) + assert isinstance(outputs['log_vars'], dict) + for v in ['loss_gan', 'loss_d_real', 'loss_d_fake']: + assert isinstance(outputs['log_vars'][v], float) + assert outputs['num_samples'] == 1 + assert torch.equal(outputs['results']['lq'], data_batch['lq']) + assert torch.equal(outputs['results']['gt'], data_batch['gt']) + assert torch.is_tensor(outputs['results']['output']) + assert outputs['results']['output'].size() == (1, 3, 128, 128) + + # test train_step w/o loss_percep + restorer = build_model(model_cfg, train_cfg=None, test_cfg=None) + with patch.object( + restorer, 'perceptual_loss', + return_value=(None, torch.tensor(2.0))): + outputs = restorer.train_step(data_batch, optimizer) + assert isinstance(outputs, dict) + assert isinstance(outputs['log_vars'], dict) + for v in [ + 'loss_style', 'loss_gan', 'loss_d_real', 'loss_d_fake', + 'loss_pix' + ]: + assert isinstance(outputs['log_vars'][v], float) + assert outputs['num_samples'] == 1 + assert torch.equal(outputs['results']['lq'], data_batch['lq']) + assert torch.equal(outputs['results']['gt'], data_batch['gt']) + assert torch.is_tensor(outputs['results']['output']) + assert outputs['results']['output'].size() == (1, 3, 128, 128) + + # test train_step w/o loss_style + restorer = build_model(model_cfg, train_cfg=None, test_cfg=None) + with patch.object( + restorer, 'perceptual_loss', + return_value=(torch.tensor(2.0), None)): + outputs = restorer.train_step(data_batch, optimizer) + assert isinstance(outputs, dict) + assert isinstance(outputs['log_vars'], dict) + for v in [ + 'loss_perceptual', 'loss_gan', 'loss_d_real', 'loss_d_fake', + 'loss_pix' + ]: + assert isinstance(outputs['log_vars'][v], float) + assert outputs['num_samples'] == 1 + assert torch.equal(outputs['results']['lq'], data_batch['lq']) + assert torch.equal(outputs['results']['gt'], data_batch['gt']) + assert torch.is_tensor(outputs['results']['output']) + assert outputs['results']['output'].size() == (1, 3, 128, 128) From 64a93b3eb2e2886e734b4cf435ec2ba661725815 Mon Sep 17 00:00:00 2001 From: ckkelvinchan Date: Thu, 23 Sep 2021 14:29:48 +0800 Subject: [PATCH 2/3] remove deepcopy --- mmedit/models/restorers/real_esrgan.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/mmedit/models/restorers/real_esrgan.py b/mmedit/models/restorers/real_esrgan.py index 5ca71f9ec7..6299615139 100644 --- a/mmedit/models/restorers/real_esrgan.py +++ b/mmedit/models/restorers/real_esrgan.py @@ -105,8 +105,6 @@ def train_step(self, data_batch, optimizer): else: self.generator.load_state_dict(self.generator_ema.state_dict()) - # self.generator = deepcopy(self.generator_ema) - # data lq = data_batch['lq'] gt = data_batch['gt'] From 6414162c6969c4a0dc66219a82a1d1863152ba8f Mon Sep 17 00:00:00 2001 From: ckkelvinchan Date: Thu, 23 Sep 2021 15:06:57 +0800 Subject: [PATCH 3/3] add test for start_iter --- tests/test_models/test_restorers/test_real_esrgan.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_models/test_restorers/test_real_esrgan.py b/tests/test_models/test_restorers/test_real_esrgan.py index 154f62a34f..ebfcd2aa70 100644 --- a/tests/test_models/test_restorers/test_real_esrgan.py +++ b/tests/test_models/test_restorers/test_real_esrgan.py @@ -169,13 +169,13 @@ def test_real_esrgan(): assert torch.is_tensor(outputs['results']['output']) assert outputs['results']['output'].size() == (1, 3, 128, 128) - # test disc_steps and disc_init_steps + # test disc_steps and disc_init_steps and start_iter data_batch = { 'lq': inputs.cpu(), 'gt': targets.cpu(), 'gt_unsharp': targets.cpu() } - train_cfg = dict(disc_steps=2, disc_init_steps=2) + train_cfg = dict(disc_steps=2, disc_init_steps=2, start_iter=0) restorer = build_model(model_cfg, train_cfg=train_cfg, test_cfg=test_cfg) with patch.object( restorer,