Skip to content

Commit

Permalink
[feature] Liif rdn (#440)
Browse files Browse the repository at this point in the history
* Add checkpoints

* [Feature] Add LIIF-RDN

* Add checkpoints

* Update

* Fix

Co-authored-by: liyinshuo <[email protected]>
  • Loading branch information
Yshuo-Li and liyinshuo authored Jul 22, 2021
1 parent 3167261 commit 43eddb9
Show file tree
Hide file tree
Showing 7 changed files with 303 additions and 14 deletions.
6 changes: 6 additions & 0 deletions configs/restorers/liif/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@
|| x6 | 27.1187 / 0.7774 | 24.7461 / 0.6444 | 26.7770 / 0.7425 ||
|| x18 | 20.8516 / 0.5406 | 20.0096 / 0.4525 | 22.1987 / 0.5955 ||
|| x30 | 18.8467 / 0.5010 | 18.1321 / 0.3963 | 20.5050 / 0.5577 ||
| [liif_rdn_norm_c64b16_g1_1000k_div2k](/configs/restorers/liif/liif_rdn_norm_x2-4_c64b16_g1_1000k_div2k.py) | x2 | 35.7874 / 0.9366 | 31.6866 / 0.8896 | 34.7548 / 0.9356 | [model](https://download.openmmlab.com/mmediting/restorers/liif/liif_rdn_norm_c64b16_g1_1000k_div2k_20210717-22d6fdc8.pth) \| [log](https://download.openmmlab.com/mmediting/restorers/liif/liif_rdn_norm_c64b16_g1_1000k_div2k_20210717-22d6fdc8.log.json) |
|| x3 | 32.4992 / 0.8923 | 28.4905 / 0.8037 | 31.0744 / 0.8731 ||
|| x4 | 30.3835 / 0.8513 | 26.8734 / 0.7373 | 29.1101 / 0.8197 ||
|| x6 | 27.1914 / 0.7751 | 24.7824 / 0.6434 | 26.8693 / 0.7437 ||
|| x18 | 20.8913 / 0.5329 | 20.1077 / 0.4537 | 22.2972 / 0.5950 ||
|| x30 | 18.9354 / 0.4864 | 18.1448 / 0.3942 | 20.5663 / 0.5560 ||

Note:
* △ refers to ditto.
Expand Down
20 changes: 12 additions & 8 deletions configs/restorers/liif/liif_edsr_norm_c64b16_g1_1000k_div2k.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,18 @@
gt_folder='data/val_set5/Set5',
pipeline=valid_pipeline,
scale=scale_max),
# test=dict(
# type=test_dataset_type,
# lq_folder=f'data/val_set5/Set5_bicLRx{scale_max:d}',
# gt_folder='data/val_set5/Set5',
# pipeline=test_pipeline,
# scale=scale_max,
# filename_tmpl='{}'),
test=dict(
type=test_dataset_type,
lq_folder=f'data/val_set5/Set5_bicLRx{scale_max:d}',
type=val_dataset_type,
gt_folder='data/val_set5/Set5',
pipeline=test_pipeline,
scale=scale_max,
filename_tmpl='{}'))
pipeline=valid_pipeline,
scale=scale_max))

# optimizer
optimizers = dict(type='Adam', lr=1.e-4)
Expand All @@ -133,9 +138,8 @@
step=[200000, 400000, 600000, 800000],
gamma=0.5)

checkpoint_config = dict(
interval=iter_per_epoch, save_optimizer=True, by_epoch=False)
evaluation = dict(interval=iter_per_epoch, save_image=True, gpu_collect=True)
checkpoint_config = dict(interval=3000, save_optimizer=True, by_epoch=False)
evaluation = dict(interval=3000, save_image=True, gpu_collect=True)
log_config = dict(
interval=100,
hooks=[
Expand Down
158 changes: 158 additions & 0 deletions configs/restorers/liif/liif_rdn_norm_x2-4_c64b16_g1_1000k_div2k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
exp_name = 'liif_rdn_norm_x2-4_c64b16_g1_1000k_div2k'
scale_min, scale_max = 1, 4

# model settings
model = dict(
type='LIIF',
generator=dict(
type='LIIFRDN',
encoder=dict(
type='RDN',
in_channels=3,
out_channels=3,
mid_channels=64,
num_blocks=16,
upscale_factor=4,
num_layers=8,
channel_growth=64),
imnet=dict(
type='MLPRefiner',
in_dim=64,
out_dim=3,
hidden_list=[256, 256, 256, 256]),
local_ensemble=True,
feat_unfold=True,
cell_decode=True,
eval_bsize=30000),
rgb_mean=(0.5, 0.5, 0.5),
rgb_std=(0.5, 0.5, 0.5),
pixel_loss=dict(type='L1Loss', loss_weight=1.0, reduction='mean'))
# model training and testing settings
train_cfg = None
test_cfg = dict(metrics=['PSNR', 'SSIM'], crop_border=scale_max)

# dataset settings
scale_min, scale_max = 1, 4
# dataset settings
train_dataset_type = 'SRFolderGTDataset'
val_dataset_type = 'SRFolderGTDataset'
test_dataset_type = 'SRFolderDataset'
train_pipeline = [
dict(
type='LoadImageFromFile',
io_backend='disk',
key='gt',
flag='color',
channel_order='rgb'),
dict(
type='RandomDownSampling',
scale_min=scale_min,
scale_max=scale_max,
patch_size=48),
dict(type='RescaleToZeroOne', keys=['lq', 'gt']),
dict(
type='Flip', keys=['lq', 'gt'], flip_ratio=0.5,
direction='horizontal'),
dict(type='Flip', keys=['lq', 'gt'], flip_ratio=0.5, direction='vertical'),
dict(type='RandomTransposeHW', keys=['lq', 'gt'], transpose_ratio=0.5),
dict(type='ImageToTensor', keys=['lq', 'gt']),
dict(type='GenerateCoordinateAndCell', sample_quantity=2304),
dict(
type='Collect',
keys=['lq', 'gt', 'coord', 'cell'],
meta_keys=['gt_path'])
]
valid_pipeline = [
dict(
type='LoadImageFromFile',
io_backend='disk',
key='gt',
flag='color',
channel_order='rgb'),
dict(type='RandomDownSampling', scale_min=scale_max, scale_max=scale_max),
dict(type='RescaleToZeroOne', keys=['lq', 'gt']),
dict(type='ImageToTensor', keys=['lq', 'gt']),
dict(type='GenerateCoordinateAndCell'),
dict(
type='Collect',
keys=['lq', 'gt', 'coord', 'cell'],
meta_keys=['gt_path'])
]
test_pipeline = [
dict(
type='LoadImageFromFile',
io_backend='disk',
key='gt',
flag='color',
channel_order='rgb'),
dict(
type='LoadImageFromFile',
io_backend='disk',
key='lq',
flag='color',
channel_order='rgb'),
dict(type='RescaleToZeroOne', keys=['lq', 'gt']),
dict(type='ImageToTensor', keys=['lq', 'gt']),
dict(type='GenerateCoordinateAndCell', scale=scale_max),
dict(
type='Collect',
keys=['lq', 'gt', 'coord', 'cell'],
meta_keys=['gt_path'])
]

data = dict(
workers_per_gpu=8,
train_dataloader=dict(samples_per_gpu=16, drop_last=True),
val_dataloader=dict(samples_per_gpu=1),
test_dataloader=dict(samples_per_gpu=1),
train=dict(
type='RepeatDataset',
times=20,
dataset=dict(
type=train_dataset_type,
gt_folder='data/DIV2K/DIV2K_train_HR',
pipeline=train_pipeline,
scale=scale_max)),
val=dict(
type=val_dataset_type,
gt_folder='data/val_set5/Set5',
pipeline=valid_pipeline,
scale=scale_max),
test=dict(
type=test_dataset_type,
lq_folder=f'data/val_set5/Set5_bicLRx{scale_max:d}',
gt_folder='data/val_set5/Set5',
pipeline=test_pipeline,
scale=scale_max,
filename_tmpl='{}'))

# optimizer
optimizers = dict(type='Adam', lr=1.e-4)

# learning policy
iter_per_epoch = 1000
total_iters = 1000 * iter_per_epoch
lr_config = dict(
policy='Step',
by_epoch=False,
step=[200000, 400000, 600000, 800000],
gamma=0.5)

checkpoint_config = dict(interval=3000, save_optimizer=True, by_epoch=False)
evaluation = dict(interval=3000, save_image=True, gpu_collect=True)
log_config = dict(
interval=100,
hooks=[
dict(type='TextLoggerHook', by_epoch=False),
dict(type='TensorboardLoggerHook')
])
visual_config = None

# runtime settings
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = f'./work_dirs/{exp_name}'
load_from = None
resume_from = None
workflow = [('train', 1)]
find_unused_parameters = True
8 changes: 4 additions & 4 deletions mmedit/models/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
ResNetEnc, ResShortcutDec, ResShortcutEnc,
SimpleEncoderDecoder)
from .generation_backbones import ResnetGenerator, UnetGenerator
from .sr_backbones import (EDSR, LIIFEDSR, RDN, SRCNN, BasicVSRNet, DICNet,
EDVRNet, GLEANStyleGANv2, IconVSR, MSRResNet,
RRDBNet, TDANNet, TOFlow, TTSRNet)
from .sr_backbones import (EDSR, LIIFEDSR, LIIFRDN, RDN, SRCNN, BasicVSRNet,
DICNet, EDVRNet, GLEANStyleGANv2, IconVSR,
MSRResNet, RRDBNet, TDANNet, TOFlow, TTSRNet)

__all__ = [
'MSRResNet', 'VGG16', 'PlainDecoder', 'SimpleEncoderDecoder',
Expand All @@ -25,5 +25,5 @@
'IndexNetDecoder', 'TOFlow', 'ResGCAEncoder', 'ResGCADecoder', 'SRCNN',
'UnetGenerator', 'ResnetGenerator', 'FBAResnetDilated', 'FBADecoder',
'BasicVSRNet', 'IconVSR', 'TTSRNet', 'GLEANStyleGANv2', 'TDANNet',
'LIIFEDSR'
'LIIFEDSR', 'LIIFRDN'
]
4 changes: 2 additions & 2 deletions mmedit/models/backbones/sr_backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .edvr_net import EDVRNet
from .glean_styleganv2 import GLEANStyleGANv2
from .iconvsr import IconVSR
from .liif_net import LIIFEDSR
from .liif_net import LIIFEDSR, LIIFRDN
from .rdn import RDN
from .rrdb_net import RRDBNet
from .sr_resnet import MSRResNet
Expand All @@ -16,5 +16,5 @@
__all__ = [
'MSRResNet', 'RRDBNet', 'EDSR', 'EDVRNet', 'TOFlow', 'SRCNN', 'DICNet',
'BasicVSRNet', 'IconVSR', 'RDN', 'TTSRNet', 'GLEANStyleGANv2', 'TDANNet',
'LIIFEDSR'
'LIIFEDSR', 'LIIFRDN'
]
67 changes: 67 additions & 0 deletions mmedit/models/backbones/sr_backbones/liif_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ class LIIFNet(nn.Module):
Paper: Learning Continuous Image Representation with
Local Implicit Image Function
The subclasses should define `generator` with `encoder` and `imnet`,
and overwrite the function `gen_feature`.
If `encoder` does not contain `mid_channels`, `__init__` should be
overwrite.
Args:
encoder (dict): Config for the generator.
imnet (dict): Config for the imnet.
Expand Down Expand Up @@ -253,3 +258,65 @@ def gen_feature(self, x):
res += x

return res


@BACKBONES.register_module()
class LIIFRDN(LIIFNet):
"""LIIF net based on RDN.
Paper: Learning Continuous Image Representation with
Local Implicit Image Function
Args:
encoder (dict): Config for the generator.
imnet (dict): Config for the imnet.
local_ensemble (bool): Whether to use local ensemble. Default: True.
feat_unfold (bool): Whether to use feat unfold. Default: True.
cell_decode (bool): Whether to use cell decode. Default: True.
eval_bsize (int): Size of batched predict. Default: None.
"""

def __init__(self,
encoder,
imnet,
local_ensemble=True,
feat_unfold=True,
cell_decode=True,
eval_bsize=None):
super().__init__(
encoder=encoder,
imnet=imnet,
local_ensemble=local_ensemble,
feat_unfold=feat_unfold,
cell_decode=cell_decode,
eval_bsize=eval_bsize)

self.sfe1 = self.encoder.sfe1
self.sfe2 = self.encoder.sfe2
self.rdbs = self.encoder.rdbs
self.gff = self.encoder.gff
self.num_blocks = self.encoder.num_blocks
del self.encoder

def gen_feature(self, x):
"""Generate feature.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""

sfe1 = self.sfe1(x)
sfe2 = self.sfe2(sfe1)

x = sfe2
local_features = []
for i in range(self.num_blocks):
x = self.rdbs[i](x)
local_features.append(x)

x = self.gff(torch.cat(local_features, 1)) + sfe1

return x
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,57 @@ def test_liif_edsr():
output = model(inputs, coord, cell, True)
assert torch.is_tensor(output)
assert output.shape == targets.shape


def test_liif_rdn():

model_cfg = dict(
type='LIIFRDN',
encoder=dict(
type='RDN',
in_channels=3,
out_channels=3,
mid_channels=64,
num_blocks=16,
upscale_factor=4,
num_layers=8,
channel_growth=64),
imnet=dict(
type='MLPRefiner',
in_dim=64,
out_dim=3,
hidden_list=[256, 256, 256, 256]),
local_ensemble=True,
feat_unfold=True,
cell_decode=True,
eval_bsize=30000)

# build model
model = build_backbone(model_cfg)

# test attributes
assert model.__class__.__name__ == 'LIIFRDN'

# prepare data
inputs = torch.rand(1, 3, 22, 11)
targets = torch.rand(1, 128 * 64, 3)
coord = torch.rand(1, 128 * 64, 2)
cell = torch.rand(1, 128 * 64, 2)

# test on cpu
output = model(inputs, coord, cell)
output = model(inputs, coord, cell, True)
assert torch.is_tensor(output)
assert output.shape == targets.shape

# test on gpu
if torch.cuda.is_available():
model = model.cuda()
inputs = inputs.cuda()
targets = targets.cuda()
coord = coord.cuda()
cell = cell.cuda()
output = model(inputs, coord, cell)
output = model(inputs, coord, cell, True)
assert torch.is_tensor(output)
assert output.shape == targets.shape

0 comments on commit 43eddb9

Please sign in to comment.