From b0c54420fd2ead3f0b15df685332c926e758d233 Mon Sep 17 00:00:00 2001 From: ckkelvinchan Date: Tue, 18 May 2021 15:19:50 +0800 Subject: [PATCH 1/6] Add TDAN architecture --- mmedit/models/backbones/__init__.py | 4 +- .../models/backbones/sr_backbones/__init__.py | 3 +- .../models/backbones/sr_backbones/tdan_net.py | 138 ++++++++++++++++++ tests/test_tdan_net.py | 23 +++ 4 files changed, 165 insertions(+), 3 deletions(-) create mode 100644 mmedit/models/backbones/sr_backbones/tdan_net.py create mode 100644 tests/test_tdan_net.py diff --git a/mmedit/models/backbones/__init__.py b/mmedit/models/backbones/__init__.py index 444bb01c88..a7e5dc767e 100644 --- a/mmedit/models/backbones/__init__.py +++ b/mmedit/models/backbones/__init__.py @@ -13,7 +13,7 @@ # yapf: enable from .generation_backbones import ResnetGenerator, UnetGenerator from .sr_backbones import (EDSR, RDN, SRCNN, BasicVSRNet, EDVRNet, IconVSR, - MSRResNet, RRDBNet, TOFlow) + MSRResNet, RRDBNet, TDANNet, TOFlow) __all__ = [ 'MSRResNet', 'VGG16', 'PlainDecoder', 'SimpleEncoderDecoder', @@ -25,5 +25,5 @@ 'DeepFillEncoderDecoder', 'EDVRNet', 'IndexedUpsample', 'IndexNetEncoder', 'IndexNetDecoder', 'TOFlow', 'ResGCAEncoder', 'ResGCADecoder', 'SRCNN', 'UnetGenerator', 'ResnetGenerator', 'FBAResnetDilated', 'FBADecoder', - 'BasicVSRNet', 'IconVSR' + 'BasicVSRNet', 'IconVSR', 'TDANNet' ] diff --git a/mmedit/models/backbones/sr_backbones/__init__.py b/mmedit/models/backbones/sr_backbones/__init__.py index d96a81a9ca..74ce4ec284 100644 --- a/mmedit/models/backbones/sr_backbones/__init__.py +++ b/mmedit/models/backbones/sr_backbones/__init__.py @@ -6,9 +6,10 @@ from .rrdb_net import RRDBNet from .sr_resnet import MSRResNet from .srcnn import SRCNN +from .tdan_net import TDANNet from .tof import TOFlow __all__ = [ 'MSRResNet', 'RRDBNet', 'EDSR', 'EDVRNet', 'TOFlow', 'SRCNN', - 'BasicVSRNet', 'IconVSR', 'RDN' + 'BasicVSRNet', 'IconVSR', 'RDN', 'TDANNet' ] diff --git a/mmedit/models/backbones/sr_backbones/tdan_net.py b/mmedit/models/backbones/sr_backbones/tdan_net.py new file mode 100644 index 0000000000..f2577b55c1 --- /dev/null +++ b/mmedit/models/backbones/sr_backbones/tdan_net.py @@ -0,0 +1,138 @@ +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule, constant_init +from mmcv.ops import DeformConv2dPack, deform_conv2d +from mmcv.runner import load_checkpoint +from torch.nn.modules.utils import _pair + +from mmedit.models.common import (PixelShufflePack, ResidualBlockNoBN, + make_layer) +from mmedit.models.registry import BACKBONES +from mmedit.utils import get_root_logger + + +class AugmentedDeformConv2dPack(DeformConv2dPack): + """Augmented Deformable Convolution Pack. + + Different from the official DCN, which generates offsets from the + preceeding feature, this AugmentedDeformConv2dPack takes another feature to + generate the offsets. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int or tuple[int]): Same as nn.Conv2d. + stride (int or tuple[int]): Same as nn.Conv2d. + padding (int or tuple[int]): Same as nn.Conv2d. + dilation (int or tuple[int]): Same as nn.Conv2d. + groups (int): Same as nn.Conv2d. + bias (bool or str): If specified as `auto`, it will be decided by the + norm_cfg. Bias will be set as True if norm_cfg is None, otherwise + False. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.conv_offset = nn.Conv2d( + self.in_channels, + self.deform_groups * 2 * self.kernel_size[0] * self.kernel_size[1], + kernel_size=self.kernel_size, + stride=_pair(self.stride), + padding=_pair(self.padding), + bias=True) + + self.init_offset() + + def init_offset(self): + constant_init(self.conv_offset, val=0, bias=0) + + def forward(self, x, extra_feat): + offset = self.conv_offset(extra_feat) + return deform_conv2d(x, offset, self.weight, self.stride, self.padding, + self.dilation, self.groups, self.deform_groups) + + +@BACKBONES.register_module() +class TDANNet(nn.Module): + """TDAN network structure for video super-resolution. + + Support only x4 upsampling. + Paper: + TDAN: Temporally-Deformable Alignment Network for Video Super- + Resolution, CVPR, 2020 + """ + + def __init__(self): + + super().__init__() + + self.feat_extract = nn.Sequential( + ConvModule(3, 64, 3, padding=1), + make_layer(ResidualBlockNoBN, 5, mid_channels=64)) + + self.feat_aggregate = nn.Sequential( + nn.Conv2d(128, 64, 3, padding=1, bias=True), + DeformConv2dPack(64, 64, 3, padding=1, deform_groups=8), + DeformConv2dPack(64, 64, 3, padding=1, deform_groups=8)) + self.align_1 = AugmentedDeformConv2dPack( + 64, 64, 3, padding=1, deform_groups=8) + self.align_2 = DeformConv2dPack(64, 64, 3, padding=1, deform_groups=8) + self.to_rgb = nn.Conv2d(64, 3, 3, padding=1, bias=True) + + self.reconstruct = nn.Sequential( + ConvModule(3 * 5, 64, 3, padding=1), + make_layer(ResidualBlockNoBN, 10, mid_channels=64), + PixelShufflePack(64, 64, 2, upsample_kernel=3), + PixelShufflePack(64, 64, 2, upsample_kernel=3), + nn.Conv2d(64, 3, 3, 1, 1, bias=False)) + + def forward(self, lrs): + """Forward function for TDANNet. + + Args: + lrs (Tensor): Input LR sequence with shape (n, t, c, h, w). + + Returns: + Tensor: Output HR image with shape (n, c, 4h, 4w). + """ + + n, t, c, h, w = lrs.size() + lr_center = lrs[:, t // 2, :, :, :] # LR center frame + + # extract features + feats = self.feat_extract(lrs.view(-1, c, h, w)).view(n, t, -1, h, w) + + # alignment of LR frames + feat_center = feats[:, t // 2, :, :, :].contiguous() + aligned_lrs = [] + for i in range(0, t): + if i == t // 2: + aligned_lrs.append(lr_center) + else: + feat_neig = feats[:, i, :, :, :].contiguous() + feat_agg = torch.cat([feat_center, feat_neig], dim=1) + feat_agg = self.feat_aggregate(feat_agg) + + aligned_feat = self.align_2(self.align_1(feat_neig, feat_agg)) + aligned_lrs.append(self.to_rgb(aligned_feat)) + aligned_lrs = torch.cat(aligned_lrs, dim=1) + + # return HR center frame and the aligned LR frames + return self.reconstruct(aligned_lrs), aligned_lrs.view(n, t, c, h, w) + + def init_weights(self, pretrained=None, strict=True): + """Init weights for models. + + Args: + pretrained (str, optional): Path for pretrained weights. If given + None, pretrained weights will not be loaded. Defaults: None. + strict (boo, optional): Whether strictly load the pretrained model. + Defaults to True. + """ + if isinstance(pretrained, str): + logger = get_root_logger() + load_checkpoint(self, pretrained, strict=strict, logger=logger) + elif pretrained is not None: + raise TypeError(f'"pretrained" must be a str or None. ' + f'But received {type(pretrained)}.') diff --git a/tests/test_tdan_net.py b/tests/test_tdan_net.py new file mode 100644 index 0000000000..74ae067d36 --- /dev/null +++ b/tests/test_tdan_net.py @@ -0,0 +1,23 @@ +import pytest +import torch + +from mmedit.models.backbones.sr_backbones.tdan_net import TDANNet + + +def test_tdan_net(): + """Test TDANNet.""" + + # gpu (DCN is avaialble only on GPU) + if torch.cuda.is_available(): + tdan = TDANNet().cuda() + input_tensor = torch.rand(1, 5, 3, 64, 64).cuda() + tdan.init_weights(pretrained=None) + + output = tdan(input_tensor) + assert len(output) == 2 # (1) HR center + (2) aligned LRs + assert output[0].shape == (1, 3, 256, 256) # HR center frame + assert output[1].shape == (1, 5, 3, 64, 64) # aligned LRs + + with pytest.raises(TypeError): + # pretrained should be str or None + tdan.init_weights(pretrained=[1]) From 88208cfa8d0bde3c96b578fbce0e3a8feaf4b8e4 Mon Sep 17 00:00:00 2001 From: ckkelvinchan Date: Tue, 18 May 2021 15:23:14 +0800 Subject: [PATCH 2/6] Modify docstring --- mmedit/models/backbones/sr_backbones/tdan_net.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmedit/models/backbones/sr_backbones/tdan_net.py b/mmedit/models/backbones/sr_backbones/tdan_net.py index f2577b55c1..2b94d78948 100644 --- a/mmedit/models/backbones/sr_backbones/tdan_net.py +++ b/mmedit/models/backbones/sr_backbones/tdan_net.py @@ -14,7 +14,7 @@ class AugmentedDeformConv2dPack(DeformConv2dPack): """Augmented Deformable Convolution Pack. - Different from the official DCN, which generates offsets from the + Different from DeformConv2dPack, which generates offsets from the preceeding feature, this AugmentedDeformConv2dPack takes another feature to generate the offsets. From 3e9afab5c5a741b61243c691e34b364edca65977 Mon Sep 17 00:00:00 2001 From: ckkelvinchan Date: Tue, 18 May 2021 15:30:53 +0800 Subject: [PATCH 3/6] Fix bug in unittest --- tests/test_tdan_net.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_tdan_net.py b/tests/test_tdan_net.py index 74ae067d36..de36a52c21 100644 --- a/tests/test_tdan_net.py +++ b/tests/test_tdan_net.py @@ -18,6 +18,6 @@ def test_tdan_net(): assert output[0].shape == (1, 3, 256, 256) # HR center frame assert output[1].shape == (1, 5, 3, 64, 64) # aligned LRs - with pytest.raises(TypeError): - # pretrained should be str or None - tdan.init_weights(pretrained=[1]) + with pytest.raises(TypeError): + # pretrained should be str or None + tdan.init_weights(pretrained=[1]) From 807228bb71ddc0b69f30a3b19cdc3826ef297afb Mon Sep 17 00:00:00 2001 From: ckkelvinchan Date: Wed, 26 May 2021 10:28:25 +0800 Subject: [PATCH 4/6] Update backbone --- .../models/backbones/sr_backbones/tdan_net.py | 74 ++++++++++++------- 1 file changed, 48 insertions(+), 26 deletions(-) diff --git a/mmedit/models/backbones/sr_backbones/tdan_net.py b/mmedit/models/backbones/sr_backbones/tdan_net.py index 2b94d78948..331503acc3 100644 --- a/mmedit/models/backbones/sr_backbones/tdan_net.py +++ b/mmedit/models/backbones/sr_backbones/tdan_net.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn from mmcv.cnn import ConvModule, constant_init -from mmcv.ops import DeformConv2dPack, deform_conv2d +from mmcv.ops import DeformConv2d, DeformConv2dPack, deform_conv2d from mmcv.runner import load_checkpoint from torch.nn.modules.utils import _pair @@ -11,7 +11,7 @@ from mmedit.utils import get_root_logger -class AugmentedDeformConv2dPack(DeformConv2dPack): +class AugmentedDeformConv2dPack(DeformConv2d): """Augmented Deformable Convolution Pack. Different from DeformConv2dPack, which generates offsets from the @@ -19,13 +19,17 @@ class AugmentedDeformConv2dPack(DeformConv2dPack): generate the offsets. Args: - in_channels (int): Same as nn.Conv2d. - out_channels (int): Same as nn.Conv2d. - kernel_size (int or tuple[int]): Same as nn.Conv2d. - stride (int or tuple[int]): Same as nn.Conv2d. - padding (int or tuple[int]): Same as nn.Conv2d. - dilation (int or tuple[int]): Same as nn.Conv2d. - groups (int): Same as nn.Conv2d. + in_channels (int): Number of channels in the input feature. + out_channels (int): Number of channels produced by the convolution. + kernel_size (int or tuple[int]): Size of the convolving kernel. + stride (int or tuple[int]): Stride of the convolution. Default: 1. + padding (int or tuple[int]): Zero-padding added to both sides of the + input. Default: 0. + dilation (int or tuple[int]): Spacing between kernel elements. + Default: 1. + groups (int): Number of blocked connections from input channels to + output channels. Default: 1. + deform_groups (int): Number of deformable group partitions. bias (bool or str): If specified as `auto`, it will be decided by the norm_cfg. Bias will be set as True if norm_cfg is None, otherwise False. @@ -61,31 +65,48 @@ class TDANNet(nn.Module): Paper: TDAN: Temporally-Deformable Alignment Network for Video Super- Resolution, CVPR, 2020 + + Args: + in_channels (int): Number of channels of the input image. Default: 3. + mid_channels (int): Number of channels of the intermediate features. + Default: 64. + out_channels (int): Number of channels of the output image. Default: 3. + num_blocks (list[int]): Number of residual blocks before and after + temporal alignment. Default: [5, 10]. """ - def __init__(self): + def __init__(self, + in_channels=3, + mid_channels=64, + out_channels=3, + num_blocks=[5, 10]): super().__init__() self.feat_extract = nn.Sequential( - ConvModule(3, 64, 3, padding=1), - make_layer(ResidualBlockNoBN, 5, mid_channels=64)) + ConvModule(in_channels, mid_channels, 3, padding=1), + make_layer( + ResidualBlockNoBN, num_blocks[0], mid_channels=mid_channels)) self.feat_aggregate = nn.Sequential( - nn.Conv2d(128, 64, 3, padding=1, bias=True), - DeformConv2dPack(64, 64, 3, padding=1, deform_groups=8), - DeformConv2dPack(64, 64, 3, padding=1, deform_groups=8)) + nn.Conv2d(mid_channels * 2, mid_channels, 3, padding=1, bias=True), + DeformConv2dPack( + mid_channels, mid_channels, 3, padding=1, deform_groups=8), + DeformConv2dPack( + mid_channels, mid_channels, 3, padding=1, deform_groups=8)) self.align_1 = AugmentedDeformConv2dPack( - 64, 64, 3, padding=1, deform_groups=8) - self.align_2 = DeformConv2dPack(64, 64, 3, padding=1, deform_groups=8) - self.to_rgb = nn.Conv2d(64, 3, 3, padding=1, bias=True) + mid_channels, mid_channels, 3, padding=1, deform_groups=8) + self.align_2 = DeformConv2dPack( + mid_channels, mid_channels, 3, padding=1, deform_groups=8) + self.to_rgb = nn.Conv2d(mid_channels, 3, 3, padding=1, bias=True) self.reconstruct = nn.Sequential( - ConvModule(3 * 5, 64, 3, padding=1), - make_layer(ResidualBlockNoBN, 10, mid_channels=64), - PixelShufflePack(64, 64, 2, upsample_kernel=3), - PixelShufflePack(64, 64, 2, upsample_kernel=3), - nn.Conv2d(64, 3, 3, 1, 1, bias=False)) + ConvModule(mid_channels * 5, mid_channels, 3, padding=1), + make_layer( + ResidualBlockNoBN, num_blocks[1], mid_channels=mid_channels), + PixelShufflePack(mid_channels, mid_channels, 2, upsample_kernel=3), + PixelShufflePack(mid_channels, mid_channels, 2, upsample_kernel=3), + nn.Conv2d(mid_channels, out_channels, 3, 1, 1, bias=False)) def forward(self, lrs): """Forward function for TDANNet. @@ -94,9 +115,9 @@ def forward(self, lrs): lrs (Tensor): Input LR sequence with shape (n, t, c, h, w). Returns: - Tensor: Output HR image with shape (n, c, 4h, 4w). + tuple[Tensor]: Output HR image with shape (n, c, 4h, 4w) and + aligned LR images with shape (n, t, c, h, w). """ - n, t, c, h, w = lrs.size() lr_center = lrs[:, t // 2, :, :, :] # LR center frame @@ -116,9 +137,10 @@ def forward(self, lrs): aligned_feat = self.align_2(self.align_1(feat_neig, feat_agg)) aligned_lrs.append(self.to_rgb(aligned_feat)) + aligned_lrs = torch.cat(aligned_lrs, dim=1) - # return HR center frame and the aligned LR frames + # output HR center frame and the aligned LR frames return self.reconstruct(aligned_lrs), aligned_lrs.view(n, t, c, h, w) def init_weights(self, pretrained=None, strict=True): From 48fa4df7a5629cc6a68b559229a2cc28b023906c Mon Sep 17 00:00:00 2001 From: ckkelvinchan Date: Wed, 26 May 2021 10:39:06 +0800 Subject: [PATCH 5/6] Minor update --- mmedit/models/backbones/sr_backbones/tdan_net.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmedit/models/backbones/sr_backbones/tdan_net.py b/mmedit/models/backbones/sr_backbones/tdan_net.py index 331503acc3..af30f69035 100644 --- a/mmedit/models/backbones/sr_backbones/tdan_net.py +++ b/mmedit/models/backbones/sr_backbones/tdan_net.py @@ -101,7 +101,7 @@ def __init__(self, self.to_rgb = nn.Conv2d(mid_channels, 3, 3, padding=1, bias=True) self.reconstruct = nn.Sequential( - ConvModule(mid_channels * 5, mid_channels, 3, padding=1), + ConvModule(in_channels * 5, mid_channels, 3, padding=1), make_layer( ResidualBlockNoBN, num_blocks[1], mid_channels=mid_channels), PixelShufflePack(mid_channels, mid_channels, 2, upsample_kernel=3), From 180615fdacdfefdfd55b29472adc28999e70530b Mon Sep 17 00:00:00 2001 From: ckkelvinchan Date: Fri, 28 May 2021 10:41:57 +0800 Subject: [PATCH 6/6] Change TDANNet arguments --- .../models/backbones/sr_backbones/tdan_net.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/mmedit/models/backbones/sr_backbones/tdan_net.py b/mmedit/models/backbones/sr_backbones/tdan_net.py index af30f69035..e39e8d915d 100644 --- a/mmedit/models/backbones/sr_backbones/tdan_net.py +++ b/mmedit/models/backbones/sr_backbones/tdan_net.py @@ -71,22 +71,27 @@ class TDANNet(nn.Module): mid_channels (int): Number of channels of the intermediate features. Default: 64. out_channels (int): Number of channels of the output image. Default: 3. - num_blocks (list[int]): Number of residual blocks before and after - temporal alignment. Default: [5, 10]. + num_blocks_before_align (int): Number of residual blocks before + temporal alignment. Default: 5. + num_blocks_before_align (int): Number of residual blocks after + temporal alignment. Default: 10. """ def __init__(self, in_channels=3, mid_channels=64, out_channels=3, - num_blocks=[5, 10]): + num_blocks_before_align=5, + num_blocks_after_align=10): super().__init__() self.feat_extract = nn.Sequential( ConvModule(in_channels, mid_channels, 3, padding=1), make_layer( - ResidualBlockNoBN, num_blocks[0], mid_channels=mid_channels)) + ResidualBlockNoBN, + num_blocks_before_align, + mid_channels=mid_channels)) self.feat_aggregate = nn.Sequential( nn.Conv2d(mid_channels * 2, mid_channels, 3, padding=1, bias=True), @@ -103,7 +108,9 @@ def __init__(self, self.reconstruct = nn.Sequential( ConvModule(in_channels * 5, mid_channels, 3, padding=1), make_layer( - ResidualBlockNoBN, num_blocks[1], mid_channels=mid_channels), + ResidualBlockNoBN, + num_blocks_after_align, + mid_channels=mid_channels), PixelShufflePack(mid_channels, mid_channels, 2, upsample_kernel=3), PixelShufflePack(mid_channels, mid_channels, 2, upsample_kernel=3), nn.Conv2d(mid_channels, out_channels, 3, 1, 1, bias=False))