Skip to content

Commit

Permalink
[Feature] Add TTSR. (#321)
Browse files Browse the repository at this point in the history
* [Feature] Add TTSR.

* Add test

* Fix

Co-authored-by: liyinshuo <[email protected]>
  • Loading branch information
Yshuo-Li and liyinshuo authored May 25, 2021
1 parent f3c9dd6 commit b066bb6
Show file tree
Hide file tree
Showing 3 changed files with 331 additions and 2 deletions.
5 changes: 4 additions & 1 deletion mmedit/models/restorers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,8 @@
from .esrgan import ESRGAN
from .liif import LIIF
from .srgan import SRGAN
from .ttsr import TTSR

__all__ = ['BasicRestorer', 'SRGAN', 'ESRGAN', 'EDVR', 'LIIF', 'BasicVSR']
__all__ = [
'BasicRestorer', 'SRGAN', 'ESRGAN', 'EDVR', 'LIIF', 'BasicVSR', 'TTSR'
]
238 changes: 238 additions & 0 deletions mmedit/models/restorers/ttsr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
import numbers
import os.path as osp

import mmcv
import torch

from mmedit.core import tensor2img
from ..builder import build_backbone, build_component, build_loss
from ..registry import MODELS
from .basic_restorer import BasicRestorer


@MODELS.register_module()
class TTSR(BasicRestorer):
"""TTSR model for Reference-based Image Super-Resolution.
Paper: Learning Texture Transformer Network for Image Super-Resolution.
Args:
generator (dict): Config for the generator.
extractor (dict): Config for the extractor.
transformer (dict): Config for the transformer.
pixel_loss (dict): Config for the pixel loss.
train_cfg (dict): Config for train. Default: None.
test_cfg (dict): Config for testing. Default: None.
pretrained (str): Path for pretrained model. Default: None.
"""

def __init__(self,
generator,
extractor,
transformer,
pixel_loss,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(BasicRestorer, self).__init__()

self.train_cfg = train_cfg
self.test_cfg = test_cfg

# model
self.generator = build_backbone(generator)
self.transformer = build_component(transformer)
self.extractor = build_component(extractor)

# loss
self.pixel_loss = build_loss(pixel_loss)

# pretrained
self.init_weights(pretrained)

def forward_dummy(self, lq, lq_up, ref, ref_downup, only_pred=True):
"""Forward of networks.
Args:
lq (Tensor): LQ image.
lq_up (Tensor): Upsampled LQ image.
ref (Tensor): Reference image.
ref_downup (Tensor): Image generated by sequentially applying
bicubic down-sampling and up-sampling on reference image.
only_pred (bool): Only return predicted results or not.
Default: True.
Returns:
pred (Tensor): Predicted super-resolution results (n, 3, 4h, 4w).
s (Tensor): Soft-Attention tensor with shape (n, 1, h, w).
t_level3 (Tensor): Transformed HR texture T in level3.
(n, 4c, h, w)
t_level2 (Tensor): Transformed HR texture T in level2.
(n, 2c, 2h, 2w)
t_level1 (Tensor): Transformed HR texture T in level1.
(n, c, 4h, 4w)
"""

_, _, lq_up_level3 = self.extractor(lq_up)
_, _, ref_downup_level3 = self.extractor(ref_downup)
ref_level1, ref_level2, ref_level3 = self.extractor(ref)

s, t_level3, t_level2, t_level1 = self.transformer(
lq_up_level3, ref_downup_level3, ref_level1, ref_level2,
ref_level3)

pred = self.generator(lq, s, t_level3, t_level2, t_level1)

if only_pred:
return pred
return pred, s, t_level3, t_level2, t_level1

def forward(self, lq, gt=None, test_mode=False, **kwargs):
"""Forward function.
Args:
lq (Tensor): Input lq images.
gt (Tensor): Ground-truth image. Default: None.
test_mode (bool): Whether in test mode or not. Default: False.
kwargs (dict): Other arguments.
"""

if test_mode:
return self.forward_test(lq, gt=gt, **kwargs)

return self.forward_dummy(lq, **kwargs)

def train_step(self, data_batch, optimizer):
"""Train step.
Args:
data_batch (dict): A batch of data, which requires
'lq', 'gt', 'lq_up', 'ref', 'ref_downup'
optimizer (obj): Optimizer.
Returns:
dict: Returned output, which includes:
log_vars, num_samples, results (lq, gt and pred).
"""
# data
lq = data_batch['lq']
lq_up = data_batch['lq_up']
gt = data_batch['gt']
ref = data_batch['ref']
ref_downup = data_batch['ref_downup']

# generate
pred = self.forward_dummy(lq, lq_up, ref, ref_downup)

# loss
losses = dict()

losses['loss_pix'] = self.pixel_loss(pred, gt)

# parse loss
loss, log_vars = self.parse_losses(losses)

# optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()

log_vars.pop('loss') # remove the unnecessary 'loss'
outputs = dict(
log_vars=log_vars,
num_samples=len(gt.data),
results=dict(
lq=lq.cpu(), gt=gt.cpu(), ref=ref.cpu(), output=pred.cpu()))

return outputs

def forward_test(self,
lq,
lq_up,
ref,
ref_downup,
gt=None,
meta=None,
save_image=False,
save_path=None,
iteration=None):
"""Testing forward function.
Args:
lq (Tensor): LQ image
gt (Tensor): GT image
lq_up (Tensor): Upsampled LQ image
ref (Tensor): Reference image
ref_downup (Tensor): Image generated by sequentially applying
bicubic down-sampling and up-sampling on reference image
meta (list[dict]): Meta data, such as path of GT file.
Default: None.
save_image (bool): Whether to save image. Default: False.
save_path (str): Path to save image. Default: None.
iteration (int): Iteration for the saving image name.
Default: None.
Returns:
dict: Output results, which contain either key(s)
1. 'eval_result'.
2. 'lq', 'pred'.
3. 'lq', 'pred', 'gt'.
"""

# generator
with torch.no_grad():
pred = self.forward_dummy(
lq=lq, lq_up=lq_up, ref=ref, ref_downup=ref_downup)

pred = (pred + 1.) / 2.
if gt is not None:
gt = (gt + 1.) / 2.

if self.test_cfg is not None and self.test_cfg.get('metrics', None):
assert gt is not None, (
'evaluation with metrics must have gt images.')
results = dict(eval_result=self.evaluate(pred, gt))
else:
results = dict(lq=lq.cpu(), output=pred.cpu())
if gt is not None:
results['gt'] = gt.cpu()

# save image
if save_image:
if 'gt_path' in meta[0]:
the_path = meta[0]['gt_path']
else:
the_path = meta[0]['lq_path']
folder_name = osp.splitext(osp.basename(the_path))[0]
if isinstance(iteration, numbers.Number):
save_path = osp.join(save_path, folder_name,
f'{folder_name}-{iteration + 1:06d}.png')
elif iteration is None:
save_path = osp.join(save_path, f'{folder_name}.png')
else:
raise ValueError('iteration should be number or None, '
f'but got {type(iteration)}')
mmcv.imwrite(tensor2img(pred), save_path)

return results

def init_weights(self, pretrained=None, strict=True):
"""Init weights for models.
Args:
pretrained (str, optional): Path for pretrained weights. If given
None, pretrained weights will not be loaded. Defaults to None.
strict (boo, optional): Whether strictly load the pretrained model.
Defaults to True.
"""
if isinstance(pretrained, str):
if self.generator:
self.generator.init_weights(pretrained, strict)
if self.extractor:
self.extractor.init_weights(pretrained, strict)
if self.transformer:
self.transformer.init_weights(pretrained, strict)
elif pretrained is not None:
raise TypeError('"pretrained" must be a str or None. '
f'But received {type(pretrained)}.')
90 changes: 89 additions & 1 deletion tests/test_ttsr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import numpy as np
import torch
from mmcv.runner import obj_from_dict
from mmcv.utils.config import Config

from mmedit.models import build_backbone
from mmedit.models import build_backbone, build_model
from mmedit.models.backbones.sr_backbones.ttsr_net import (CSFI2, CSFI3, SFE,
MergeFeatures)

Expand Down Expand Up @@ -56,3 +59,88 @@ def test_ttsr_net():
outputs = ttsr(inputs, s, t_level3, t_level2, t_level1)

assert outputs.shape == (2, 3, 96, 96)


def test_ttsr():

model_cfg = dict(
type='TTSR',
generator=dict(
type='TTSRNet',
in_channels=3,
out_channels=3,
mid_channels=64,
num_blocks=(16, 16, 8, 4)),
extractor=dict(type='LTE'),
transformer=dict(type='SearchTransformer'),
pixel_loss=dict(type='L1Loss', loss_weight=1.0, reduction='mean'))

scale = 4
train_cfg = None
test_cfg = Config(dict(metrics=['PSNR', 'SSIM'], crop_border=scale))

# build restorer
restorer = build_model(model_cfg, train_cfg=train_cfg, test_cfg=test_cfg)

# test attributes
assert restorer.__class__.__name__ == 'TTSR'

# prepare data
inputs = torch.rand(1, 3, 16, 16)
targets = torch.rand(1, 3, 64, 64)
ref = torch.rand(1, 3, 64, 64)
data_batch = {
'lq': inputs,
'gt': targets,
'ref': ref,
'lq_up': ref,
'ref_downup': ref
}

# prepare optimizer
optim_cfg = dict(type='Adam', lr=1e-4, betas=(0.9, 0.999))
optimizer = obj_from_dict(optim_cfg, torch.optim,
dict(params=restorer.parameters()))

# test train_step and forward_test (cpu)
outputs = restorer.train_step(data_batch, optimizer)
assert isinstance(outputs, dict)
assert isinstance(outputs['log_vars'], dict)
assert isinstance(outputs['log_vars']['loss_pix'], float)
assert outputs['num_samples'] == 1
assert outputs['results']['lq'].shape == data_batch['lq'].shape
assert outputs['results']['gt'].shape == data_batch['gt'].shape
assert torch.is_tensor(outputs['results']['output'])
assert outputs['results']['output'].size() == (1, 3, 64, 64)

# test train_step and forward_test (gpu)
if torch.cuda.is_available():
restorer = restorer.cuda()
data_batch = {
'lq': inputs.cuda(),
'gt': targets.cuda(),
'ref': ref.cuda(),
'lq_up': ref.cuda(),
'ref_downup': ref.cuda()
}

# train_step
optimizer = obj_from_dict(optim_cfg, torch.optim,
dict(params=restorer.parameters()))
outputs = restorer.train_step(data_batch, optimizer)
assert isinstance(outputs, dict)
assert isinstance(outputs['log_vars'], dict)
assert isinstance(outputs['log_vars']['loss_pix'], float)
assert outputs['num_samples'] == 1
assert outputs['results']['lq'].shape == data_batch['lq'].shape
assert outputs['results']['gt'].shape == data_batch['gt'].shape
assert torch.is_tensor(outputs['results']['output'])
assert outputs['results']['output'].size() == (1, 3, 64, 64)

# val_step
result = restorer.val_step(data_batch, meta=[{'gt_path': ''}])
assert isinstance(result, dict)
assert isinstance(result['eval_result'], dict)
assert result['eval_result'].keys() == set({'PSNR', 'SSIM'})
assert isinstance(result['eval_result']['PSNR'], np.float64)
assert isinstance(result['eval_result']['SSIM'], np.float64)

0 comments on commit b066bb6

Please sign in to comment.