Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add DIC #363

Merged
merged 5 commits into from
Jun 15, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mmedit/models/restorers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .basic_restorer import BasicRestorer
from .basicvsr import BasicVSR
from .dic import DIC
from .edvr import EDVR
from .esrgan import ESRGAN
from .glean import GLEAN
Expand All @@ -10,5 +11,5 @@

__all__ = [
'BasicRestorer', 'SRGAN', 'ESRGAN', 'EDVR', 'LIIF', 'BasicVSR', 'TTSR',
'GLEAN', 'TDAN'
'GLEAN', 'TDAN', 'DIC'
]
197 changes: 197 additions & 0 deletions mmedit/models/restorers/dic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
import numbers
import os.path as osp
from collections import OrderedDict

import mmcv
import torch

from mmedit.core import tensor2img
from mmedit.models.common import ImgNormalize
from ..builder import build_backbone, build_loss
from ..registry import MODELS
from .basic_restorer import BasicRestorer


@MODELS.register_module()
class DIC(BasicRestorer):
"""DIC model for Face Super-Resolution.

Paper: Deep Face Super-Resolution with Iterative Collaboration between
Attentive Recovery and Landmark Estimation.

Args:
generator (dict): Config for the generator.
pixel_loss (dict): Config for the pixel loss.
align_loss (dict): Config for thr align loss.
train_cfg (dict): Config for train. Default: None.
test_cfg (dict): Config for testing. Default: None.
pretrained (str): Path for pretrained model. Default: None.
"""

