Skip to content

Commit

Permalink
[Feature] Add RealBasicVSR backbone (#632)
Browse files Browse the repository at this point in the history
* add RealBasicVSR backbone

* Add unittest

* add test

* Add unittest
  • Loading branch information
ckkelvinchan authored Dec 4, 2021
1 parent 6afc390 commit 0dda55f
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 4 deletions.
6 changes: 3 additions & 3 deletions mmedit/models/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from .generation_backbones import ResnetGenerator, UnetGenerator
from .sr_backbones import (EDSR, LIIFEDSR, LIIFRDN, RDN, SRCNN, BasicVSRNet,
BasicVSRPlusPlus, DICNet, EDVRNet, GLEANStyleGANv2,
IconVSR, MSRResNet, RRDBNet, TDANNet, TOFlow,
TTSRNet)
IconVSR, MSRResNet, RealBasicVSRNet, RRDBNet,
TDANNet, TOFlow, TTSRNet)

__all__ = [
'MSRResNet', 'VGG16', 'PlainDecoder', 'SimpleEncoderDecoder',
Expand All @@ -27,5 +27,5 @@
'IndexNetDecoder', 'TOFlow', 'ResGCAEncoder', 'ResGCADecoder', 'SRCNN',
'UnetGenerator', 'ResnetGenerator', 'FBAResnetDilated', 'FBADecoder',
'BasicVSRNet', 'IconVSR', 'TTSRNet', 'GLEANStyleGANv2', 'TDANNet',
'LIIFEDSR', 'LIIFRDN', 'BasicVSRPlusPlus'
'LIIFEDSR', 'LIIFRDN', 'BasicVSRPlusPlus', 'RealBasicVSRNet'
]
3 changes: 2 additions & 1 deletion mmedit/models/backbones/sr_backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .iconvsr import IconVSR
from .liif_net import LIIFEDSR, LIIFRDN
from .rdn import RDN
from .real_basicvsr_net import RealBasicVSRNet
from .rrdb_net import RRDBNet
from .sr_resnet import MSRResNet
from .srcnn import SRCNN
Expand All @@ -18,5 +19,5 @@
__all__ = [
'MSRResNet', 'RRDBNet', 'EDSR', 'EDVRNet', 'TOFlow', 'SRCNN', 'DICNet',
'BasicVSRNet', 'IconVSR', 'RDN', 'TTSRNet', 'GLEANStyleGANv2', 'TDANNet',
'LIIFEDSR', 'LIIFRDN', 'BasicVSRPlusPlus'
'LIIFEDSR', 'LIIFRDN', 'BasicVSRPlusPlus', 'RealBasicVSRNet'
]
108 changes: 108 additions & 0 deletions mmedit/models/backbones/sr_backbones/real_basicvsr_net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.runner import load_checkpoint

from mmedit.models.backbones.sr_backbones.basicvsr_net import (
BasicVSRNet, ResidualBlocksWithInputConv)
from mmedit.models.registry import BACKBONES
from mmedit.utils import get_root_logger


@BACKBONES.register_module()
class RealBasicVSRNet(nn.Module):
"""RealBasicVSR network structure for real-world video super-resolution.
Support only x4 upsampling.
Paper:
Investigating Tradeoffs in Real-World Video Super-Resolution, arXiv
Args:
mid_channels (int, optional): Channel number of the intermediate
features. Default: 64.
num_propagation_blocks (int, optional): Number of residual blocks in
each propagation branch. Default: 20.
num_cleaning_blocks (int, optional): Number of residual blocks in the
image cleaning module. Default: 20.
dynamic_refine_thres (int, optional): Stop cleaning the images when
the residue is smaller than this value. Default: 255.
spynet_pretrained (str, optional): Pre-trained model path of SPyNet.
Default: None.
is_fix_cleaning (bool, optional): Whether to fix the weights of
the image cleaning module during training. Default: False.
is_sequential_cleaning (bool, optional): Whether to clean the images
sequentially. This is used to save GPU memory, but the speed is
slightly slower. Default: False.
"""

def __init__(self,
mid_channels=64,
num_propagation_blocks=20,
num_cleaning_blocks=20,
dynamic_refine_thres=255,
spynet_pretrained=None,
is_fix_cleaning=False,
is_sequential_cleaning=False):

super().__init__()

self.dynamic_refine_thres = dynamic_refine_thres / 255.
self.is_sequential_cleaning = is_sequential_cleaning

# image cleaning module
self.image_cleaning = nn.Sequential(
ResidualBlocksWithInputConv(3, mid_channels, num_cleaning_blocks),
nn.Conv2d(mid_channels, 3, 3, 1, 1, bias=True),
)

if is_fix_cleaning: # keep the weights of the cleaning module fixed
self.image_cleaning.requires_grad_(False)

# BasicVSR
self.basicvsr = BasicVSRNet(mid_channels, num_propagation_blocks,
spynet_pretrained)
self.basicvsr.spynet.requires_grad_(False)

def forward(self, lqs, return_lqs=False):
n, t, c, h, w = lqs.size()

for _ in range(0, 3): # at most 3 cleaning, determined empirically
if self.is_sequential_cleaning:
residues = []
for i in range(0, t):
residue_i = self.image_cleaning(lqs[:, i, :, :, :])
lqs[:, i, :, :, :] += residue_i
residues.append(residue_i)
residues = torch.stack(residues, dim=1)
else: # time -> batch, then apply cleaning at once
lqs = lqs.view(-1, c, h, w)
residues = self.image_cleaning(lqs)
lqs = (lqs + residues).view(n, t, c, h, w)

# determine whether to continue cleaning
if torch.mean(torch.abs(residues)) < self.dynamic_refine_thres:
break

# Super-resolution (BasicVSR)
outputs = self.basicvsr(lqs)

if return_lqs:
return outputs, lqs
else:
return outputs

def init_weights(self, pretrained=None, strict=True):
"""Init weights for models.
Args:
pretrained (str, optional): Path for pretrained weights. If given
None, pretrained weights will not be loaded. Defaults: None.
strict (boo, optional): Whether strictly load the pretrained model.
Defaults to True.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=strict, logger=logger)
elif pretrained is not None:
raise TypeError(f'"pretrained" must be a str or None. '
f'But received {type(pretrained)}.')
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch

from mmedit.models.backbones.sr_backbones.real_basicvsr_net import \
RealBasicVSRNet


def test_real_basicvsr_net():
"""Test RealBasicVSR."""

# cpu
# is_fix_cleaning = False
real_basicvsr = RealBasicVSRNet(is_fix_cleaning=False)

# is_sequential_cleaning = False
real_basicvsr = RealBasicVSRNet(
is_fix_cleaning=True, is_sequential_cleaning=False)
input_tensor = torch.rand(1, 5, 3, 64, 64)
real_basicvsr.init_weights(pretrained=None)
output = real_basicvsr(input_tensor)
assert output.shape == (1, 5, 3, 256, 256)

# is_sequential_cleaning = True, return_lq = True
real_basicvsr = RealBasicVSRNet(
is_fix_cleaning=True, is_sequential_cleaning=True)
output, lq = real_basicvsr(input_tensor, return_lqs=True)
assert output.shape == (1, 5, 3, 256, 256)
assert lq.shape == (1, 5, 3, 64, 64)

with pytest.raises(TypeError):
# pretrained should be str or None
real_basicvsr.init_weights(pretrained=[1])

# gpu
if torch.cuda.is_available():
# is_fix_cleaning = False
real_basicvsr = RealBasicVSRNet(is_fix_cleaning=False).cuda()

# is_sequential_cleaning = False
real_basicvsr = RealBasicVSRNet(
is_fix_cleaning=True, is_sequential_cleaning=False).cuda()
input_tensor = torch.rand(1, 5, 3, 64, 64).cuda()
real_basicvsr.init_weights(pretrained=None)
output = real_basicvsr(input_tensor)
assert output.shape == (1, 5, 3, 256, 256)

# is_sequential_cleaning = True, return_lq = True
real_basicvsr = RealBasicVSRNet(
is_fix_cleaning=True, is_sequential_cleaning=True).cuda()
output, lq = real_basicvsr(input_tensor, return_lqs=True)
assert output.shape == (1, 5, 3, 256, 256)
assert lq.shape == (1, 5, 3, 64, 64)

with pytest.raises(TypeError):
# pretrained should be str or None
real_basicvsr.init_weights(pretrained=[1])

0 comments on commit 0dda55f

Please sign in to comment.