From f246dde738407c6be32ac250173f9caf9f4408cd Mon Sep 17 00:00:00 2001 From: liyinshuo Date: Thu, 20 May 2021 11:12:31 +0800 Subject: [PATCH] [Feature] Add TTSR. --- mmedit/models/restorers/__init__.py | 5 +- mmedit/models/restorers/ttsr.py | 257 ++++++++++++++++++++++++++++ tests/test_ttsr.py | 5 + 3 files changed, 266 insertions(+), 1 deletion(-) create mode 100644 mmedit/models/restorers/ttsr.py diff --git a/mmedit/models/restorers/__init__.py b/mmedit/models/restorers/__init__.py index 40eccbc07d..893cbaab09 100644 --- a/mmedit/models/restorers/__init__.py +++ b/mmedit/models/restorers/__init__.py @@ -4,5 +4,8 @@ from .esrgan import ESRGAN from .liif import LIIF from .srgan import SRGAN +from .ttsr import TTSR -__all__ = ['BasicRestorer', 'SRGAN', 'ESRGAN', 'EDVR', 'LIIF', 'BasicVSR'] +__all__ = [ + 'BasicRestorer', 'SRGAN', 'ESRGAN', 'EDVR', 'LIIF', 'BasicVSR', 'TTSR' +] diff --git a/mmedit/models/restorers/ttsr.py b/mmedit/models/restorers/ttsr.py new file mode 100644 index 0000000000..ed49e6e39f --- /dev/null +++ b/mmedit/models/restorers/ttsr.py @@ -0,0 +1,257 @@ +import numbers +import os.path as osp + +import mmcv +import torch + +from mmedit.core import tensor2img +from ..builder import build_backbone, build_component, build_loss +from ..registry import MODELS +from .basic_restorer import BasicRestorer + + +@MODELS.register_module() +class TTSR(BasicRestorer): + """TTSR model for Reference-based Image Super-Resolution. + + Paper: Learning Texture Transformer Network for Image Super-Resolution. + + Args: + generator (dict): Config for the generator + extractor (dict): Config for the extractor + transformer (dict): Config for the transformer + pixel_loss (dict): Config for the pixel loss. Default: None + train_cfg (dict): Config for train. Default: None + test_cfg (dict): Config for testing. Default: None + pretrained (str): Path for pretrained model. Default: None + """ + + def __init__(self, + generator, + extractor, + transformer, + pixel_loss=None, + train_cfg=None, + test_cfg=None, + pretrained=None): + super(BasicRestorer, self).__init__() + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + # model + self.generator = build_backbone(generator) + self.transformer = build_component(transformer) + self.extractor = build_component(extractor) + + # loss + self.pixel_loss = build_loss(pixel_loss) if pixel_loss else None + + # pretrained + if pretrained: + self.init_weights(pretrained) + + def forward_dummy(self, lq, lq_pad, ref, ref_pad, only_pred=True): + """Forward of networks. + + Args: + lq (Tensor): LQ image + lq_pad (Tensor): Upsampled LQ image + ref (Tensor): Reference image + ref_pad (Tensor): Image generated by sequentially applying + bicubic down-sampling and up-sampling on reference image + only_pred (bool): Only return pred or not. Default: True + + Returns: + pred (Tensor): Predicted super-resolution results (n, 3, 4h, 4w) + s (Tensor): Soft-Attention tensor with shape (n, 1, h, w). + t_level3 (Tensor): Transferred HR texture T in level3. + (n, 4c, h, w) + t_level2 (Tensor): Transferred HR texture T in level2. + (n, 2c, 2h, 2w) + t_level1 (Tensor): Transferred HR texture T in level1. + (n, c, 4h, 4w) + """ + + _, _, lq_pad_level3 = self.extractor((lq_pad.detach() + 1.) / 2.) + _, _, ref_pad_level3 = self.extractor((ref_pad.detach() + 1.) / 2.) + ref_level1, ref_level2, ref_level3 = self.extractor( + (ref.detach() + 1.) / 2.) + + s, t_level3, t_level2, t_level1 = self.transformer( + lq_pad_level3, ref_pad_level3, ref_level1, ref_level2, ref_level3) + + pred = self.generator(lq, s, t_level3, t_level2, t_level1) + + if only_pred: + return pred + return pred, s, t_level3, t_level2, t_level1 + + def forward(self, lq, gt=None, test_mode=False, **kwargs): + """Forward function. + + Args: + lq (Tensor): Input lq images. + gt (Tensor): Ground-truth image. Default: None. + test_mode (bool): Whether in test mode or not. Default: False. + kwargs (dict): Other arguments. + """ + + if test_mode: + return self.forward_test(lq, gt=gt, **kwargs) + + return self.forward_dummy(lq, **kwargs) + + def train_step(self, data_batch, optimizer): + """Train step. + + Args: + data_batch (dict): A batch of data, which requires + 'lq', 'gt', 'lq_pad', 'ref', 'ref_pad' + optimizer (obj): Optimizer. + + Returns: + dict: Returned output, which includes: + log_vars, num_samples, results (lq, gt and pred). + + """ + # data + lq = data_batch['lq'] + lq_pad = data_batch['lq_pad'] + gt = data_batch['gt'] + ref = data_batch['ref'] + ref_pad = data_batch['ref_pad'] + + # generate + pred = self.forward_dummy(lq, lq_pad, ref, ref_pad) + + # loss + losses = dict() + log_vars = dict() + + losses['loss_pix'] = self.pixel_loss(pred, gt) + + # parse loss + loss, log_vars = self.parse_losses(losses) + + # optimize + optimizer.zero_grad() + loss.backward() + optimizer.step() + + log_vars.pop('loss') # remove the unnecessary 'loss' + outputs = dict( + log_vars=log_vars, + num_samples=len(gt.data), + results=dict( + lq=lq.cpu(), gt=gt.cpu(), ref=ref.cpu(), output=pred.cpu())) + + pred = None + loss = None + + return outputs + + def forward_test(self, + lq, + lq_pad, + ref, + ref_pad, + gt=None, + meta=None, + save_image=False, + save_path=None, + iteration=None): + """Testing forward function. + + Args: + lq (Tensor): LQ image + gt (Tensor): GT image + lq_pad (Tensor): Upsampled LQ image + ref (Tensor): Reference image + ref_pad (Tensor): Image generated by sequentially applying + bicubic down-sampling and up-sampling on reference image + meta (list[dict]): Meta data, such as path of GT file. + Default: None. + save_image (bool): Whether to save image. Default: False. + save_path (str): Path to save image. Default: None. + iteration (int): Iteration for the saving image name. + Default: None. + + Returns: + dict: Output results, which contain either key(s) + 1. 'eval_result'. + 2. 'lq', 'pred'. + 3. 'lq', 'pred', 'gt'. + """ + + # generator + with torch.no_grad(): + pred = self.forward_dummy( + lq=lq, lq_pad=lq_pad, ref=ref, ref_pad=ref_pad) + + pred = (pred + 1.) / 2. + if gt is not None: + gt = (gt + 1.) / 2. + + if self.test_cfg is not None and self.test_cfg.get('metrics', None): + assert gt is not None, ( + 'evaluation with metrics must have gt images.') + results = dict(eval_result=self.evaluate(pred, gt)) + else: + results = dict(lq=lq.cpu(), output=pred.cpu()) + if gt is not None: + results['gt'] = gt.cpu() + + # save image + if save_image: + if 'gt_path' in meta[0]: + the_path = meta[0]['gt_path'] + else: + the_path = meta[0]['lq_path'] + folder_name = osp.splitext(osp.basename(the_path))[0] + if isinstance(iteration, numbers.Number): + save_path = osp.join(save_path, folder_name, + f'{folder_name}-{iteration + 1:06d}.png') + elif iteration is None: + save_path = osp.join(save_path, f'{folder_name}.png') + else: + raise ValueError('iteration should be number or None, ' + f'but got {type(iteration)}') + mmcv.imwrite(tensor2img(pred), save_path) + + return results + + def val_step(self, data_batch, **kwargs): + """Validation step. + + Args: + data_batch (dict): A batch of data. + kwargs (dict): Other arguments for ``val_step``. + + Returns: + dict: Returned output. + """ + output = self.forward_test(**data_batch, **kwargs) + return output + + def init_weights(self, pretrained=None, strict=True): + """Init weights for models. + + Args: + pretrained (str, optional): Path for pretrained weights. If given + None, pretrained weights will not be loaded. Defaults to None. + strict (boo, optional): Whether strictly load the pretrained model. + Defaults to True. + """ + if isinstance(pretrained, str): + if self.generator: + self.generator.init_weights(pretrained, strict) + if self.extractor: + self.extractor.init_weights(pretrained, strict) + if self.transformer: + self.transformer.init_weights(pretrained, strict) + elif pretrained is None: + pass # use default initialization + else: + raise TypeError('"pretrained" must be a str or None. ' + f'But received {type(pretrained)}.') diff --git a/tests/test_ttsr.py b/tests/test_ttsr.py index 1e7d99af41..02333ddfc0 100644 --- a/tests/test_ttsr.py +++ b/tests/test_ttsr.py @@ -56,3 +56,8 @@ def test_ttsr_net(): outputs = ttsr(inputs, s, t_level3, t_level2, t_level1) assert outputs.shape == (2, 3, 96, 96) + + +def test_ttsr(): + # TODO wait for ttsr_net, lte, and search_transofrmer. + pass