def __init__(self,
generator,
pixel_loss,
align_loss,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(BasicRestorer, self).__init__()

self.train_cfg = train_cfg
self.test_cfg = test_cfg

# model
self.generator = build_backbone(generator)
self.img_denormalize = ImgNormalize(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is it used?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, where is the normalize corresponds to this denormalize?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The normalize is in pipeline (called in the config file, similar to other method of this repo).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about the first question

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: Where is it used?
A: Deleted by mistake in the previous version (f0b85d2).

This problem has been fixed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this de-normalization be configurable? Just asking... We can leave the better normalize-denormalize story to future work. Currently they are quite ad-hoc.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

post-pipeline?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe...

pixel_range=1,
img_mean=(0.509, 0.424, 0.378),
img_std=(1., 1., 1.),
sign=1)

# loss
self.pixel_loss = build_loss(pixel_loss)
self.align_loss = build_loss(align_loss)

# pretrained
if pretrained:
self.init_weights(pretrained)

def forward(self, lq, gt=None, test_mode=False, **kwargs):
"""Forward function.

Args:
lq (Tensor): Input lq images.
gt (Tensor): Ground-truth image. Default: None.
test_mode (bool): Whether in test mode or not. Default: False.
kwargs (dict): Other arguments.
"""

if test_mode:
return self.forward_test(lq, gt=gt, **kwargs)

return self.generator.forward(lq)
Yshuo-Li marked this conversation as resolved.
Show resolved Hide resolved

def train_step(self, data_batch, optimizer):
"""Train step.

Args:
data_batch (dict): A batch of data, which requires
'lq', 'gt'
optimizer (obj): Optimizer.

Returns:
dict: Returned output, which includes:
log_vars, num_samples, results (lq, gt and pred).

"""
# data
lq = data_batch['lq']
gt = data_batch['gt']
gt_heatmap = data_batch['heatmap']

# generate
sr_list, heatmap_list = self.generator.forward(lq)

# loss
losses = OrderedDict()

loss_pix = 0.0
loss_align = 0.0
for step, (sr, heatmap) in enumerate(zip(sr_list, heatmap_list)):
losses[f'loss_pixel_v{step}'] = self.pixel_loss(sr, gt)
ckkelvinchan marked this conversation as resolved.
Show resolved Hide resolved
loss_pix += losses[f'loss_pixel_v{step}']
losses[f'loss_align_v{step}'] = self.pixel_loss(
heatmap, gt_heatmap)
loss_align += losses[f'loss_align_v{step}']

# parse loss
loss, log_vars = self.parse_losses(losses)

# optimize
optimizer['generator'].zero_grad()
loss.backward()
optimizer['generator'].step()

log_vars.pop('loss') # remove the unnecessary 'loss'
outputs = dict(
log_vars=log_vars,
num_samples=len(gt.data),
results=dict(lq=lq.cpu(), gt=gt.cpu(), output=sr_list[-1].cpu()))

return outputs

def forward_test(self,
lq,
gt=None,
meta=None,
save_image=False,
save_path=None,
iteration=None):
"""Testing forward function.

Args:
lq (Tensor): LQ image.
gt (Tensor): GT image.
meta (list[dict]): Meta data, such as path of GT file.
Default: None.
save_image (bool): Whether to save image. Default: False.
save_path (str): Path to save image. Default: None.
iteration (int): Iteration for the saving image name.
Default: None.

Returns:
dict: Output results, which contain either key(s)
1. 'eval_result'.
2. 'lq', 'pred'.
3. 'lq', 'pred', 'gt'.
"""

# generator
with torch.no_grad():
sr_list, _ = self.generator.forward(lq)
pred = sr_list[-1]
pred = self.img_denormalize(pred)

if gt is not None:
gt = self.img_denormalize(gt)

if self.test_cfg is not None and self.test_cfg.get('metrics', None):
assert gt is not None, (
'evaluation with metrics must have gt images.')
results = dict(eval_result=self.evaluate(pred, gt))
else:
results = dict(lq=lq.cpu(), output=pred.cpu())
if gt is not None:
results['gt'] = gt.cpu()

# save image
if save_image:
if 'gt_path' in meta[0]:
the_path = meta[0]['gt_path']
else:
the_path = meta[0]['lq_path']
Yshuo-Li marked this conversation as resolved.
Show resolved Hide resolved
folder_name = osp.splitext(osp.basename(the_path))[0]
if isinstance(iteration, numbers.Number):
save_path = osp.join(save_path, folder_name,
f'{folder_name}-{iteration + 1:06d}.png')
elif iteration is None:
save_path = osp.join(save_path, f'{folder_name}.png')
else:
raise ValueError('iteration should be number or None, '
f'but got {type(iteration)}')
mmcv.imwrite(tensor2img(pred), save_path)

return results

def val_step(self, data_batch, **kwargs):
Yshuo-Li marked this conversation as resolved.
Show resolved Hide resolved
"""Validation step.

Args:
data_batch (dict): A batch of data.
kwargs (dict): Other arguments for ``val_step``.

Returns:
dict: Returned output.
"""
output = self.forward_test(**data_batch, **kwargs)
return output
95 changes: 95 additions & 0 deletions tests/test_dic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import numpy as np
import pytest
import torch
from mmcv.runner import obj_from_dict
from mmcv.utils.config import Config

from mmedit.models.builder import build_model


def test_dic_midel():
Yshuo-Li marked this conversation as resolved.
Show resolved Hide resolved

model_cfg = dict(
type='DIC',
generator=dict(
type='DICNet',
in_channels=3,
out_channels=3,
mid_channels=48,
num_blocks=6,
hg_mid_channels=256,
hg_num_keypoints=68,
num_steps=4,
upscale_factor=8,
detach_attention=False),
pixel_loss=dict(type='L1Loss', loss_weight=1.0, reduction='mean'),
align_loss=dict(type='MSELoss', loss_weight=0.1, reduction='mean'))

scale = 8
train_cfg = None
test_cfg = Config(dict(metrics=['PSNR', 'SSIM'], crop_border=scale))

# build restorer
restorer = build_model(model_cfg, train_cfg=train_cfg, test_cfg=test_cfg)

# test attributes
assert restorer.__class__.__name__ == 'DIC'

# prepare data
inputs = torch.rand(1, 3, 16, 16)
targets = torch.rand(1, 3, 128, 128)
heatmap = torch.rand(1, 68, 32, 32)
data_batch = {'lq': inputs, 'gt': targets, 'heatmap': heatmap}

# prepare optimizer
optim_cfg = dict(type='Adam', lr=1e-4, betas=(0.9, 0.999))
optimizer = dict(
generator=obj_from_dict(optim_cfg, torch.optim,
dict(params=restorer.parameters())))

# test train_step and forward_test (cpu)
outputs = restorer.train_step(data_batch, optimizer)
assert isinstance(outputs, dict)
assert isinstance(outputs['log_vars'], dict)
assert isinstance(outputs['log_vars']['loss_pixel_v3'], float)
assert outputs['num_samples'] == 1
assert outputs['results']['lq'].shape == data_batch['lq'].shape
assert outputs['results']['gt'].shape == data_batch['gt'].shape
assert torch.is_tensor(outputs['results']['output'])
assert outputs['results']['output'].size() == (1, 3, 128, 128)

# test train_step and forward_test (gpu)
if torch.cuda.is_available():
restorer = restorer.cuda()
data_batch = {
'lq': inputs.cuda(),
'gt': targets.cuda(),
'heatmap': heatmap.cuda()
}

# train_step
optimizer = dict(
generator=obj_from_dict(optim_cfg, torch.optim,
dict(params=restorer.parameters())))
outputs = restorer.train_step(data_batch, optimizer)
assert isinstance(outputs, dict)
assert isinstance(outputs['log_vars'], dict)
assert isinstance(outputs['log_vars']['loss_pixel_v3'], float)
assert outputs['num_samples'] == 1
assert outputs['results']['lq'].shape == data_batch['lq'].shape
assert outputs['results']['gt'].shape == data_batch['gt'].shape
assert torch.is_tensor(outputs['results']['output'])
assert outputs['results']['output'].size() == (1, 3, 128, 128)

# val_step
data_batch.pop('heatmap')
result = restorer.val_step(data_batch, meta=[{'gt_path': ''}])
assert isinstance(result, dict)
assert isinstance(result['eval_result'], dict)
assert result['eval_result'].keys() == set({'PSNR', 'SSIM'})
assert isinstance(result['eval_result']['PSNR'], np.float64)
assert isinstance(result['eval_result']['SSIM'], np.float64)

with pytest.raises(AssertionError):
# evaluation with metrics must have gt images
restorer(lq=inputs.cuda(), test_mode=True)