Skip to content

Commit

Permalink
Add test
Browse files Browse the repository at this point in the history
  • Loading branch information
liyinshuo committed May 25, 2021
1 parent f246dde commit 0211082
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 18 deletions.
32 changes: 17 additions & 15 deletions mmedit/models/restorers/ttsr.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,14 @@ def __init__(self,
if pretrained:
self.init_weights(pretrained)

def forward_dummy(self, lq, lq_pad, ref, ref_pad, only_pred=True):
def forward_dummy(self, lq, lq_up, ref, ref_downup, only_pred=True):
"""Forward of networks.
Args:
lq (Tensor): LQ image
lq_pad (Tensor): Upsampled LQ image
lq_up (Tensor): Upsampled LQ image
ref (Tensor): Reference image
ref_pad (Tensor): Image generated by sequentially applying
ref_downup (Tensor): Image generated by sequentially applying
bicubic down-sampling and up-sampling on reference image
only_pred (bool): Only return pred or not. Default: True
Expand All @@ -73,13 +73,15 @@ def forward_dummy(self, lq, lq_pad, ref, ref_pad, only_pred=True):
(n, c, 4h, 4w)
"""

_, _, lq_pad_level3 = self.extractor((lq_pad.detach() + 1.) / 2.)
_, _, ref_pad_level3 = self.extractor((ref_pad.detach() + 1.) / 2.)
_, _, lq_up_level3 = self.extractor((lq_up.detach() + 1.) / 2.)
_, _, ref_downup_level3 = self.extractor(
(ref_downup.detach() + 1.) / 2.)
ref_level1, ref_level2, ref_level3 = self.extractor(
(ref.detach() + 1.) / 2.)

s, t_level3, t_level2, t_level1 = self.transformer(
lq_pad_level3, ref_pad_level3, ref_level1, ref_level2, ref_level3)
lq_up_level3, ref_downup_level3, ref_level1, ref_level2,
ref_level3)

pred = self.generator(lq, s, t_level3, t_level2, t_level1)

