Skip to content

Commit

Permalink
Add TDAN backbone (#316)
Browse files Browse the repository at this point in the history
* Add TDAN architecture

* Modify docstring

* Fix bug in unittest

* Update backbone

* Minor update

* Change TDANNet arguments
  • Loading branch information
ckkelvinchan authored May 28, 2021
1 parent b72baab commit 92cebc1
Show file tree
Hide file tree
Showing 4 changed files with 194 additions and 3 deletions.
4 changes: 2 additions & 2 deletions mmedit/models/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .generation_backbones import ResnetGenerator, UnetGenerator
from .sr_backbones import (EDSR, RDN, SRCNN, BasicVSRNet, EDVRNet,
GLEANStyleGANv2, IconVSR, MSRResNet, RRDBNet,
TOFlow, TTSRNet)
TDANNet, TOFlow, TTSRNet)

__all__ = [
'MSRResNet', 'VGG16', 'PlainDecoder', 'SimpleEncoderDecoder',
Expand All @@ -26,5 +26,5 @@
'DeepFillEncoderDecoder', 'EDVRNet', 'IndexedUpsample', 'IndexNetEncoder',
'IndexNetDecoder', 'TOFlow', 'ResGCAEncoder', 'ResGCADecoder', 'SRCNN',
'UnetGenerator', 'ResnetGenerator', 'FBAResnetDilated', 'FBADecoder',
'BasicVSRNet', 'IconVSR', 'TTSRNet', 'GLEANStyleGANv2'
'BasicVSRNet', 'IconVSR', 'TTSRNet', 'GLEANStyleGANv2', 'TDANNet'
]
3 changes: 2 additions & 1 deletion mmedit/models/backbones/sr_backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
from .rrdb_net import RRDBNet
from .sr_resnet import MSRResNet
from .srcnn import SRCNN
from .tdan_net import TDANNet
from .tof import TOFlow
from .ttsr_net import TTSRNet

__all__ = [
'MSRResNet', 'RRDBNet', 'EDSR', 'EDVRNet', 'TOFlow', 'SRCNN',
'BasicVSRNet', 'IconVSR', 'RDN', 'TTSRNet', 'GLEANStyleGANv2'
'BasicVSRNet', 'IconVSR', 'RDN', 'TTSRNet', 'GLEANStyleGANv2', 'TDANNet'
]
167 changes: 167 additions & 0 deletions mmedit/models/backbones/sr_backbones/tdan_net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, constant_init
from mmcv.ops import DeformConv2d, DeformConv2dPack, deform_conv2d
from mmcv.runner import load_checkpoint
from torch.nn.modules.utils import _pair

from mmedit.models.common import (PixelShufflePack, ResidualBlockNoBN,
make_layer)
from mmedit.models.registry import BACKBONES
from mmedit.utils import get_root_logger


