Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TDAN backbone #316

Merged
merged 8 commits into from
May 28, 2021
Merged

Add TDAN backbone #316

merged 8 commits into from
May 28, 2021

Conversation

ckkelvinchan
Copy link
Member

TDAN: Temporally-Deformable Alignment Network for Video Super-Resolution, CVPR, 2020

@innerlee innerlee requested a review from Yshuo-Li May 18, 2021 07:22
@ckkelvinchan ckkelvinchan changed the title Add TDAN architecture Add TDAN backbone May 18, 2021
@codecov
Copy link

codecov bot commented May 19, 2021

Codecov Report

Merging #316 (c3dd96a) into master (b72baab) will decrease coverage by 0.24%.
The diff coverage is 36.53%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #316      +/-   ##
==========================================
- Coverage   80.30%   80.05%   -0.25%     
==========================================
  Files         177      178       +1     
  Lines        9280     9332      +52     
  Branches     1352     1357       +5     
==========================================
+ Hits         7452     7471      +19     
- Misses       1640     1673      +33     
  Partials      188      188              
Flag Coverage Δ
unittests 80.03% <36.53%> (-0.25%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
mmedit/models/backbones/__init__.py 100.00% <ø> (ø)
mmedit/models/backbones/sr_backbones/tdan_net.py 35.29% <35.29%> (ø)
mmedit/models/backbones/sr_backbones/__init__.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update b72baab...c3dd96a. Read the comment docs.

generate the offsets.

Args:
in_channels (int): Same as nn.Conv2d.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please refactor docstring of args or copy explain of these args from nn.Conv2d

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay~

False.
"""

def __init__(self, *args, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about the same as docstring rather than use *args and **kwargs?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class is similar to DeformConv2dPack. Therefore I think it is better to follow the definition in DeformConv2dPack.

Reference: https://mmcv.readthedocs.io/en/latest/_modules/mmcv/ops/deform_conv.html#DeformConv2dPack

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK


self.conv_offset = nn.Conv2d(
self.in_channels,
self.deform_groups * 2 * self.kernel_size[0] * self.kernel_size[1],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deform_groups is not in docstring

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay~

super().__init__()

self.feat_extract = nn.Sequential(
ConvModule(3, 64, 3, padding=1),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about declaring args such as in_channels, out_channels, mid_channels=64.


self.reconstruct = nn.Sequential(
ConvModule(3 * 5, 64, 3, padding=1),
make_layer(ResidualBlockNoBN, 10, mid_channels=64),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And num_blocks

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

such as num_blocks=10

lrs (Tensor): Input LR sequence with shape (n, t, c, h, w).

Returns:
Tensor: Output HR image with shape (n, c, 4h, 4w).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The returns in docstring are different from those in code.

@ckkelvinchan ckkelvinchan added the status/WIP work in progress normally label May 28, 2021
@ckkelvinchan
Copy link
Member Author

Please do not merge for the moment.

@ckkelvinchan
Copy link
Member Author

Please do not merge for the moment.

Okay now.

@ckkelvinchan ckkelvinchan removed the status/WIP work in progress normally label May 28, 2021
@innerlee innerlee merged commit 92cebc1 into open-mmlab:master May 28, 2021
@ckkelvinchan ckkelvinchan deleted the tdan_arch branch June 1, 2021 03:47
Yshuo-Li pushed a commit to Yshuo-Li/mmediting that referenced this pull request Jul 15, 2022
* Add TDAN architecture

* Modify docstring

* Fix bug in unittest

* Update backbone

* Minor update

* Change TDANNet arguments
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants