Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TDAN backbone #316

Merged
merged 8 commits into from
May 28, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mmedit/models/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# yapf: enable
from .generation_backbones import ResnetGenerator, UnetGenerator
from .sr_backbones import (EDSR, RDN, SRCNN, BasicVSRNet, EDVRNet, IconVSR,
MSRResNet, RRDBNet, TOFlow, TTSRNet)
MSRResNet, RRDBNet, TDANNet, TOFlow, TTSRNet)

__all__ = [
'MSRResNet', 'VGG16', 'PlainDecoder', 'SimpleEncoderDecoder',
Expand All @@ -25,5 +25,5 @@
'DeepFillEncoderDecoder', 'EDVRNet', 'IndexedUpsample', 'IndexNetEncoder',
'IndexNetDecoder', 'TOFlow', 'ResGCAEncoder', 'ResGCADecoder', 'SRCNN',
'UnetGenerator', 'ResnetGenerator', 'FBAResnetDilated', 'FBADecoder',
'BasicVSRNet', 'IconVSR', 'TTSRNet'
'BasicVSRNet', 'IconVSR', 'TTSRNet', 'TDANNet'
]
3 changes: 2 additions & 1 deletion mmedit/models/backbones/sr_backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
from .rrdb_net import RRDBNet
from .sr_resnet import MSRResNet
from .srcnn import SRCNN
from .tdan_net import TDANNet
from .tof import TOFlow
from .ttsr_net import TTSRNet

__all__ = [
'MSRResNet', 'RRDBNet', 'EDSR', 'EDVRNet', 'TOFlow', 'SRCNN',
'BasicVSRNet', 'IconVSR', 'RDN', 'TTSRNet'
'BasicVSRNet', 'IconVSR', 'RDN', 'TTSRNet', 'TDANNet'
]
160 changes: 160 additions & 0 deletions mmedit/models/backbones/sr_backbones/tdan_net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, constant_init
from mmcv.ops import DeformConv2d, DeformConv2dPack, deform_conv2d
from mmcv.runner import load_checkpoint
from torch.nn.modules.utils import _pair

from mmedit.models.common import (PixelShufflePack, ResidualBlockNoBN,
make_layer)
from mmedit.models.registry import BACKBONES
from mmedit.utils import get_root_logger


class AugmentedDeformConv2dPack(DeformConv2d):
"""Augmented Deformable Convolution Pack.

Different from DeformConv2dPack, which generates offsets from the
preceeding feature, this AugmentedDeformConv2dPack takes another feature to
generate the offsets.

Args:
in_channels (int): Number of channels in the input feature.
out_channels (int): Number of channels produced by the convolution.
kernel_size (int or tuple[int]): Size of the convolving kernel.
stride (int or tuple[int]): Stride of the convolution. Default: 1.
padding (int or tuple[int]): Zero-padding added to both sides of the
input. Default: 0.
dilation (int or tuple[int]): Spacing between kernel elements.
Default: 1.
groups (int): Number of blocked connections from input channels to
output channels. Default: 1.
deform_groups (int): Number of deformable group partitions.
bias (bool or str): If specified as `auto`, it will be decided by the
norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
False.
"""

def __init__(self, *args, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about the same as docstring rather than use *args and **kwargs?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class is similar to DeformConv2dPack. Therefore I think it is better to follow the definition in DeformConv2dPack.

Reference: https://mmcv.readthedocs.io/en/latest/_modules/mmcv/ops/deform_conv.html#DeformConv2dPack

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

super().__init__(*args, **kwargs)

self.conv_offset = nn.Conv2d(
self.in_channels,
self.deform_groups * 2 * self.kernel_size[0] * self.kernel_size[1],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deform_groups is not in docstring

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay~

kernel_size=self.kernel_size,
stride=_pair(self.stride),
padding=_pair(self.padding),
bias=True)

self.init_offset()

def init_offset(self):
constant_init(self.conv_offset, val=0, bias=0)

def forward(self, x, extra_feat):
offset = self.conv_offset(extra_feat)
return deform_conv2d(x, offset, self.weight, self.stride, self.padding,
self.dilation, self.groups, self.deform_groups)


@BACKBONES.register_module()
class TDANNet(nn.Module):
"""TDAN network structure for video super-resolution.

Support only x4 upsampling.
Paper:
TDAN: Temporally-Deformable Alignment Network for Video Super-
Resolution, CVPR, 2020

Args:
in_channels (int): Number of channels of the input image. Default: 3.
mid_channels (int): Number of channels of the intermediate features.
Default: 64.
out_channels (int): Number of channels of the output image. Default: 3.
num_blocks (list[int]): Number of residual blocks before and after
temporal alignment. Default: [5, 10].
"""

def __init__(self,
in_channels=3,
mid_channels=64,
out_channels=3,
num_blocks=[5, 10]):

super().__init__()

self.feat_extract = nn.Sequential(
ConvModule(in_channels, mid_channels, 3, padding=1),
make_layer(
ResidualBlockNoBN, num_blocks[0], mid_channels=mid_channels))

self.feat_aggregate = nn.Sequential(
nn.Conv2d(mid_channels * 2, mid_channels, 3, padding=1, bias=True),
DeformConv2dPack(
mid_channels, mid_channels, 3, padding=1, deform_groups=8),
DeformConv2dPack(
mid_channels, mid_channels, 3, padding=1, deform_groups=8))
self.align_1 = AugmentedDeformConv2dPack(
mid_channels, mid_channels, 3, padding=1, deform_groups=8)
self.align_2 = DeformConv2dPack(
mid_channels, mid_channels, 3, padding=1, deform_groups=8)
self.to_rgb = nn.Conv2d(mid_channels, 3, 3, padding=1, bias=True)

self.reconstruct = nn.Sequential(
ConvModule(in_channels * 5, mid_channels, 3, padding=1),
make_layer(
ResidualBlockNoBN, num_blocks[1], mid_channels=mid_channels),
PixelShufflePack(mid_channels, mid_channels, 2, upsample_kernel=3),
PixelShufflePack(mid_channels, mid_channels, 2, upsample_kernel=3),
nn.Conv2d(mid_channels, out_channels, 3, 1, 1, bias=False))

def forward(self, lrs):
"""Forward function for TDANNet.

Args:
lrs (Tensor): Input LR sequence with shape (n, t, c, h, w).

Returns:
tuple[Tensor]: Output HR image with shape (n, c, 4h, 4w) and
aligned LR images with shape (n, t, c, h, w).
"""
n, t, c, h, w = lrs.size()
lr_center = lrs[:, t // 2, :, :, :] # LR center frame

# extract features
feats = self.feat_extract(lrs.view(-1, c, h, w)).view(n, t, -1, h, w)

# alignment of LR frames
feat_center = feats[:, t // 2, :, :, :].contiguous()
aligned_lrs = []
for i in range(0, t):
if i == t // 2:
aligned_lrs.append(lr_center)
else:
feat_neig = feats[:, i, :, :, :].contiguous()
feat_agg = torch.cat([feat_center, feat_neig], dim=1)
feat_agg = self.feat_aggregate(feat_agg)

aligned_feat = self.align_2(self.align_1(feat_neig, feat_agg))
aligned_lrs.append(self.to_rgb(aligned_feat))

aligned_lrs = torch.cat(aligned_lrs, dim=1)

# output HR center frame and the aligned LR frames
return self.reconstruct(aligned_lrs), aligned_lrs.view(n, t, c, h, w)

def init_weights(self, pretrained=None, strict=True):
"""Init weights for models.

Args:
pretrained (str, optional): Path for pretrained weights. If given
None, pretrained weights will not be loaded. Defaults: None.
strict (boo, optional): Whether strictly load the pretrained model.
Defaults to True.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=strict, logger=logger)
elif pretrained is not None:
raise TypeError(f'"pretrained" must be a str or None. '
f'But received {type(pretrained)}.')
23 changes: 23 additions & 0 deletions tests/test_tdan_net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pytest
import torch

from mmedit.models.backbones.sr_backbones.tdan_net import TDANNet


def test_tdan_net():
"""Test TDANNet."""

# gpu (DCN is avaialble only on GPU)
if torch.cuda.is_available():
tdan = TDANNet().cuda()
input_tensor = torch.rand(1, 5, 3, 64, 64).cuda()
tdan.init_weights(pretrained=None)

output = tdan(input_tensor)
assert len(output) == 2 # (1) HR center + (2) aligned LRs
assert output[0].shape == (1, 3, 256, 256) # HR center frame
assert output[1].shape == (1, 5, 3, 64, 64) # aligned LRs

with pytest.raises(TypeError):
# pretrained should be str or None
tdan.init_weights(pretrained=[1])