class AugmentedDeformConv2dPack(DeformConv2d):
"""Augmented Deformable Convolution Pack.
Different from DeformConv2dPack, which generates offsets from the
preceeding feature, this AugmentedDeformConv2dPack takes another feature to
generate the offsets.
Args:
in_channels (int): Number of channels in the input feature.
out_channels (int): Number of channels produced by the convolution.
kernel_size (int or tuple[int]): Size of the convolving kernel.
stride (int or tuple[int]): Stride of the convolution. Default: 1.
padding (int or tuple[int]): Zero-padding added to both sides of the
input. Default: 0.
dilation (int or tuple[int]): Spacing between kernel elements.
Default: 1.
groups (int): Number of blocked connections from input channels to
output channels. Default: 1.
deform_groups (int): Number of deformable group partitions.
bias (bool or str): If specified as `auto`, it will be decided by the
norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
False.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self.conv_offset = nn.Conv2d(
self.in_channels,
self.deform_groups * 2 * self.kernel_size[0] * self.kernel_size[1],
kernel_size=self.kernel_size,
stride=_pair(self.stride),
padding=_pair(self.padding),
bias=True)

self.init_offset()

def init_offset(self):
constant_init(self.conv_offset, val=0, bias=0)

def forward(self, x, extra_feat):
offset = self.conv_offset(extra_feat)
return deform_conv2d(x, offset, self.weight, self.stride, self.padding,
self.dilation, self.groups, self.deform_groups)


@BACKBONES.register_module()
class TDANNet(nn.Module):
"""TDAN network structure for video super-resolution.
Support only x4 upsampling.
Paper:
TDAN: Temporally-Deformable Alignment Network for Video Super-
Resolution, CVPR, 2020
Args:
in_channels (int): Number of channels of the input image. Default: 3.
mid_channels (int): Number of channels of the intermediate features.
Default: 64.
out_channels (int): Number of channels of the output image. Default: 3.
num_blocks_before_align (int): Number of residual blocks before
temporal alignment. Default: 5.
num_blocks_before_align (int): Number of residual blocks after
temporal alignment. Default: 10.
"""

def __init__(self,
in_channels=3,
mid_channels=64,
out_channels=3,
num_blocks_before_align=5,
num_blocks_after_align=10):

super().__init__()

self.feat_extract = nn.Sequential(
ConvModule(in_channels, mid_channels, 3, padding=1),
make_layer(
ResidualBlockNoBN,
num_blocks_before_align,
mid_channels=mid_channels))

self.feat_aggregate = nn.Sequential(
nn.Conv2d(mid_channels * 2, mid_channels, 3, padding=1, bias=True),
DeformConv2dPack(
mid_channels, mid_channels, 3, padding=1, deform_groups=8),
DeformConv2dPack(
mid_channels, mid_channels, 3, padding=1, deform_groups=8))
self.align_1 = AugmentedDeformConv2dPack(
mid_channels, mid_channels, 3, padding=1, deform_groups=8)
self.align_2 = DeformConv2dPack(
mid_channels, mid_channels, 3, padding=1, deform_groups=8)
self.to_rgb = nn.Conv2d(mid_channels, 3, 3, padding=1, bias=True)

self.reconstruct = nn.Sequential(
ConvModule(in_channels * 5, mid_channels, 3, padding=1),
make_layer(
ResidualBlockNoBN,
num_blocks_after_align,
mid_channels=mid_channels),
PixelShufflePack(mid_channels, mid_channels, 2, upsample_kernel=3),
PixelShufflePack(mid_channels, mid_channels, 2, upsample_kernel=3),
nn.Conv2d(mid_channels, out_channels, 3, 1, 1, bias=False))

def forward(self, lrs):
"""Forward function for TDANNet.
Args:
lrs (Tensor): Input LR sequence with shape (n, t, c, h, w).
Returns:
tuple[Tensor]: Output HR image with shape (n, c, 4h, 4w) and
aligned LR images with shape (n, t, c, h, w).
"""
n, t, c, h, w = lrs.size()
lr_center = lrs[:, t // 2, :, :, :] # LR center frame

# extract features
feats = self.feat_extract(lrs.view(-1, c, h, w)).view(n, t, -1, h, w)

# alignment of LR frames
feat_center = feats[:, t // 2, :, :, :].contiguous()
aligned_lrs = []
for i in range(0, t):
if i == t // 2:
aligned_lrs.append(lr_center)
else:
feat_neig = feats[:, i, :, :, :].contiguous()
feat_agg = torch.cat([feat_center, feat_neig], dim=1)
feat_agg = self.feat_aggregate(feat_agg)

aligned_feat = self.align_2(self.align_1(feat_neig, feat_agg))
aligned_lrs.append(self.to_rgb(aligned_feat))

aligned_lrs = torch.cat(aligned_lrs, dim=1)

# output HR center frame and the aligned LR frames
return self.reconstruct(aligned_lrs), aligned_lrs.view(n, t, c, h, w)

def init_weights(self, pretrained=None, strict=True):
"""Init weights for models.
Args:
pretrained (str, optional): Path for pretrained weights. If given
None, pretrained weights will not be loaded. Defaults: None.
strict (boo, optional): Whether strictly load the pretrained model.
Defaults to True.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=strict, logger=logger)
elif pretrained is not None:
raise TypeError(f'"pretrained" must be a str or None. '
f'But received {type(pretrained)}.')
23 changes: 23 additions & 0 deletions tests/test_tdan_net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pytest
import torch

from mmedit.models.backbones.sr_backbones.tdan_net import TDANNet


def test_tdan_net():
"""Test TDANNet."""

# gpu (DCN is avaialble only on GPU)
if torch.cuda.is_available():
tdan = TDANNet().cuda()
input_tensor = torch.rand(1, 5, 3, 64, 64).cuda()
tdan.init_weights(pretrained=None)

output = tdan(input_tensor)
assert len(output) == 2 # (1) HR center + (2) aligned LRs
assert output[0].shape == (1, 3, 256, 256) # HR center frame
assert output[1].shape == (1, 5, 3, 64, 64) # aligned LRs

with pytest.raises(TypeError):
# pretrained should be str or None
tdan.init_weights(pretrained=[1])

0 comments on commit 92cebc1

Please sign in to comment.