Expand Down Expand Up @@ -107,7 +109,7 @@ def train_step(self, data_batch, optimizer):
Args:
data_batch (dict): A batch of data, which requires
'lq', 'gt', 'lq_pad', 'ref', 'ref_pad'
'lq', 'gt', 'lq_up', 'ref', 'ref_downup'
optimizer (obj): Optimizer.
Returns:
Expand All @@ -117,13 +119,13 @@ def train_step(self, data_batch, optimizer):
"""
# data
lq = data_batch['lq']
lq_pad = data_batch['lq_pad']
lq_up = data_batch['lq_up']
gt = data_batch['gt']
ref = data_batch['ref']
ref_pad = data_batch['ref_pad']
ref_downup = data_batch['ref_downup']

# generate
pred = self.forward_dummy(lq, lq_pad, ref, ref_pad)
pred = self.forward_dummy(lq, lq_up, ref, ref_downup)

# loss
losses = dict()
Expand Down Expand Up @@ -153,9 +155,9 @@ def train_step(self, data_batch, optimizer):

def forward_test(self,
lq,
lq_pad,
lq_up,
ref,
ref_pad,
ref_downup,
gt=None,
meta=None,
save_image=False,
Expand All @@ -166,9 +168,9 @@ def forward_test(self,
Args:
lq (Tensor): LQ image
gt (Tensor): GT image
lq_pad (Tensor): Upsampled LQ image
lq_up (Tensor): Upsampled LQ image
ref (Tensor): Reference image
ref_pad (Tensor): Image generated by sequentially applying
ref_downup (Tensor): Image generated by sequentially applying
bicubic down-sampling and up-sampling on reference image
meta (list[dict]): Meta data, such as path of GT file.
Default: None.
Expand All @@ -187,7 +189,7 @@ def forward_test(self,
# generator
with torch.no_grad():
pred = self.forward_dummy(
lq=lq, lq_pad=lq_pad, ref=ref, ref_pad=ref_pad)
lq=lq, lq_up=lq_up, ref=ref, ref_downup=ref_downup)

pred = (pred + 1.) / 2.
if gt is not None:
Expand Down
89 changes: 86 additions & 3 deletions tests/test_ttsr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import numpy as np
import torch
from mmcv.runner import obj_from_dict
from mmcv.utils.config import Config

from mmedit.models import build_backbone
from mmedit.models import build_backbone, build_model
from mmedit.models.backbones.sr_backbones.ttsr_net import (CSFI2, CSFI3, SFE,
MergeFeatures)

Expand Down Expand Up @@ -59,5 +62,85 @@ def test_ttsr_net():


def test_ttsr():
# TODO wait for ttsr_net, lte, and search_transofrmer.
pass

model_cfg = dict(
type='TTSR',
generator=dict(
type='TTSRNet',
in_channels=3,
out_channels=3,
mid_channels=64,
num_blocks=(16, 16, 8, 4)),
extractor=dict(type='LTE'),
transformer=dict(type='SearchTransformer'),
pixel_loss=dict(type='L1Loss', loss_weight=1.0, reduction='mean'))

scale = 4
train_cfg = None
test_cfg = Config(dict(metrics=['PSNR', 'SSIM'], crop_border=scale))

# build restorer
restorer = build_model(model_cfg, train_cfg=train_cfg, test_cfg=test_cfg)

# test attributes
assert restorer.__class__.__name__ == 'TTSR'

# prepare data
inputs = torch.rand(1, 3, 16, 16)
targets = torch.rand(1, 3, 64, 64)
ref = torch.rand(1, 3, 64, 64)
data_batch = {
'lq': inputs,
'gt': targets,
'ref': ref,
'lq_up': ref,
'ref_downup': ref
}

# prepare optimizer
optim_cfg = dict(type='Adam', lr=1e-4, betas=(0.9, 0.999))
optimizer = obj_from_dict(optim_cfg, torch.optim,
dict(params=restorer.parameters()))

# test train_step and forward_test (cpu)
outputs = restorer.train_step(data_batch, optimizer)
assert isinstance(outputs, dict)
assert isinstance(outputs['log_vars'], dict)
assert isinstance(outputs['log_vars']['loss_pix'], float)
assert outputs['num_samples'] == 1
assert outputs['results']['lq'].shape == data_batch['lq'].shape
assert outputs['results']['gt'].shape == data_batch['gt'].shape
assert torch.is_tensor(outputs['results']['output'])
assert outputs['results']['output'].size() == (1, 3, 64, 64)

# test train_step and forward_test (gpu)
if torch.cuda.is_available():
restorer = restorer.cuda()
data_batch = {
'lq': inputs.cuda(),
'gt': targets.cuda(),
'ref': ref.cuda(),
'lq_up': ref.cuda(),
'ref_downup': ref.cuda()
}

# train_step
optimizer = obj_from_dict(optim_cfg, torch.optim,
dict(params=restorer.parameters()))
outputs = restorer.train_step(data_batch, optimizer)
assert isinstance(outputs, dict)
assert isinstance(outputs['log_vars'], dict)
assert isinstance(outputs['log_vars']['loss_pix'], float)
assert outputs['num_samples'] == 1
assert outputs['results']['lq'].shape == data_batch['lq'].shape
assert outputs['results']['gt'].shape == data_batch['gt'].shape
assert torch.is_tensor(outputs['results']['output'])
assert outputs['results']['output'].size() == (1, 3, 64, 64)

# val_step
result = restorer.val_step(data_batch, meta=[{'gt_path': ''}])
assert isinstance(result, dict)
assert isinstance(result['eval_result'], dict)
assert result['eval_result'].keys() == set({'PSNR', 'SSIM'})
assert isinstance(result['eval_result']['PSNR'], np.float64)
assert isinstance(result['eval_result']['SSIM'], np.float64)

0 comments on commit 0211082

Please sign in to comment.