Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Real-ESRGAN model #546

Merged
merged 3 commits into from
Sep 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mmedit/models/restorers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
from .esrgan import ESRGAN
from .glean import GLEAN
from .liif import LIIF
from .real_esrgan import RealESRGAN
from .srgan import SRGAN
from .tdan import TDAN
from .ttsr import TTSR

__all__ = [
'BasicRestorer', 'SRGAN', 'ESRGAN', 'EDVR', 'LIIF', 'BasicVSR', 'TTSR',
'GLEAN', 'TDAN', 'DIC'
'GLEAN', 'TDAN', 'DIC', 'RealESRGAN'
]
233 changes: 233 additions & 0 deletions mmedit/models/restorers/real_esrgan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
import numbers
import os.path as osp
from copy import deepcopy

import mmcv
import torch
from mmcv.parallel import is_module_wrapper

from mmedit.core import tensor2img
from ..common import set_requires_grad
from ..registry import MODELS
from .srgan import SRGAN


@MODELS.register_module()
class RealESRGAN(SRGAN):
"""Real-ESRGAN model for single image super-resolution.

Ref:
Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure
Synthetic Data, 2021.

Args:
generator (dict): Config for the generator.
discriminator (dict, optional): Config for the discriminator.
Default: None.
gan_loss (dict, optional): Config for the gan loss.
Note that the loss weight in gan loss is only for the generator.
pixel_loss (dict, optional): Config for the pixel loss. Default: None.
perceptual_loss (dict, optional): Config for the perceptual loss.
Default: None.
is_use_sharpened_gt_in_pixel (bool, optional): Whether to use the image
sharpened by unsharp masking as the GT for pixel loss.
Default: False.
is_use_sharpened_gt_in_percep (bool, optional): Whether to use the
image sharpened by unsharp masking as the GT for perceptual loss.
Default: False.
is_use_sharpened_gt_in_gan (bool, optional): Whether to use the
image sharpened by unsharp masking as the GT for adversarial loss.
Default: False.
is_use_ema (bool, optional): When to apply exponential moving average
on the network weights. Default: True.
train_cfg (dict): Config for training. Default: None.
You may change the training of gan by setting:
`disc_steps`: how many discriminator updates after one generate
update;
`disc_init_steps`: how many discriminator updates at the start of
the training.
These two keys are useful when training with WGAN.
test_cfg (dict): Config for testing. Default: None.
pretrained (str): Path for pretrained model. Default: None.
"""

def __init__(self,
generator,
discriminator=None,
gan_loss=None,
pixel_loss=None,
perceptual_loss=None,
is_use_sharpened_gt_in_pixel=False,
is_use_sharpened_gt_in_percep=False,
is_use_sharpened_gt_in_gan=False,
is_use_ema=True,
train_cfg=None,
test_cfg=None,
pretrained=None):

super().__init__(generator, discriminator, gan_loss, pixel_loss,
perceptual_loss, train_cfg, test_cfg, pretrained)

self.is_use_sharpened_gt_in_pixel = is_use_sharpened_gt_in_pixel
self.is_use_sharpened_gt_in_percep = is_use_sharpened_gt_in_percep
self.is_use_sharpened_gt_in_gan = is_use_sharpened_gt_in_gan

self.is_use_ema = is_use_ema
if is_use_ema:
self.generator_ema = deepcopy(self.generator)
else:
self.generator_ema = None

del self.step_counter
self.register_buffer('step_counter', torch.zeros(1))

if train_cfg is not None: # used for initializeing from ema model
self.start_iter = train_cfg.get('start_iter', -1)
else:
self.start_iter = -1

def train_step(self, data_batch, optimizer):
"""Train step.

Args:
data_batch (dict): A batch of data.
optimizer (obj): Optimizer.

Returns:
dict: Returned output.
"""
# during initialization, load weights from the ema model
if (self.step_counter == self.start_iter
and self.generator_ema is not None):
if is_module_wrapper(self.generator):
self.generator.module.load_state_dict(
self.generator_ema.module.state_dict())
else:
self.generator.load_state_dict(self.generator_ema.state_dict())

