Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RDN. #233

Merged
merged 9 commits into from
Apr 25, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mmedit/models/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
SimpleEncoderDecoder)
# yapf: enable
from .generation_backbones import ResnetGenerator, UnetGenerator
from .sr_backbones import EDSR, SRCNN, EDVRNet, MSRResNet, RRDBNet, TOFlow
from .sr_backbones import EDSR, RDN, SRCNN, EDVRNet, MSRResNet, RRDBNet, TOFlow

__all__ = [
'MSRResNet', 'VGG16', 'PlainDecoder', 'SimpleEncoderDecoder',
'GLEncoderDecoder', 'GLEncoder', 'GLDecoder', 'GLDilationNeck',
'PConvEncoderDecoder', 'PConvEncoder', 'PConvDecoder', 'ResNetEnc',
'ResNetDec', 'ResShortcutEnc', 'ResShortcutDec', 'RRDBNet',
'DeepFillEncoder', 'HolisticIndexBlock', 'DepthwiseIndexBlock',
'ContextualAttentionNeck', 'DeepFillDecoder', 'EDSR',
'ContextualAttentionNeck', 'DeepFillDecoder', 'EDSR', 'RDN',
'DeepFillEncoderDecoder', 'EDVRNet', 'IndexedUpsample', 'IndexNetEncoder',
'IndexNetDecoder', 'TOFlow', 'ResGCAEncoder', 'ResGCADecoder', 'SRCNN',
'UnetGenerator', 'ResnetGenerator', 'FBAResnetDilated', 'FBADecoder'
Expand Down
3 changes: 2 additions & 1 deletion mmedit/models/backbones/sr_backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from .edsr import EDSR
from .edvr_net import EDVRNet
from .rdn import RDN
from .rrdb_net import RRDBNet
from .sr_resnet import MSRResNet
from .srcnn import SRCNN
from .tof import TOFlow

__all__ = ['MSRResNet', 'RRDBNet', 'EDSR', 'EDVRNet', 'TOFlow', 'SRCNN']
__all__ = ['MSRResNet', 'RRDBNet', 'EDSR', 'EDVRNet', 'TOFlow', 'SRCNN', 'RDN']
187 changes: 187 additions & 0 deletions mmedit/models/backbones/sr_backbones/rdn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
import torch
from mmcv.runner import load_checkpoint
from torch import nn

from mmedit.models.registry import BACKBONES
from mmedit.utils import get_root_logger


class DenseLayer(nn.Module):
"""Dense layer

Args:
in_channels (int): Channel number of inputs.
out_channels (int): Channel number of outputs.

"""

def __init__(self, in_channels, out_channels):
super(DenseLayer, self).__init__()
innerlee marked this conversation as resolved.
Show resolved Hide resolved
self.conv = nn.Conv2d(
in_channels, out_channels, kernel_size=3, padding=3 // 2)
innerlee marked this conversation as resolved.
Show resolved Hide resolved
self.relu = nn.ReLU(inplace=True)

def forward(self, x):
"""Forward function.

Args:
x (Tensor): Input tensor with shape (n, c, h, w).

Returns:
Tensor: Forward results, tensor with shape (n, c, h, w).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the shape does not change after concatenating

"""
return torch.cat([x, self.relu(self.conv(x))], 1)


class RDB(nn.Module):
"""Residual Dense Block of Residual Dense Network

Args:
in_channels (int): Channel number of inputs.
out_channels (int): Channel number of outputs.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring does not match __init__

"""

def __init__(self, in_channels, growth_rate, num_layers):
super(RDB, self).__init__()
innerlee marked this conversation as resolved.
Show resolved Hide resolved
self.layers = nn.Sequential(*[
DenseLayer(in_channels + growth_rate * i, growth_rate)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is growth_rate an int or a float ratio?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is int, which is the channel number of dense layer output.

for i in range(num_layers)
])

# local feature fusion
self.lff = nn.Conv2d(
in_channels + growth_rate * num_layers, growth_rate, kernel_size=1)

def forward(self, x):
"""Forward function.

Args:
x (Tensor): Input tensor with shape (n, c, h, w).

Returns:
Tensor: Forward results, tensor with shape (n, c, h, w).
"""
return x + self.lff(self.layers(x)) # local residual learning


@BACKBONES.register_module()
class RDN(nn.Module):
"""RDN model for single image super-resolution.

Paper: Residual Dense Network for Image Super-Resolution
Adapted from:
https://github.com/yulunzhang/RDN.git
https://github.com/yjn870/RDN-pytorch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add licence info

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These GitHub projects don't have license files.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, need to contact them


Args:
in_channels (int): Channel number of inputs.
out_channels (int): Channel number of outputs.
mid_channels (int): Channel number of intermediate features.
Default: 64.
num_blocks (int): Block number in the trunk network. Default: 16.
upscale_factor (int): Upsampling factor. Support 2^n and 3.
Default: 4.
num_layer (int): Layer number in the Residual Dense Block.
Default: 8.
growth_rate(int): Channels growth in each layer of RDB.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rate, a terminology that is often related to percentage, is not good to describe this number.

How about just growth, or channel_growth or growth_num, etc.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

picked channel_growth.

Default: 64.
"""

def __init__(self,
in_channels,
out_channels,
mid_channels=64,
num_blocks=16,
upscale_factor=4,
num_layers=8,
growth_rate=64):

super(RDN, self).__init__()
self.G0 = mid_channels
self.G = growth_rate
self.D = num_blocks
self.C = num_layers
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Such names (G0, G or D) are so ugly and confusing. Please rename it with regular terms.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, the name of variables should not contain the capital words.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, just use self.mid_channels = mid_channels

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

G, D, C e.t. are variables in the paper, I'll change them in code.


# shallow feature extraction
self.sfe1 = nn.Conv2d(
in_channels, mid_channels, kernel_size=3, padding=3 // 2)
self.sfe2 = nn.Conv2d(
mid_channels, mid_channels, kernel_size=3, padding=3 // 2)

# residual dense blocks
self.rdbs = nn.ModuleList([RDB(self.G0, self.G, self.C)])
for _ in range(self.D - 1):
self.rdbs.append(RDB(self.G, self.G, self.C))

# global feature fusion
self.gff = nn.Sequential(
nn.Conv2d(self.G * self.D, self.G0, kernel_size=1),
nn.Conv2d(self.G0, self.G0, kernel_size=3, padding=3 // 2))

# up-sampling
assert 2 <= upscale_factor <= 4
if upscale_factor == 2 or upscale_factor == 4:
self.upscale = []
for _ in range(upscale_factor // 2):
self.upscale.extend([
nn.Conv2d(
self.G0,
self.G0 * (2**2),
kernel_size=3,
padding=3 // 2),
nn.PixelShuffle(2)
])
self.upscale = nn.Sequential(*self.upscale)
else:
self.upscale = nn.Sequential(
nn.Conv2d(
self.G0,
self.G0 * (upscale_factor**2),
kernel_size=3,
padding=3 // 2), nn.PixelShuffle(upscale_factor))

self.output = nn.Conv2d(
self.G0, out_channels, kernel_size=3, padding=3 // 2)
innerlee marked this conversation as resolved.
Show resolved Hide resolved

def forward(self, x):
"""Forward function.

Args:
x (Tensor): Input tensor with shape (n, c, h, w).

Returns:
Tensor: Forward results, tensor with shape (n, c, h, w).
"""

sfe1 = self.sfe1(x)
sfe2 = self.sfe2(sfe1)

x = sfe2
local_features = []
for i in range(self.D):
x = self.rdbs[i](x)
local_features.append(x)

x = self.gff(torch.cat(local_features, 1)) + sfe1
# global residual learning
x = self.upscale(x)
x = self.output(x)
return x

def init_weights(self, pretrained=None, strict=True):
"""Init weights for models.

Args:
pretrained (str, optional): Path for pretrained weights. If given
None, pretrained weights will not be loaded. Defaults to None.
strict (boo, optional): Whether strictly load the pretrained model.
Defaults to True.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=strict, logger=logger)
elif pretrained is None:
pass # use default initialization
else:
raise TypeError('"pretrained" must be a str or None. '
f'But received {type(pretrained)}.')
56 changes: 56 additions & 0 deletions tests/test_rdn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import torch
import torch.nn as nn

from mmedit.models import build_backbone


def test_rdn():

scale = 4

model_cfg = dict(
type='RDN',
in_channels=3,
out_channels=3,
mid_channels=64,
num_blocks=16,
upscale_factor=scale)

# build model
model = build_backbone(model_cfg)

# test attributes
assert model.__class__.__name__ == 'RDN'

# prepare data
inputs = torch.rand(1, 3, 32, 16)
targets = torch.rand(1, 3, 128, 64)

# prepare loss
loss_function = nn.L1Loss()

# prepare optimizer
optimizer = torch.optim.Adam(model.parameters())

# test on cpu
output = model(inputs)
optimizer.zero_grad()
loss = loss_function(output, targets)
loss.backward()
optimizer.step()
assert torch.is_tensor(output)
assert output.shape == targets.shape

# test on gpu
if torch.cuda.is_available():
model = model.cuda()
optimizer = torch.optim.Adam(model.parameters())
inputs = inputs.cuda()
targets = targets.cuda()
output = model(inputs)
optimizer.zero_grad()
loss = loss_function(output, targets)
loss.backward()
optimizer.step()
assert torch.is_tensor(output)
assert output.shape == targets.shape