forked from open-mmlab/mmagic
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Real-ESRGAN model (open-mmlab#546)
* add RealESRGAN model * remove deepcopy * add test for start_iter
- Loading branch information
1 parent
a40653c
commit 522e77b
Showing
3 changed files
with
494 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,233 @@ | ||
import numbers | ||
import os.path as osp | ||
from copy import deepcopy | ||
|
||
import mmcv | ||
import torch | ||
from mmcv.parallel import is_module_wrapper | ||
|
||
from mmedit.core import tensor2img | ||
from ..common import set_requires_grad | ||
from ..registry import MODELS | ||
from .srgan import SRGAN | ||
|
||
|
||
@MODELS.register_module() | ||
class RealESRGAN(SRGAN): | ||
"""Real-ESRGAN model for single image super-resolution. | ||
Ref: | ||
Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure | ||
Synthetic Data, 2021. | ||
Args: | ||
generator (dict): Config for the generator. | ||
discriminator (dict, optional): Config for the discriminator. | ||
Default: None. | ||
gan_loss (dict, optional): Config for the gan loss. | ||
Note that the loss weight in gan loss is only for the generator. | ||
pixel_loss (dict, optional): Config for the pixel loss. Default: None. | ||
perceptual_loss (dict, optional): Config for the perceptual loss. | ||
Default: None. | ||
is_use_sharpened_gt_in_pixel (bool, optional): Whether to use the image | ||
sharpened by unsharp masking as the GT for pixel loss. | ||
Default: False. | ||
is_use_sharpened_gt_in_percep (bool, optional): Whether to use the | ||
image sharpened by unsharp masking as the GT for perceptual loss. | ||
Default: False. | ||
is_use_sharpened_gt_in_gan (bool, optional): Whether to use the | ||
image sharpened by unsharp masking as the GT for adversarial loss. | ||
Default: False. | ||
is_use_ema (bool, optional): When to apply exponential moving average | ||
on the network weights. Default: True. | ||
train_cfg (dict): Config for training. Default: None. | ||
You may change the training of gan by setting: | ||
`disc_steps`: how many discriminator updates after one generate | ||
update; | ||
`disc_init_steps`: how many discriminator updates at the start of | ||
the training. | ||
These two keys are useful when training with WGAN. | ||
test_cfg (dict): Config for testing. Default: None. | ||
pretrained (str): Path for pretrained model. Default: None. | ||
""" | ||
|
||
def __init__(self, | ||
generator, | ||
discriminator=None, | ||
gan_loss=None, | ||
pixel_loss=None, | ||
perceptual_loss=None, | ||
is_use_sharpened_gt_in_pixel=False, | ||
is_use_sharpened_gt_in_percep=False, | ||
is_use_sharpened_gt_in_gan=False, | ||
is_use_ema=True, | ||
train_cfg=None, | ||
test_cfg=None, | ||
pretrained=None): | ||
|
||
super().__init__(generator, discriminator, gan_loss, pixel_loss, | ||
perceptual_loss, train_cfg, test_cfg, pretrained) | ||
|
||
self.is_use_sharpened_gt_in_pixel = is_use_sharpened_gt_in_pixel | ||
self.is_use_sharpened_gt_in_percep = is_use_sharpened_gt_in_percep | ||
self.is_use_sharpened_gt_in_gan = is_use_sharpened_gt_in_gan | ||
|
||
self.is_use_ema = is_use_ema | ||
if is_use_ema: | ||
self.generator_ema = deepcopy(self.generator) | ||
else: | ||
self.generator_ema = None | ||
|
||
del self.step_counter | ||
self.register_buffer('step_counter', torch.zeros(1)) | ||
|
||
if train_cfg is not None: # used for initializeing from ema model | ||
self.start_iter = train_cfg.get('start_iter', -1) | ||
else: | ||
self.start_iter = -1 | ||
|
||
def train_step(self, data_batch, optimizer): | ||
"""Train step. | ||
Args: | ||
data_batch (dict): A batch of data. | ||
optimizer (obj): Optimizer. | ||
Returns: | ||
dict: Returned output. | ||
""" | ||
# during initialization, load weights from the ema model | ||
if (self.step_counter == self.start_iter | ||
and self.generator_ema is not None): | ||
if is_module_wrapper(self.generator): | ||
self.generator.module.load_state_dict( | ||
self.generator_ema.module.state_dict()) | ||
else: | ||
self.generator.load_state_dict(self.generator_ema.state_dict()) | ||
|
||
# data | ||
lq = data_batch['lq'] | ||
gt = data_batch['gt'] | ||
|
||
gt_pixel, gt_percep, gt_gan = gt.clone(), gt.clone(), gt.clone() | ||
if self.is_use_sharpened_gt_in_pixel: | ||
gt_pixel = data_batch['gt_unsharp'] | ||
if self.is_use_sharpened_gt_in_percep: | ||
gt_percep = data_batch['gt_unsharp'] | ||
if self.is_use_sharpened_gt_in_gan: | ||
gt_gan = data_batch['gt_unsharp'] | ||
|
||
# generator | ||
fake_g_output = self.generator(lq) | ||
|
||
losses = dict() | ||
log_vars = dict() | ||
|
||
# no updates to discriminator parameters. | ||
if self.gan_loss: | ||
set_requires_grad(self.discriminator, False) | ||
|
||
if (self.step_counter % self.disc_steps == 0 | ||
and self.step_counter >= self.disc_init_steps): | ||
if self.pixel_loss: | ||
losses['loss_pix'] = self.pixel_loss(fake_g_output, gt_pixel) | ||
if self.perceptual_loss: | ||
loss_percep, loss_style = self.perceptual_loss( | ||
fake_g_output, gt_percep) | ||
if loss_percep is not None: | ||
losses['loss_perceptual'] = loss_percep | ||
if loss_style is not None: | ||
losses['loss_style'] = loss_style | ||
|
||
# gan loss for generator | ||
if self.gan_loss: | ||
fake_g_pred = self.discriminator(fake_g_output) | ||
losses['loss_gan'] = self.gan_loss( | ||
fake_g_pred, target_is_real=True, is_disc=False) | ||
|
||
# parse loss | ||
loss_g, log_vars_g = self.parse_losses(losses) | ||
log_vars.update(log_vars_g) | ||
|
||
# optimize | ||
optimizer['generator'].zero_grad() | ||
loss_g.backward() | ||
optimizer['generator'].step() | ||
|
||
# discriminator | ||
if self.gan_loss: | ||
set_requires_grad(self.discriminator, True) | ||
# real | ||
real_d_pred = self.discriminator(gt_gan) | ||
loss_d_real = self.gan_loss( | ||
real_d_pred, target_is_real=True, is_disc=True) | ||
loss_d, log_vars_d = self.parse_losses( | ||
dict(loss_d_real=loss_d_real)) | ||
optimizer['discriminator'].zero_grad() | ||
loss_d.backward() | ||
log_vars.update(log_vars_d) | ||
# fake | ||
fake_d_pred = self.discriminator(fake_g_output.detach()) | ||
loss_d_fake = self.gan_loss( | ||
fake_d_pred, target_is_real=False, is_disc=True) | ||
loss_d, log_vars_d = self.parse_losses( | ||
dict(loss_d_fake=loss_d_fake)) | ||
loss_d.backward() | ||
log_vars.update(log_vars_d) | ||
|
||
optimizer['discriminator'].step() | ||
|
||
self.step_counter += 1 | ||
|
||
log_vars.pop('loss') # remove the unnecessary 'loss' | ||
outputs = dict( | ||
log_vars=log_vars, | ||
num_samples=len(gt.data), | ||
results=dict(lq=lq.cpu(), gt=gt.cpu(), output=fake_g_output.cpu())) | ||
|
||
return outputs | ||
|
||
def forward_test(self, | ||
lq, | ||
gt=None, | ||
meta=None, | ||
save_image=False, | ||
save_path=None, | ||
iteration=None): | ||
"""Testing forward function. | ||
Args: | ||
lq (Tensor): LQ Tensor with shape (n, c, h, w). | ||
gt (Tensor): GT Tensor with shape (n, c, h, w). Default: None. | ||
save_image (bool): Whether to save image. Default: False. | ||
save_path (str): Path to save image. Default: None. | ||
iteration (int): Iteration for the saving image name. | ||
Default: None. | ||
Returns: | ||
dict: Output results. | ||
""" | ||
_model = self.generator_ema if self.is_use_ema else self.generator | ||
output = _model(lq) | ||
|
||
if self.test_cfg is not None and self.test_cfg.get( | ||
'metrics', None) and gt is not None: | ||
results = dict(eval_result=self.evaluate(output, gt)) | ||
else: | ||
results = dict(lq=lq.cpu(), output=output.cpu()) | ||
|
||
# save image | ||
if save_image: | ||
lq_path = meta[0]['lq_path'] | ||
folder_name = osp.splitext(osp.basename(lq_path))[0] | ||
if isinstance(iteration, numbers.Number): | ||
save_path = osp.join(save_path, folder_name, | ||
f'{folder_name}-{iteration + 1:06d}.png') | ||
elif iteration is None: | ||
save_path = osp.join(save_path, f'{folder_name}.png') | ||
else: | ||
raise ValueError('iteration should be number or None, ' | ||
f'but got {type(iteration)}') | ||
mmcv.imwrite(tensor2img(output), save_path) | ||
|
||
return results |
Oops, something went wrong.