From b066bb6c4eab63b17a4d7816ea4e12a9f7792ab2 Mon Sep 17 00:00:00 2001 From: ys-li <56712176+Yshuo-Li@users.noreply.github.com> Date: Tue, 25 May 2021 20:04:44 +0800 Subject: [PATCH] [Feature] Add TTSR. (#321) * [Feature] Add TTSR. * Add test * Fix Co-authored-by: liyinshuo --- mmedit/models/restorers/__init__.py | 5 +- mmedit/models/restorers/ttsr.py | 238 ++++++++++++++++++++++++++++ tests/test_ttsr.py | 90 ++++++++++- 3 files changed, 331 insertions(+), 2 deletions(-) create mode 100644 mmedit/models/restorers/ttsr.py diff --git a/mmedit/models/restorers/__init__.py b/mmedit/models/restorers/__init__.py index 40eccbc07d..893cbaab09 100644 --- a/mmedit/models/restorers/__init__.py +++ b/mmedit/models/restorers/__init__.py @@ -4,5 +4,8 @@ from .esrgan import ESRGAN from .liif import LIIF from .srgan import SRGAN +from .ttsr import TTSR -__all__ = ['BasicRestorer', 'SRGAN', 'ESRGAN', 'EDVR', 'LIIF', 'BasicVSR'] +__all__ = [ + 'BasicRestorer', 'SRGAN', 'ESRGAN', 'EDVR', 'LIIF', 'BasicVSR', 'TTSR' +] diff --git a/mmedit/models/restorers/ttsr.py b/mmedit/models/restorers/ttsr.py new file mode 100644 index 0000000000..fca5b81e3e --- /dev/null +++ b/mmedit/models/restorers/ttsr.py @@ -0,0 +1,238 @@ +import numbers +import os.path as osp + +import mmcv +import torch + +from mmedit.core import tensor2img +from ..builder import build_backbone, build_component, build_loss +from ..registry import MODELS +from .basic_restorer import BasicRestorer + + +@MODELS.register_module() +class TTSR(BasicRestorer): + """TTSR model for Reference-based Image Super-Resolution. + + Paper: Learning Texture Transformer Network for Image Super-Resolution. + + Args: + generator (dict): Config for the generator. + extractor (dict): Config for the extractor. + transformer (dict): Config for the transformer. + pixel_loss (dict): Config for the pixel loss. + train_cfg (dict): Config for train. Default: None. + test_cfg (dict): Config for testing. Default: None. + pretrained (str): Path for pretrained model. Default: None. + """ + + def __init__(self, + generator, + extractor, + transformer, + pixel_loss, + train_cfg=None, + test_cfg=None, + pretrained=None): + super(BasicRestorer, self).__init__() + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + # model + self.generator = build_backbone(generator) + self.transformer = build_component(transformer) + self.extractor = build_component(extractor) + + # loss + self.pixel_loss = build_loss(pixel_loss) + + # pretrained + self.init_weights(pretrained) + + def forward_dummy(self, lq, lq_up, ref, ref_downup, only_pred=True): + """Forward of networks. + + Args: + lq (Tensor): LQ image. + lq_up (Tensor): Upsampled LQ image. + ref (Tensor): Reference image. + ref_downup (Tensor): Image generated by sequentially applying + bicubic down-sampling and up-sampling on reference image. + only_pred (bool): Only return predicted results or not. + Default: True. + + Returns: + pred (Tensor): Predicted super-resolution results (n, 3, 4h, 4w). + s (Tensor): Soft-Attention tensor with shape (n, 1, h, w). + t_level3 (Tensor): Transformed HR texture T in level3. + (n, 4c, h, w) + t_level2 (Tensor): Transformed HR texture T in level2. + (n, 2c, 2h, 2w) + t_level1 (Tensor): Transformed HR texture T in level1. + (n, c, 4h, 4w) + """ + + _, _, lq_up_level3 = self.extractor(lq_up) + _, _, ref_downup_level3 = self.extractor(ref_downup) + ref_level1, ref_level2, ref_level3 = self.extractor(ref) + + s, t_level3, t_level2, t_level1 = self.transformer( + lq_up_level3, ref_downup_level3, ref_level1, ref_level2, + ref_level3) + + pred = self.generator(lq, s, t_level3, t_level2, t_level1) + + if only_pred: + return pred + return pred, s, t_level3, t_level2, t_level1 + + def forward(self, lq, gt=None, test_mode=False, **kwargs): + """Forward function. + + Args: + lq (Tensor): Input lq images. + gt (Tensor): Ground-truth image. Default: None. + test_mode (bool): Whether in test mode or not. Default: False. + kwargs (dict): Other arguments. + """ + + if test_mode: + return self.forward_test(lq, gt=gt, **kwargs) + + return self.forward_dummy(lq, **kwargs) + + def train_step(self, data_batch, optimizer): + """Train step. + + Args: + data_batch (dict): A batch of data, which requires + 'lq', 'gt', 'lq_up', 'ref', 'ref_downup' + optimizer (obj): Optimizer. + + Returns: + dict: Returned output, which includes: + log_vars, num_samples, results (lq, gt and pred). + + """ + # data + lq = data_batch['lq'] + lq_up = data_batch['lq_up'] + gt = data_batch['gt'] + ref = data_batch['ref'] + ref_downup = data_batch['ref_downup'] + + # generate + pred = self.forward_dummy(lq, lq_up, ref, ref_downup) + + # loss + losses = dict() + + losses['loss_pix'] = self.pixel_loss(pred, gt) + + # parse loss + loss, log_vars = self.parse_losses(losses) + + # optimize + optimizer.zero_grad() + loss.backward() + optimizer.step() + + log_vars.pop('loss') # remove the unnecessary 'loss' + outputs = dict( + log_vars=log_vars, + num_samples=len(gt.data), + results=dict( + lq=lq.cpu(), gt=gt.cpu(), ref=ref.cpu(), output=pred.cpu())) + + return outputs + + def forward_test(self, + lq, + lq_up, + ref, + ref_downup, + gt=None, + meta=None, + save_image=False, + save_path=None, + iteration=None): + """Testing forward function. + + Args: + lq (Tensor): LQ image + gt (Tensor): GT image + lq_up (Tensor): Upsampled LQ image + ref (Tensor): Reference image + ref_downup (Tensor): Image generated by sequentially applying + bicubic down-sampling and up-sampling on reference image + meta (list[dict]): Meta data, such as path of GT file. + Default: None. + save_image (bool): Whether to save image. Default: False. + save_path (str): Path to save image. Default: None. + iteration (int): Iteration for the saving image name. + Default: None. + + Returns: + dict: Output results, which contain either key(s) + 1. 'eval_result'. + 2. 'lq', 'pred'. + 3. 'lq', 'pred', 'gt'. + """ + + # generator + with torch.no_grad(): + pred = self.forward_dummy( + lq=lq, lq_up=lq_up, ref=ref, ref_downup=ref_downup) + + pred = (pred + 1.) / 2. + if gt is not None: + gt = (gt + 1.) / 2. + + if self.test_cfg is not None and self.test_cfg.get('metrics', None): + assert gt is not None, ( + 'evaluation with metrics must have gt images.') + results = dict(eval_result=self.evaluate(pred, gt)) + else: + results = dict(lq=lq.cpu(), output=pred.cpu()) + if gt is not None: + results['gt'] = gt.cpu() + + # save image + if save_image: + if 'gt_path' in meta[0]: + the_path = meta[0]['gt_path'] + else: + the_path = meta[0]['lq_path'] + folder_name = osp.splitext(osp.basename(the_path))[0] + if isinstance(iteration, numbers.Number): + save_path = osp.join(save_path, folder_name, + f'{folder_name}-{iteration + 1:06d}.png') + elif iteration is None: + save_path = osp.join(save_path, f'{folder_name}.png') + else: + raise ValueError('iteration should be number or None, ' + f'but got {type(iteration)}') + mmcv.imwrite(tensor2img(pred), save_path) + + return results + + def init_weights(self, pretrained=None, strict=True): + """Init weights for models. + + Args: + pretrained (str, optional): Path for pretrained weights. If given + None, pretrained weights will not be loaded. Defaults to None. + strict (boo, optional): Whether strictly load the pretrained model. + Defaults to True. + """ + if isinstance(pretrained, str): + if self.generator: + self.generator.init_weights(pretrained, strict) + if self.extractor: + self.extractor.init_weights(pretrained, strict) + if self.transformer: + self.transformer.init_weights(pretrained, strict) + elif pretrained is not None: + raise TypeError('"pretrained" must be a str or None. ' + f'But received {type(pretrained)}.') diff --git a/tests/test_ttsr.py b/tests/test_ttsr.py index 1e7d99af41..28e19a3f26 100644 --- a/tests/test_ttsr.py +++ b/tests/test_ttsr.py @@ -1,6 +1,9 @@ +import numpy as np import torch +from mmcv.runner import obj_from_dict +from mmcv.utils.config import Config -from mmedit.models import build_backbone +from mmedit.models import build_backbone, build_model from mmedit.models.backbones.sr_backbones.ttsr_net import (CSFI2, CSFI3, SFE, MergeFeatures) @@ -56,3 +59,88 @@ def test_ttsr_net(): outputs = ttsr(inputs, s, t_level3, t_level2, t_level1) assert outputs.shape == (2, 3, 96, 96) + + +def test_ttsr(): + + model_cfg = dict( + type='TTSR', + generator=dict( + type='TTSRNet', + in_channels=3, + out_channels=3, + mid_channels=64, + num_blocks=(16, 16, 8, 4)), + extractor=dict(type='LTE'), + transformer=dict(type='SearchTransformer'), + pixel_loss=dict(type='L1Loss', loss_weight=1.0, reduction='mean')) + + scale = 4 + train_cfg = None + test_cfg = Config(dict(metrics=['PSNR', 'SSIM'], crop_border=scale)) + + # build restorer + restorer = build_model(model_cfg, train_cfg=train_cfg, test_cfg=test_cfg) + + # test attributes + assert restorer.__class__.__name__ == 'TTSR' + + # prepare data + inputs = torch.rand(1, 3, 16, 16) + targets = torch.rand(1, 3, 64, 64) + ref = torch.rand(1, 3, 64, 64) + data_batch = { + 'lq': inputs, + 'gt': targets, + 'ref': ref, + 'lq_up': ref, + 'ref_downup': ref + } + + # prepare optimizer + optim_cfg = dict(type='Adam', lr=1e-4, betas=(0.9, 0.999)) + optimizer = obj_from_dict(optim_cfg, torch.optim, + dict(params=restorer.parameters())) + + # test train_step and forward_test (cpu) + outputs = restorer.train_step(data_batch, optimizer) + assert isinstance(outputs, dict) + assert isinstance(outputs['log_vars'], dict) + assert isinstance(outputs['log_vars']['loss_pix'], float) + assert outputs['num_samples'] == 1 + assert outputs['results']['lq'].shape == data_batch['lq'].shape + assert outputs['results']['gt'].shape == data_batch['gt'].shape + assert torch.is_tensor(outputs['results']['output']) + assert outputs['results']['output'].size() == (1, 3, 64, 64) + + # test train_step and forward_test (gpu) + if torch.cuda.is_available(): + restorer = restorer.cuda() + data_batch = { + 'lq': inputs.cuda(), + 'gt': targets.cuda(), + 'ref': ref.cuda(), + 'lq_up': ref.cuda(), + 'ref_downup': ref.cuda() + } + + # train_step + optimizer = obj_from_dict(optim_cfg, torch.optim, + dict(params=restorer.parameters())) + outputs = restorer.train_step(data_batch, optimizer) + assert isinstance(outputs, dict) + assert isinstance(outputs['log_vars'], dict) + assert isinstance(outputs['log_vars']['loss_pix'], float) + assert outputs['num_samples'] == 1 + assert outputs['results']['lq'].shape == data_batch['lq'].shape + assert outputs['results']['gt'].shape == data_batch['gt'].shape + assert torch.is_tensor(outputs['results']['output']) + assert outputs['results']['output'].size() == (1, 3, 64, 64) + + # val_step + result = restorer.val_step(data_batch, meta=[{'gt_path': ''}]) + assert isinstance(result, dict) + assert isinstance(result['eval_result'], dict) + assert result['eval_result'].keys() == set({'PSNR', 'SSIM'}) + assert isinstance(result['eval_result']['PSNR'], np.float64) + assert isinstance(result['eval_result']['SSIM'], np.float64)