-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Add DIC #363
Merged
Merged
[Feature] Add DIC #363
Changes from 2 commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
import numbers | ||
import os.path as osp | ||
from collections import OrderedDict | ||
|
||
import mmcv | ||
import torch | ||
|
||
from mmedit.core import tensor2img | ||
from mmedit.models.common import ImgNormalize | ||
from ..builder import build_backbone, build_loss | ||
from ..registry import MODELS | ||
from .basic_restorer import BasicRestorer | ||
|
||
|
||
@MODELS.register_module() | ||
class DIC(BasicRestorer): | ||
"""DIC model for Face Super-Resolution. | ||
|
||
Paper: Deep Face Super-Resolution with Iterative Collaboration between | ||
Attentive Recovery and Landmark Estimation. | ||
|
||
Args: | ||
generator (dict): Config for the generator. | ||
pixel_loss (dict): Config for the pixel loss. | ||
align_loss (dict): Config for thr align loss. | ||
train_cfg (dict): Config for train. Default: None. | ||
test_cfg (dict): Config for testing. Default: None. | ||
pretrained (str): Path for pretrained model. Default: None. | ||
""" | ||
|
||
def __init__(self, | ||
generator, | ||
pixel_loss, | ||
align_loss, | ||
train_cfg=None, | ||
test_cfg=None, | ||
pretrained=None): | ||
super(BasicRestorer, self).__init__() | ||
|
||
self.train_cfg = train_cfg | ||
self.test_cfg = test_cfg | ||
|
||
# model | ||
self.generator = build_backbone(generator) | ||
self.img_denormalize = ImgNormalize( | ||
pixel_range=1, | ||
img_mean=(0.509, 0.424, 0.378), | ||
img_std=(1., 1., 1.), | ||
sign=1) | ||
|
||
# loss | ||
self.pixel_loss = build_loss(pixel_loss) | ||
self.align_loss = build_loss(align_loss) | ||
|
||
# pretrained | ||
if pretrained: | ||
self.init_weights(pretrained) | ||
|
||
def forward(self, lq, gt=None, test_mode=False, **kwargs): | ||
"""Forward function. | ||
|
||
Args: | ||
lq (Tensor): Input lq images. | ||
gt (Tensor): Ground-truth image. Default: None. | ||
test_mode (bool): Whether in test mode or not. Default: False. | ||
kwargs (dict): Other arguments. | ||
""" | ||
|
||
if test_mode: | ||
return self.forward_test(lq, gt=gt, **kwargs) | ||
|
||
return self.generator.forward(lq) | ||
Yshuo-Li marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def train_step(self, data_batch, optimizer): | ||
"""Train step. | ||
|
||
Args: | ||
data_batch (dict): A batch of data, which requires | ||
'lq', 'gt' | ||
optimizer (obj): Optimizer. | ||
|
||
Returns: | ||
dict: Returned output, which includes: | ||
log_vars, num_samples, results (lq, gt and pred). | ||
|
||
""" | ||
# data | ||
lq = data_batch['lq'] | ||
gt = data_batch['gt'] | ||
gt_heatmap = data_batch['heatmap'] | ||
|
||
# generate | ||
sr_list, heatmap_list = self.generator.forward(lq) | ||
|
||
# loss | ||
losses = OrderedDict() | ||
|
||
loss_pix = 0.0 | ||
loss_align = 0.0 | ||
for step, (sr, heatmap) in enumerate(zip(sr_list, heatmap_list)): | ||
losses[f'loss_pixel_v{step}'] = self.pixel_loss(sr, gt) | ||
ckkelvinchan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
loss_pix += losses[f'loss_pixel_v{step}'] | ||
losses[f'loss_align_v{step}'] = self.pixel_loss( | ||
heatmap, gt_heatmap) | ||
loss_align += losses[f'loss_align_v{step}'] | ||
|
||
# parse loss | ||
loss, log_vars = self.parse_losses(losses) | ||
|
||
# optimize | ||
optimizer['generator'].zero_grad() | ||
loss.backward() | ||
optimizer['generator'].step() | ||
|
||
log_vars.pop('loss') # remove the unnecessary 'loss' | ||
outputs = dict( | ||
log_vars=log_vars, | ||
num_samples=len(gt.data), | ||
results=dict(lq=lq.cpu(), gt=gt.cpu(), output=sr_list[-1].cpu())) | ||
|
||
return outputs | ||
|
||
def forward_test(self, | ||
lq, | ||
gt=None, | ||
meta=None, | ||
save_image=False, | ||
save_path=None, | ||
iteration=None): | ||
"""Testing forward function. | ||
|
||
Args: | ||
lq (Tensor): LQ image. | ||
gt (Tensor): GT image. | ||
meta (list[dict]): Meta data, such as path of GT file. | ||
Default: None. | ||
save_image (bool): Whether to save image. Default: False. | ||
save_path (str): Path to save image. Default: None. | ||
iteration (int): Iteration for the saving image name. | ||
Default: None. | ||
|
||
Returns: | ||
dict: Output results, which contain either key(s) | ||
1. 'eval_result'. | ||
2. 'lq', 'pred'. | ||
3. 'lq', 'pred', 'gt'. | ||
""" | ||
|
||
# generator | ||
with torch.no_grad(): | ||
sr_list, _ = self.generator.forward(lq) | ||
pred = sr_list[-1] | ||
pred = self.img_denormalize(pred) | ||
|
||
if gt is not None: | ||
gt = self.img_denormalize(gt) | ||
|
||
if self.test_cfg is not None and self.test_cfg.get('metrics', None): | ||
assert gt is not None, ( | ||
'evaluation with metrics must have gt images.') | ||
results = dict(eval_result=self.evaluate(pred, gt)) | ||
else: | ||
results = dict(lq=lq.cpu(), output=pred.cpu()) | ||
if gt is not None: | ||
results['gt'] = gt.cpu() | ||
|
||
# save image | ||
if save_image: | ||
if 'gt_path' in meta[0]: | ||
the_path = meta[0]['gt_path'] | ||
else: | ||
the_path = meta[0]['lq_path'] | ||
Yshuo-Li marked this conversation as resolved.
Show resolved
Hide resolved
|
||
folder_name = osp.splitext(osp.basename(the_path))[0] | ||
if isinstance(iteration, numbers.Number): | ||
save_path = osp.join(save_path, folder_name, | ||
f'{folder_name}-{iteration + 1:06d}.png') | ||
elif iteration is None: | ||
save_path = osp.join(save_path, f'{folder_name}.png') | ||
else: | ||
raise ValueError('iteration should be number or None, ' | ||
f'but got {type(iteration)}') | ||
mmcv.imwrite(tensor2img(pred), save_path) | ||
|
||
return results | ||
|
||
def val_step(self, data_batch, **kwargs): | ||
Yshuo-Li marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Validation step. | ||
|
||
Args: | ||
data_batch (dict): A batch of data. | ||
kwargs (dict): Other arguments for ``val_step``. | ||
|
||
Returns: | ||
dict: Returned output. | ||
""" | ||
output = self.forward_test(**data_batch, **kwargs) | ||
return output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import numpy as np | ||
import pytest | ||
import torch | ||
from mmcv.runner import obj_from_dict | ||
from mmcv.utils.config import Config | ||
|
||
from mmedit.models.builder import build_model | ||
|
||
|
||
def test_dic_midel(): | ||
Yshuo-Li marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
model_cfg = dict( | ||
type='DIC', | ||
generator=dict( | ||
type='DICNet', | ||
in_channels=3, | ||
out_channels=3, | ||
mid_channels=48, | ||
num_blocks=6, | ||
hg_mid_channels=256, | ||
hg_num_keypoints=68, | ||
num_steps=4, | ||
upscale_factor=8, | ||
detach_attention=False), | ||
pixel_loss=dict(type='L1Loss', loss_weight=1.0, reduction='mean'), | ||
align_loss=dict(type='MSELoss', loss_weight=0.1, reduction='mean')) | ||
|
||
scale = 8 | ||
train_cfg = None | ||
test_cfg = Config(dict(metrics=['PSNR', 'SSIM'], crop_border=scale)) | ||
|
||
# build restorer | ||
restorer = build_model(model_cfg, train_cfg=train_cfg, test_cfg=test_cfg) | ||
|
||
# test attributes | ||
assert restorer.__class__.__name__ == 'DIC' | ||
|
||
# prepare data | ||
inputs = torch.rand(1, 3, 16, 16) | ||
targets = torch.rand(1, 3, 128, 128) | ||
heatmap = torch.rand(1, 68, 32, 32) | ||
data_batch = {'lq': inputs, 'gt': targets, 'heatmap': heatmap} | ||
|
||
# prepare optimizer | ||
optim_cfg = dict(type='Adam', lr=1e-4, betas=(0.9, 0.999)) | ||
optimizer = dict( | ||
generator=obj_from_dict(optim_cfg, torch.optim, | ||
dict(params=restorer.parameters()))) | ||
|
||
# test train_step and forward_test (cpu) | ||
outputs = restorer.train_step(data_batch, optimizer) | ||
assert isinstance(outputs, dict) | ||
assert isinstance(outputs['log_vars'], dict) | ||
assert isinstance(outputs['log_vars']['loss_pixel_v3'], float) | ||
assert outputs['num_samples'] == 1 | ||
assert outputs['results']['lq'].shape == data_batch['lq'].shape | ||
assert outputs['results']['gt'].shape == data_batch['gt'].shape | ||
assert torch.is_tensor(outputs['results']['output']) | ||
assert outputs['results']['output'].size() == (1, 3, 128, 128) | ||
|
||
# test train_step and forward_test (gpu) | ||
if torch.cuda.is_available(): | ||
restorer = restorer.cuda() | ||
data_batch = { | ||
'lq': inputs.cuda(), | ||
'gt': targets.cuda(), | ||
'heatmap': heatmap.cuda() | ||
} | ||
|
||
# train_step | ||
optimizer = dict( | ||
generator=obj_from_dict(optim_cfg, torch.optim, | ||
dict(params=restorer.parameters()))) | ||
outputs = restorer.train_step(data_batch, optimizer) | ||
assert isinstance(outputs, dict) | ||
assert isinstance(outputs['log_vars'], dict) | ||
assert isinstance(outputs['log_vars']['loss_pixel_v3'], float) | ||
assert outputs['num_samples'] == 1 | ||
assert outputs['results']['lq'].shape == data_batch['lq'].shape | ||
assert outputs['results']['gt'].shape == data_batch['gt'].shape | ||
assert torch.is_tensor(outputs['results']['output']) | ||
assert outputs['results']['output'].size() == (1, 3, 128, 128) | ||
|
||
# val_step | ||
data_batch.pop('heatmap') | ||
result = restorer.val_step(data_batch, meta=[{'gt_path': ''}]) | ||
assert isinstance(result, dict) | ||
assert isinstance(result['eval_result'], dict) | ||
assert result['eval_result'].keys() == set({'PSNR', 'SSIM'}) | ||
assert isinstance(result['eval_result']['PSNR'], np.float64) | ||
assert isinstance(result['eval_result']['SSIM'], np.float64) | ||
|
||
with pytest.raises(AssertionError): | ||
# evaluation with metrics must have gt images | ||
restorer(lq=inputs.cuda(), test_mode=True) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is it used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out of curiosity, where is the normalize corresponds to this denormalize?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The normalize is in pipeline (called in the config file, similar to other method of this repo).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about the first question
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Q: Where is it used?
A: Deleted by mistake in the previous version (f0b85d2).
This problem has been fixed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this de-normalization be configurable? Just asking... We can leave the better normalize-denormalize story to future work. Currently they are quite ad-hoc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
post-pipeline?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe...