-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Add RealBasicVSR backbone (#632)
* add RealBasicVSR backbone * Add unittest * add test * Add unittest
- Loading branch information
1 parent
6afc390
commit 0dda55f
Showing
4 changed files
with
170 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
108 changes: 108 additions & 0 deletions
108
mmedit/models/backbones/sr_backbones/real_basicvsr_net.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import torch | ||
import torch.nn as nn | ||
from mmcv.runner import load_checkpoint | ||
|
||
from mmedit.models.backbones.sr_backbones.basicvsr_net import ( | ||
BasicVSRNet, ResidualBlocksWithInputConv) | ||
from mmedit.models.registry import BACKBONES | ||
from mmedit.utils import get_root_logger | ||
|
||
|
||
@BACKBONES.register_module() | ||
class RealBasicVSRNet(nn.Module): | ||
"""RealBasicVSR network structure for real-world video super-resolution. | ||
Support only x4 upsampling. | ||
Paper: | ||
Investigating Tradeoffs in Real-World Video Super-Resolution, arXiv | ||
Args: | ||
mid_channels (int, optional): Channel number of the intermediate | ||
features. Default: 64. | ||
num_propagation_blocks (int, optional): Number of residual blocks in | ||
each propagation branch. Default: 20. | ||
num_cleaning_blocks (int, optional): Number of residual blocks in the | ||
image cleaning module. Default: 20. | ||
dynamic_refine_thres (int, optional): Stop cleaning the images when | ||
the residue is smaller than this value. Default: 255. | ||
spynet_pretrained (str, optional): Pre-trained model path of SPyNet. | ||
Default: None. | ||
is_fix_cleaning (bool, optional): Whether to fix the weights of | ||
the image cleaning module during training. Default: False. | ||
is_sequential_cleaning (bool, optional): Whether to clean the images | ||
sequentially. This is used to save GPU memory, but the speed is | ||
slightly slower. Default: False. | ||
""" | ||
|
||
def __init__(self, | ||
mid_channels=64, | ||
num_propagation_blocks=20, | ||
num_cleaning_blocks=20, | ||
dynamic_refine_thres=255, | ||
spynet_pretrained=None, | ||
is_fix_cleaning=False, | ||
is_sequential_cleaning=False): | ||
|
||
super().__init__() | ||
|
||
self.dynamic_refine_thres = dynamic_refine_thres / 255. | ||
self.is_sequential_cleaning = is_sequential_cleaning | ||
|
||
# image cleaning module | ||
self.image_cleaning = nn.Sequential( | ||
ResidualBlocksWithInputConv(3, mid_channels, num_cleaning_blocks), | ||
nn.Conv2d(mid_channels, 3, 3, 1, 1, bias=True), | ||
) | ||
|
||
if is_fix_cleaning: # keep the weights of the cleaning module fixed | ||
self.image_cleaning.requires_grad_(False) | ||
|
||
# BasicVSR | ||
self.basicvsr = BasicVSRNet(mid_channels, num_propagation_blocks, | ||
spynet_pretrained) | ||
self.basicvsr.spynet.requires_grad_(False) | ||
|
||
def forward(self, lqs, return_lqs=False): | ||
n, t, c, h, w = lqs.size() | ||
|
||
for _ in range(0, 3): # at most 3 cleaning, determined empirically | ||
if self.is_sequential_cleaning: | ||
residues = [] | ||
for i in range(0, t): | ||
residue_i = self.image_cleaning(lqs[:, i, :, :, :]) | ||
lqs[:, i, :, :, :] += residue_i | ||
residues.append(residue_i) | ||
residues = torch.stack(residues, dim=1) | ||
else: # time -> batch, then apply cleaning at once | ||
lqs = lqs.view(-1, c, h, w) | ||
residues = self.image_cleaning(lqs) | ||
lqs = (lqs + residues).view(n, t, c, h, w) | ||
|
||
# determine whether to continue cleaning | ||
if torch.mean(torch.abs(residues)) < self.dynamic_refine_thres: | ||
break | ||
|
||
# Super-resolution (BasicVSR) | ||
outputs = self.basicvsr(lqs) | ||
|
||
if return_lqs: | ||
return outputs, lqs | ||
else: | ||
return outputs | ||
|
||
def init_weights(self, pretrained=None, strict=True): | ||
"""Init weights for models. | ||
Args: | ||
pretrained (str, optional): Path for pretrained weights. If given | ||
None, pretrained weights will not be loaded. Defaults: None. | ||
strict (boo, optional): Whether strictly load the pretrained model. | ||
Defaults to True. | ||
""" | ||
if isinstance(pretrained, str): | ||
logger = get_root_logger() | ||
load_checkpoint(self, pretrained, strict=strict, logger=logger) | ||
elif pretrained is not None: | ||
raise TypeError(f'"pretrained" must be a str or None. ' | ||
f'But received {type(pretrained)}.') |
57 changes: 57 additions & 0 deletions
57
tests/test_models/test_backbones/test_sr_backbones/test_real_basicvsr_net.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import pytest | ||
import torch | ||
|
||
from mmedit.models.backbones.sr_backbones.real_basicvsr_net import \ | ||
RealBasicVSRNet | ||
|
||
|
||
def test_real_basicvsr_net(): | ||
"""Test RealBasicVSR.""" | ||
|
||
# cpu | ||
# is_fix_cleaning = False | ||
real_basicvsr = RealBasicVSRNet(is_fix_cleaning=False) | ||
|
||
# is_sequential_cleaning = False | ||
real_basicvsr = RealBasicVSRNet( | ||
is_fix_cleaning=True, is_sequential_cleaning=False) | ||
input_tensor = torch.rand(1, 5, 3, 64, 64) | ||
real_basicvsr.init_weights(pretrained=None) | ||
output = real_basicvsr(input_tensor) | ||
assert output.shape == (1, 5, 3, 256, 256) | ||
|
||
# is_sequential_cleaning = True, return_lq = True | ||
real_basicvsr = RealBasicVSRNet( | ||
is_fix_cleaning=True, is_sequential_cleaning=True) | ||
output, lq = real_basicvsr(input_tensor, return_lqs=True) | ||
assert output.shape == (1, 5, 3, 256, 256) | ||
assert lq.shape == (1, 5, 3, 64, 64) | ||
|
||
with pytest.raises(TypeError): | ||
# pretrained should be str or None | ||
real_basicvsr.init_weights(pretrained=[1]) | ||
|
||
# gpu | ||
if torch.cuda.is_available(): | ||
# is_fix_cleaning = False | ||
real_basicvsr = RealBasicVSRNet(is_fix_cleaning=False).cuda() | ||
|
||
# is_sequential_cleaning = False | ||
real_basicvsr = RealBasicVSRNet( | ||
is_fix_cleaning=True, is_sequential_cleaning=False).cuda() | ||
input_tensor = torch.rand(1, 5, 3, 64, 64).cuda() | ||
real_basicvsr.init_weights(pretrained=None) | ||
output = real_basicvsr(input_tensor) | ||
assert output.shape == (1, 5, 3, 256, 256) | ||
|
||
# is_sequential_cleaning = True, return_lq = True | ||
real_basicvsr = RealBasicVSRNet( | ||
is_fix_cleaning=True, is_sequential_cleaning=True).cuda() | ||
output, lq = real_basicvsr(input_tensor, return_lqs=True) | ||
assert output.shape == (1, 5, 3, 256, 256) | ||
assert lq.shape == (1, 5, 3, 64, 64) | ||
|
||
with pytest.raises(TypeError): | ||
# pretrained should be str or None | ||
real_basicvsr.init_weights(pretrained=[1]) |