From 5ad9339b51fcd088ef5d60f0da037b9265cbc2b2 Mon Sep 17 00:00:00 2001 From: liyinshuo Date: Tue, 8 Jun 2021 11:32:16 +0800 Subject: [PATCH 1/5] [Feature] Add DIC --- mmedit/models/restorers/__init__.py | 3 +- mmedit/models/restorers/dic.py | 201 ++++++++++++++++++++++++++++ 2 files changed, 203 insertions(+), 1 deletion(-) create mode 100644 mmedit/models/restorers/dic.py diff --git a/mmedit/models/restorers/__init__.py b/mmedit/models/restorers/__init__.py index b0d899e35e..1f5e41eb89 100644 --- a/mmedit/models/restorers/__init__.py +++ b/mmedit/models/restorers/__init__.py @@ -1,5 +1,6 @@ from .basic_restorer import BasicRestorer from .basicvsr import BasicVSR +from .dic import DIC from .edvr import EDVR from .esrgan import ESRGAN from .glean import GLEAN @@ -10,5 +11,5 @@ __all__ = [ 'BasicRestorer', 'SRGAN', 'ESRGAN', 'EDVR', 'LIIF', 'BasicVSR', 'TTSR', - 'GLEAN', 'TDAN' + 'GLEAN', 'TDAN', 'DIC' ] diff --git a/mmedit/models/restorers/dic.py b/mmedit/models/restorers/dic.py new file mode 100644 index 0000000000..b2e4f592d4 --- /dev/null +++ b/mmedit/models/restorers/dic.py @@ -0,0 +1,201 @@ +import numbers +import os.path as osp +from collections import OrderedDict + +import mmcv +import torch + +from mmedit.core import tensor2img +from mmedit.models.common import ImgNormalize +from ..builder import build_backbone, build_loss +from ..registry import MODELS +from .basic_restorer import BasicRestorer + + +@MODELS.register_module() +class DIC(BasicRestorer): + """DIC model for Face Super-Resolution. + + Paper: Deep Face Super-Resolution with Iterative Collaboration between + Attentive Recovery and Landmark Estimation. + + Args: + generator (dict): Config for the generator. + pixel_loss (dict): Config for the pixel loss. + align_loss (dict): Config for thr align loss. + train_cfg (dict): Config for train. Default: None. + test_cfg (dict): Config for testing. Default: None. + pretrained (str): Path for pretrained model. Default: None. + """ + + def __init__(self, + generator, + pixel_loss=None, + align_loss=None, + train_cfg=None, + test_cfg=None, + pretrained=None): + super(BasicRestorer, self).__init__() + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + # model + self.generator = build_backbone(generator) + self.img_normalize = ImgNormalize( + pixel_range=1, + img_mean=(129.795, 108.12, 96.39), + img_std=(255, 255, 255)) + self.img_denormalize = ImgNormalize( + pixel_range=1, + img_mean=(0.509, 0.424, 0.378), + img_std=(1., 1., 1.), + sign=1) + + # loss + self.pixel_loss = build_loss(pixel_loss) if pixel_loss else None + self.align_loss = build_loss(align_loss) if align_loss else None + + # pretrained + if pretrained: + self.init_weights(pretrained) + + def forward(self, lq, gt=None, test_mode=False, **kwargs): + """Forward function. + + Args: + lq (Tensor): Input lq images. + gt (Tensor): Ground-truth image. Default: None. + test_mode (bool): Whether in test mode or not. Default: False. + kwargs (dict): Other arguments. + """ + + if test_mode: + return self.forward_test(lq, gt=gt, **kwargs) + + return self.generator.forward(lq) + + def train_step(self, data_batch, optimizer): + """Train step. + + Args: + data_batch (dict): A batch of data, which requires + 'lq', 'gt' + optimizer (obj): Optimizer. + + Returns: + dict: Returned output, which includes: + log_vars, num_samples, results (lq, gt and pred). + + """ + # data + lq = data_batch['lq'] + gt = data_batch['gt'] + gt_heatmap = data_batch['heatmap'] + + # generate + sr_list, heatmap_list = self.generator.forward(lq) + + # loss + losses = OrderedDict() + + loss_pix = 0.0 + loss_align = 0.0 + for step, (sr, heatmap) in enumerate(zip(sr_list, heatmap_list)): + losses[f'loss_pixel_v{step}'] = self.pixel_loss(sr, gt) + loss_pix += losses[f'loss_pixel_v{step}'] + losses[f'loss_align_v{step}'] = self.pixel_loss( + heatmap, gt_heatmap) + loss_align += losses[f'loss_align_v{step}'] + + # parse loss + loss, log_vars = self.parse_losses(losses) + + # optimize + optimizer['generator'].zero_grad() + loss.backward() + optimizer['generator'].step() + + log_vars.pop('loss') # remove the unnecessary 'loss' + outputs = dict( + log_vars=log_vars, + num_samples=len(gt.data), + results=dict(lq=lq.cpu(), gt=gt.cpu(), output=sr_list[-1].cpu())) + + return outputs + + def forward_test(self, + lq, + gt=None, + meta=None, + save_image=False, + save_path=None, + iteration=None): + """Testing forward function. + + Args: + lq (Tensor): LQ image. + gt (Tensor): GT image. + meta (list[dict]): Meta data, such as path of GT file. + Default: None. + save_image (bool): Whether to save image. Default: False. + save_path (str): Path to save image. Default: None. + iteration (int): Iteration for the saving image name. + Default: None. + + Returns: + dict: Output results, which contain either key(s) + 1. 'eval_result'. + 2. 'lq', 'pred'. + 3. 'lq', 'pred', 'gt'. + """ + + # generator + with torch.no_grad(): + sr_list, _ = self.generator.forward(lq) + pred = sr_list[3] + pred = self.img_denormalize(pred) + + if gt is not None: + gt = self.img_denormalize(gt) + + if self.test_cfg is not None and self.test_cfg.get('metrics', None): + assert gt is not None, ( + 'evaluation with metrics must have gt images.') + results = dict(eval_result=self.evaluate(pred, gt)) + else: + results = dict(lq=lq.cpu(), output=pred.cpu()) + if gt is not None: + results['gt'] = gt.cpu() + + # save image + if save_image: + if 'gt_path' in meta[0]: + the_path = meta[0]['gt_path'] + else: + the_path = meta[0]['lq_path'] + folder_name = osp.splitext(osp.basename(the_path))[0] + if isinstance(iteration, numbers.Number): + save_path = osp.join(save_path, folder_name, + f'{folder_name}-{iteration + 1:06d}.png') + elif iteration is None: + save_path = osp.join(save_path, f'{folder_name}.png') + else: + raise ValueError('iteration should be number or None, ' + f'but got {type(iteration)}') + mmcv.imwrite(tensor2img(pred), save_path) + + return results + + def val_step(self, data_batch, **kwargs): + """Validation step. + + Args: + data_batch (dict): A batch of data. + kwargs (dict): Other arguments for ``val_step``. + + Returns: + dict: Returned output. + """ + output = self.forward_test(**data_batch, **kwargs) + return output From 1d90cc622615b8ece02211ee0976b0a7571a688f Mon Sep 17 00:00:00 2001 From: liyinshuo Date: Sat, 12 Jun 2021 10:37:42 +0800 Subject: [PATCH 2/5] Add test --- mmedit/models/restorers/dic.py | 14 ++--- tests/test_dic_model.py | 95 ++++++++++++++++++++++++++++++++++ 2 files changed, 100 insertions(+), 9 deletions(-) create mode 100644 tests/test_dic_model.py diff --git a/mmedit/models/restorers/dic.py b/mmedit/models/restorers/dic.py index b2e4f592d4..2d8a56b1af 100644 --- a/mmedit/models/restorers/dic.py +++ b/mmedit/models/restorers/dic.py @@ -30,8 +30,8 @@ class DIC(BasicRestorer): def __init__(self, generator, - pixel_loss=None, - align_loss=None, + pixel_loss, + align_loss, train_cfg=None, test_cfg=None, pretrained=None): @@ -42,10 +42,6 @@ def __init__(self, # model self.generator = build_backbone(generator) - self.img_normalize = ImgNormalize( - pixel_range=1, - img_mean=(129.795, 108.12, 96.39), - img_std=(255, 255, 255)) self.img_denormalize = ImgNormalize( pixel_range=1, img_mean=(0.509, 0.424, 0.378), @@ -53,8 +49,8 @@ def __init__(self, sign=1) # loss - self.pixel_loss = build_loss(pixel_loss) if pixel_loss else None - self.align_loss = build_loss(align_loss) if align_loss else None + self.pixel_loss = build_loss(pixel_loss) + self.align_loss = build_loss(align_loss) # pretrained if pretrained: @@ -153,7 +149,7 @@ def forward_test(self, # generator with torch.no_grad(): sr_list, _ = self.generator.forward(lq) - pred = sr_list[3] + pred = sr_list[-1] pred = self.img_denormalize(pred) if gt is not None: diff --git a/tests/test_dic_model.py b/tests/test_dic_model.py new file mode 100644 index 0000000000..639db2bb68 --- /dev/null +++ b/tests/test_dic_model.py @@ -0,0 +1,95 @@ +import numpy as np +import pytest +import torch +from mmcv.runner import obj_from_dict +from mmcv.utils.config import Config + +from mmedit.models.builder import build_model + + +def test_dic_midel(): + + model_cfg = dict( + type='DIC', + generator=dict( + type='DICNet', + in_channels=3, + out_channels=3, + mid_channels=48, + num_blocks=6, + hg_mid_channels=256, + hg_num_keypoints=68, + num_steps=4, + upscale_factor=8, + detach_attention=False), + pixel_loss=dict(type='L1Loss', loss_weight=1.0, reduction='mean'), + align_loss=dict(type='MSELoss', loss_weight=0.1, reduction='mean')) + + scale = 8 + train_cfg = None + test_cfg = Config(dict(metrics=['PSNR', 'SSIM'], crop_border=scale)) + + # build restorer + restorer = build_model(model_cfg, train_cfg=train_cfg, test_cfg=test_cfg) + + # test attributes + assert restorer.__class__.__name__ == 'DIC' + + # prepare data + inputs = torch.rand(1, 3, 16, 16) + targets = torch.rand(1, 3, 128, 128) + heatmap = torch.rand(1, 68, 32, 32) + data_batch = {'lq': inputs, 'gt': targets, 'heatmap': heatmap} + + # prepare optimizer + optim_cfg = dict(type='Adam', lr=1e-4, betas=(0.9, 0.999)) + optimizer = dict( + generator=obj_from_dict(optim_cfg, torch.optim, + dict(params=restorer.parameters()))) + + # test train_step and forward_test (cpu) + outputs = restorer.train_step(data_batch, optimizer) + assert isinstance(outputs, dict) + assert isinstance(outputs['log_vars'], dict) + assert isinstance(outputs['log_vars']['loss_pixel_v3'], float) + assert outputs['num_samples'] == 1 + assert outputs['results']['lq'].shape == data_batch['lq'].shape + assert outputs['results']['gt'].shape == data_batch['gt'].shape + assert torch.is_tensor(outputs['results']['output']) + assert outputs['results']['output'].size() == (1, 3, 128, 128) + + # test train_step and forward_test (gpu) + if torch.cuda.is_available(): + restorer = restorer.cuda() + data_batch = { + 'lq': inputs.cuda(), + 'gt': targets.cuda(), + 'heatmap': heatmap.cuda() + } + + # train_step + optimizer = dict( + generator=obj_from_dict(optim_cfg, torch.optim, + dict(params=restorer.parameters()))) + outputs = restorer.train_step(data_batch, optimizer) + assert isinstance(outputs, dict) + assert isinstance(outputs['log_vars'], dict) + assert isinstance(outputs['log_vars']['loss_pixel_v3'], float) + assert outputs['num_samples'] == 1 + assert outputs['results']['lq'].shape == data_batch['lq'].shape + assert outputs['results']['gt'].shape == data_batch['gt'].shape + assert torch.is_tensor(outputs['results']['output']) + assert outputs['results']['output'].size() == (1, 3, 128, 128) + + # val_step + data_batch.pop('heatmap') + result = restorer.val_step(data_batch, meta=[{'gt_path': ''}]) + assert isinstance(result, dict) + assert isinstance(result['eval_result'], dict) + assert result['eval_result'].keys() == set({'PSNR', 'SSIM'}) + assert isinstance(result['eval_result']['PSNR'], np.float64) + assert isinstance(result['eval_result']['SSIM'], np.float64) + + with pytest.raises(AssertionError): + # evaluation with metrics must have gt images + restorer(lq=inputs.cuda(), test_mode=True) From 6404db22bb17e1dbd7e3e331b3463813a8587dc7 Mon Sep 17 00:00:00 2001 From: liyinshuo Date: Sat, 12 Jun 2021 17:27:10 +0800 Subject: [PATCH 3/5] Fix --- mmedit/models/restorers/dic.py | 15 +-------------- tests/test_dic_model.py | 6 +++++- 2 files changed, 6 insertions(+), 15 deletions(-) diff --git a/mmedit/models/restorers/dic.py b/mmedit/models/restorers/dic.py index 2d8a56b1af..8afc3ef178 100644 --- a/mmedit/models/restorers/dic.py +++ b/mmedit/models/restorers/dic.py @@ -90,7 +90,7 @@ def train_step(self, data_batch, optimizer): gt_heatmap = data_batch['heatmap'] # generate - sr_list, heatmap_list = self.generator.forward(lq) + sr_list, heatmap_list = self(**data_batch, test_mode=False) # loss losses = OrderedDict() @@ -182,16 +182,3 @@ def forward_test(self, mmcv.imwrite(tensor2img(pred), save_path) return results - - def val_step(self, data_batch, **kwargs): - """Validation step. - - Args: - data_batch (dict): A batch of data. - kwargs (dict): Other arguments for ``val_step``. - - Returns: - dict: Returned output. - """ - output = self.forward_test(**data_batch, **kwargs) - return output diff --git a/tests/test_dic_model.py b/tests/test_dic_model.py index 639db2bb68..a73c2fcc68 100644 --- a/tests/test_dic_model.py +++ b/tests/test_dic_model.py @@ -7,7 +7,7 @@ from mmedit.models.builder import build_model -def test_dic_midel(): +def test_dic_model(): model_cfg = dict( type='DIC', @@ -93,3 +93,7 @@ def test_dic_midel(): with pytest.raises(AssertionError): # evaluation with metrics must have gt images restorer(lq=inputs.cuda(), test_mode=True) + + +if __name__ == '__main__': + test_dic_model() From 0d1f196a985215e389032899afb0dc72eb50c64d Mon Sep 17 00:00:00 2001 From: liyinshuo Date: Tue, 15 Jun 2021 13:49:47 +0800 Subject: [PATCH 4/5] Fix --- mmedit/models/restorers/dic.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/mmedit/models/restorers/dic.py b/mmedit/models/restorers/dic.py index 8afc3ef178..3421ec6bd3 100644 --- a/mmedit/models/restorers/dic.py +++ b/mmedit/models/restorers/dic.py @@ -53,8 +53,7 @@ def __init__(self, self.align_loss = build_loss(align_loss) # pretrained - if pretrained: - self.init_weights(pretrained) + self.init_weights(pretrained) def forward(self, lq, gt=None, test_mode=False, **kwargs): """Forward function. @@ -167,10 +166,10 @@ def forward_test(self, # save image if save_image: if 'gt_path' in meta[0]: - the_path = meta[0]['gt_path'] + pred_path = meta[0]['gt_path'] else: - the_path = meta[0]['lq_path'] - folder_name = osp.splitext(osp.basename(the_path))[0] + pred_path = meta[0]['lq_path'] + folder_name = osp.splitext(osp.basename(pred_path))[0] if isinstance(iteration, numbers.Number): save_path = osp.join(save_path, folder_name, f'{folder_name}-{iteration + 1:06d}.png') From b091666fe1d63b6c9d876cd8eae386a8f365ad10 Mon Sep 17 00:00:00 2001 From: liyinshuo Date: Tue, 15 Jun 2021 15:07:42 +0800 Subject: [PATCH 5/5] Update --- tests/test_dic_model.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/test_dic_model.py b/tests/test_dic_model.py index a73c2fcc68..b957f61cb5 100644 --- a/tests/test_dic_model.py +++ b/tests/test_dic_model.py @@ -93,7 +93,3 @@ def test_dic_model(): with pytest.raises(AssertionError): # evaluation with metrics must have gt images restorer(lq=inputs.cuda(), test_mode=True) - - -if __name__ == '__main__': - test_dic_model()