# data
lq = data_batch['lq']
gt = data_batch['gt']

gt_pixel, gt_percep, gt_gan = gt.clone(), gt.clone(), gt.clone()
if self.is_use_sharpened_gt_in_pixel:
gt_pixel = data_batch['gt_unsharp']
if self.is_use_sharpened_gt_in_percep:
gt_percep = data_batch['gt_unsharp']
if self.is_use_sharpened_gt_in_gan:
gt_gan = data_batch['gt_unsharp']

# generator
fake_g_output = self.generator(lq)

losses = dict()
log_vars = dict()

# no updates to discriminator parameters.
if self.gan_loss:
set_requires_grad(self.discriminator, False)

if (self.step_counter % self.disc_steps == 0
and self.step_counter >= self.disc_init_steps):
if self.pixel_loss:
losses['loss_pix'] = self.pixel_loss(fake_g_output, gt_pixel)
if self.perceptual_loss:
loss_percep, loss_style = self.perceptual_loss(
fake_g_output, gt_percep)
if loss_percep is not None:
losses['loss_perceptual'] = loss_percep
if loss_style is not None:
losses['loss_style'] = loss_style

# gan loss for generator
if self.gan_loss:
fake_g_pred = self.discriminator(fake_g_output)
losses['loss_gan'] = self.gan_loss(
fake_g_pred, target_is_real=True, is_disc=False)

# parse loss
loss_g, log_vars_g = self.parse_losses(losses)
log_vars.update(log_vars_g)

# optimize
optimizer['generator'].zero_grad()
loss_g.backward()
optimizer['generator'].step()

# discriminator
if self.gan_loss:
set_requires_grad(self.discriminator, True)
# real
real_d_pred = self.discriminator(gt_gan)
loss_d_real = self.gan_loss(
real_d_pred, target_is_real=True, is_disc=True)
loss_d, log_vars_d = self.parse_losses(
dict(loss_d_real=loss_d_real))
optimizer['discriminator'].zero_grad()
loss_d.backward()
log_vars.update(log_vars_d)
# fake
fake_d_pred = self.discriminator(fake_g_output.detach())
loss_d_fake = self.gan_loss(
fake_d_pred, target_is_real=False, is_disc=True)
loss_d, log_vars_d = self.parse_losses(
dict(loss_d_fake=loss_d_fake))
loss_d.backward()
log_vars.update(log_vars_d)

optimizer['discriminator'].step()

self.step_counter += 1

log_vars.pop('loss') # remove the unnecessary 'loss'
outputs = dict(
log_vars=log_vars,
num_samples=len(gt.data),
results=dict(lq=lq.cpu(), gt=gt.cpu(), output=fake_g_output.cpu()))

return outputs

def forward_test(self,
lq,
gt=None,
meta=None,
save_image=False,
save_path=None,
iteration=None):
"""Testing forward function.

Args:
lq (Tensor): LQ Tensor with shape (n, c, h, w).
gt (Tensor): GT Tensor with shape (n, c, h, w). Default: None.
save_image (bool): Whether to save image. Default: False.
save_path (str): Path to save image. Default: None.
iteration (int): Iteration for the saving image name.
Default: None.

Returns:
dict: Output results.
"""
_model = self.generator_ema if self.is_use_ema else self.generator
output = _model(lq)

if self.test_cfg is not None and self.test_cfg.get(
'metrics', None) and gt is not None:
results = dict(eval_result=self.evaluate(output, gt))
else:
results = dict(lq=lq.cpu(), output=output.cpu())

# save image
if save_image:
lq_path = meta[0]['lq_path']
folder_name = osp.splitext(osp.basename(lq_path))[0]
if isinstance(iteration, numbers.Number):
save_path = osp.join(save_path, folder_name,
f'{folder_name}-{iteration + 1:06d}.png')
elif iteration is None:
save_path = osp.join(save_path, f'{folder_name}.png')
else:
raise ValueError('iteration should be number or None, '
f'but got {type(iteration)}')
mmcv.imwrite(tensor2img(output), save_path)

return results
Loading