From 676a0c503bca4a78969f55903b6ded60b8fe57c4 Mon Sep 17 00:00:00 2001 From: liyinshuo Date: Thu, 1 Apr 2021 23:19:47 +0800 Subject: [PATCH 1/7] add RDN --- mmedit/models/backbones/__init__.py | 4 +- .../models/backbones/sr_backbones/__init__.py | 3 +- mmedit/models/backbones/sr_backbones/rdn.py | 185 ++++++++++++++++++ 3 files changed, 189 insertions(+), 3 deletions(-) create mode 100644 mmedit/models/backbones/sr_backbones/rdn.py diff --git a/mmedit/models/backbones/__init__.py b/mmedit/models/backbones/__init__.py index 93393a1863..6705569a6b 100644 --- a/mmedit/models/backbones/__init__.py +++ b/mmedit/models/backbones/__init__.py @@ -12,7 +12,7 @@ SimpleEncoderDecoder) # yapf: enable from .generation_backbones import ResnetGenerator, UnetGenerator -from .sr_backbones import EDSR, SRCNN, EDVRNet, MSRResNet, RRDBNet, TOFlow +from .sr_backbones import EDSR, RDN, SRCNN, EDVRNet, MSRResNet, RRDBNet, TOFlow __all__ = [ 'MSRResNet', 'VGG16', 'PlainDecoder', 'SimpleEncoderDecoder', @@ -20,7 +20,7 @@ 'PConvEncoderDecoder', 'PConvEncoder', 'PConvDecoder', 'ResNetEnc', 'ResNetDec', 'ResShortcutEnc', 'ResShortcutDec', 'RRDBNet', 'DeepFillEncoder', 'HolisticIndexBlock', 'DepthwiseIndexBlock', - 'ContextualAttentionNeck', 'DeepFillDecoder', 'EDSR', + 'ContextualAttentionNeck', 'DeepFillDecoder', 'EDSR', 'RDN', 'DeepFillEncoderDecoder', 'EDVRNet', 'IndexedUpsample', 'IndexNetEncoder', 'IndexNetDecoder', 'TOFlow', 'ResGCAEncoder', 'ResGCADecoder', 'SRCNN', 'UnetGenerator', 'ResnetGenerator', 'FBAResnetDilated', 'FBADecoder' diff --git a/mmedit/models/backbones/sr_backbones/__init__.py b/mmedit/models/backbones/sr_backbones/__init__.py index 24f181f583..477f3c1ea7 100644 --- a/mmedit/models/backbones/sr_backbones/__init__.py +++ b/mmedit/models/backbones/sr_backbones/__init__.py @@ -1,8 +1,9 @@ from .edsr import EDSR from .edvr_net import EDVRNet +from .rdn import RDN from .rrdb_net import RRDBNet from .sr_resnet import MSRResNet from .srcnn import SRCNN from .tof import TOFlow -__all__ = ['MSRResNet', 'RRDBNet', 'EDSR', 'EDVRNet', 'TOFlow', 'SRCNN'] +__all__ = ['MSRResNet', 'RRDBNet', 'EDSR', 'EDVRNet', 'TOFlow', 'SRCNN', 'RDN'] diff --git a/mmedit/models/backbones/sr_backbones/rdn.py b/mmedit/models/backbones/sr_backbones/rdn.py new file mode 100644 index 0000000000..cc9fba2d2d --- /dev/null +++ b/mmedit/models/backbones/sr_backbones/rdn.py @@ -0,0 +1,185 @@ +import torch +from mmcv.runner import load_checkpoint +from torch import nn + +from mmedit.models.registry import BACKBONES +from mmedit.utils import get_root_logger + + +class DenseLayer(nn.Module): + + def __init__(self, in_channels, out_channels): + """Dense layer + + Args: + in_channels (int): Channel number of inputs. + out_channels (int): Channel number of outputs. + + """ + super(DenseLayer, self).__init__() + self.conv = nn.Conv2d( + in_channels, out_channels, kernel_size=3, padding=3 // 2) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results, tensor with shape (n, c, h, w). + """ + return torch.cat([x, self.relu(self.conv(x))], 1) + + +class RDB(nn.Module): + + def __init__(self, in_channels, growth_rate, num_layers): + """Residual Dense Block of Residual Dense Network + + Args: + in_channels (int): Channel number of inputs. + out_channels (int): Channel number of outputs. + """ + super(RDB, self).__init__() + self.layers = nn.Sequential(*[ + DenseLayer(in_channels + growth_rate * i, growth_rate) + for i in range(num_layers) + ]) + + # local feature fusion + self.lff = nn.Conv2d( + in_channels + growth_rate * num_layers, growth_rate, kernel_size=1) + + def forward(self, x): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results, tensor with shape (n, c, h, w). + """ + return x + self.lff(self.layers(x)) # local residual learning + + +@BACKBONES.register_module() +class RDN(nn.Module): + + def __init__(self, + in_channels, + out_channels, + mid_channels=64, + num_blocks=16, + upscale_factor=4, + num_layers=8, + growth_rate=64): + """RDN model for single image super-resolution. + + Paper: Residual Dense Network for Image Super-Resolution + Ref repo: https://github.com/yulunzhang/RDN.git + + Args: + in_channels (int): Channel number of inputs. + out_channels (int): Channel number of outputs. + mid_channels (int): Channel number of intermediate features. + Default: 64. + num_blocks (int): Block number in the trunk network. Default: 16. + upscale_factor (int): Upsampling factor. Support 2^n and 3. + Default: 4. + num_layer (int): Layer number in the Residual Dense Block. + Default: 8. + growth_rate(int): Channels growth in each layer of RDB. + Default: 64. + """ + + super(RDN, self).__init__() + self.G0 = mid_channels + self.G = growth_rate + self.D = num_blocks + self.C = num_layers + + # shallow feature extraction + self.sfe1 = nn.Conv2d( + in_channels, mid_channels, kernel_size=3, padding=3 // 2) + self.sfe2 = nn.Conv2d( + mid_channels, mid_channels, kernel_size=3, padding=3 // 2) + + # residual dense blocks + self.rdbs = nn.ModuleList([RDB(self.G0, self.G, self.C)]) + for _ in range(self.D - 1): + self.rdbs.append(RDB(self.G, self.G, self.C)) + + # global feature fusion + self.gff = nn.Sequential( + nn.Conv2d(self.G * self.D, self.G0, kernel_size=1), + nn.Conv2d(self.G0, self.G0, kernel_size=3, padding=3 // 2)) + + # up-sampling + assert 2 <= upscale_factor <= 4 + if upscale_factor == 2 or upscale_factor == 4: + self.upscale = [] + for _ in range(upscale_factor // 2): + self.upscale.extend([ + nn.Conv2d( + self.G0, + self.G0 * (2**2), + kernel_size=3, + padding=3 // 2), + nn.PixelShuffle(2) + ]) + self.upscale = nn.Sequential(*self.upscale) + else: + self.upscale = nn.Sequential( + nn.Conv2d( + self.G0, + self.G0 * (upscale_factor**2), + kernel_size=3, + padding=3 // 2), nn.PixelShuffle(upscale_factor)) + + self.output = nn.Conv2d( + self.G0, out_channels, kernel_size=3, padding=3 // 2) + + def forward(self, x): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results, tensor with shape (n, c, h, w). + """ + + sfe1 = self.sfe1(x) + sfe2 = self.sfe2(sfe1) + + x = sfe2 + local_features = [] + for i in range(self.D): + x = self.rdbs[i](x) + local_features.append(x) + + x = self.gff(torch.cat(local_features, 1)) + sfe1 + # global residual learning + x = self.upscale(x) + x = self.output(x) + return x + + def init_weights(self, pretrained=None, strict=True): + """Init weights for models. + + Args: + pretrained (str, optional): Path for pretrained weights. If given + None, pretrained weights will not be loaded. Defaults to None. + strict (boo, optional): Whether strictly load the pretrained model. + Defaults to True. + """ + if isinstance(pretrained, str): + logger = get_root_logger() + load_checkpoint(self, pretrained, strict=strict, logger=logger) + elif pretrained is None: + pass # use default initialization + else: + raise TypeError('"pretrained" must be a str or None. ' + f'But received {type(pretrained)}.') From 8fa9072b8edc2bea53e95672f071b32b4ac74667 Mon Sep 17 00:00:00 2001 From: liyinshuo Date: Fri, 2 Apr 2021 00:58:57 +0800 Subject: [PATCH 2/7] Add docstring and test. --- mmedit/models/backbones/sr_backbones/rdn.py | 62 +++++++++++---------- tests/test_rdn.py | 56 +++++++++++++++++++ 2 files changed, 88 insertions(+), 30 deletions(-) create mode 100644 tests/test_rdn.py diff --git a/mmedit/models/backbones/sr_backbones/rdn.py b/mmedit/models/backbones/sr_backbones/rdn.py index cc9fba2d2d..9e2495cdfa 100644 --- a/mmedit/models/backbones/sr_backbones/rdn.py +++ b/mmedit/models/backbones/sr_backbones/rdn.py @@ -7,15 +7,15 @@ class DenseLayer(nn.Module): + """Dense layer - def __init__(self, in_channels, out_channels): - """Dense layer + Args: + in_channels (int): Channel number of inputs. + out_channels (int): Channel number of outputs. - Args: - in_channels (int): Channel number of inputs. - out_channels (int): Channel number of outputs. + """ - """ + def __init__(self, in_channels, out_channels): super(DenseLayer, self).__init__() self.conv = nn.Conv2d( in_channels, out_channels, kernel_size=3, padding=3 // 2) @@ -34,14 +34,14 @@ def forward(self, x): class RDB(nn.Module): + """Residual Dense Block of Residual Dense Network - def __init__(self, in_channels, growth_rate, num_layers): - """Residual Dense Block of Residual Dense Network + Args: + in_channels (int): Channel number of inputs. + out_channels (int): Channel number of outputs. + """ - Args: - in_channels (int): Channel number of inputs. - out_channels (int): Channel number of outputs. - """ + def __init__(self, in_channels, growth_rate, num_layers): super(RDB, self).__init__() self.layers = nn.Sequential(*[ DenseLayer(in_channels + growth_rate * i, growth_rate) @@ -66,6 +66,26 @@ def forward(self, x): @BACKBONES.register_module() class RDN(nn.Module): + """RDN model for single image super-resolution. + + Paper: Residual Dense Network for Image Super-Resolution + Adapted from: + https://github.com/yulunzhang/RDN.git + https://github.com/yjn870/RDN-pytorch + + Args: + in_channels (int): Channel number of inputs. + out_channels (int): Channel number of outputs. + mid_channels (int): Channel number of intermediate features. + Default: 64. + num_blocks (int): Block number in the trunk network. Default: 16. + upscale_factor (int): Upsampling factor. Support 2^n and 3. + Default: 4. + num_layer (int): Layer number in the Residual Dense Block. + Default: 8. + growth_rate(int): Channels growth in each layer of RDB. + Default: 64. + """ def __init__(self, in_channels, @@ -75,24 +95,6 @@ def __init__(self, upscale_factor=4, num_layers=8, growth_rate=64): - """RDN model for single image super-resolution. - - Paper: Residual Dense Network for Image Super-Resolution - Ref repo: https://github.com/yulunzhang/RDN.git - - Args: - in_channels (int): Channel number of inputs. - out_channels (int): Channel number of outputs. - mid_channels (int): Channel number of intermediate features. - Default: 64. - num_blocks (int): Block number in the trunk network. Default: 16. - upscale_factor (int): Upsampling factor. Support 2^n and 3. - Default: 4. - num_layer (int): Layer number in the Residual Dense Block. - Default: 8. - growth_rate(int): Channels growth in each layer of RDB. - Default: 64. - """ super(RDN, self).__init__() self.G0 = mid_channels diff --git a/tests/test_rdn.py b/tests/test_rdn.py new file mode 100644 index 0000000000..ccf07fa9c0 --- /dev/null +++ b/tests/test_rdn.py @@ -0,0 +1,56 @@ +import torch +import torch.nn as nn + +from mmedit.models import build_backbone + + +def test_rdn(): + + scale = 4 + + model_cfg = dict( + type='RDN', + in_channels=3, + out_channels=3, + mid_channels=64, + num_blocks=16, + upscale_factor=scale) + + # build model + model = build_backbone(model_cfg) + + # test attributes + assert model.__class__.__name__ == 'RDN' + + # prepare data + inputs = torch.rand(1, 3, 32, 16) + targets = torch.rand(1, 3, 128, 64) + + # prepare loss + loss_function = nn.L1Loss() + + # prepare optimizer + optimizer = torch.optim.Adam(model.parameters()) + + # test on cpu + output = model(inputs) + optimizer.zero_grad() + loss = loss_function(output, targets) + loss.backward() + optimizer.step() + assert torch.is_tensor(output) + assert output.shape == targets.shape + + # test on gpu + if torch.cuda.is_available(): + model = model.cuda() + optimizer = torch.optim.Adam(model.parameters()) + inputs = inputs.cuda() + targets = targets.cuda() + output = model(inputs) + optimizer.zero_grad() + loss = loss_function(output, targets) + loss.backward() + optimizer.step() + assert torch.is_tensor(output) + assert output.shape == targets.shape From 557abb55bcf883a5c8a99f990369956171b8c388 Mon Sep 17 00:00:00 2001 From: liyinshuo Date: Tue, 13 Apr 2021 11:02:29 +0800 Subject: [PATCH 3/7] Tiny fix. --- mmedit/models/backbones/sr_backbones/rdn.py | 68 ++++++++++++--------- 1 file changed, 40 insertions(+), 28 deletions(-) diff --git a/mmedit/models/backbones/sr_backbones/rdn.py b/mmedit/models/backbones/sr_backbones/rdn.py index 9e2495cdfa..c3cd487815 100644 --- a/mmedit/models/backbones/sr_backbones/rdn.py +++ b/mmedit/models/backbones/sr_backbones/rdn.py @@ -16,7 +16,7 @@ class DenseLayer(nn.Module): """ def __init__(self, in_channels, out_channels): - super(DenseLayer, self).__init__() + super().__init__() self.conv = nn.Conv2d( in_channels, out_channels, kernel_size=3, padding=3 // 2) self.relu = nn.ReLU(inplace=True) @@ -25,10 +25,10 @@ def forward(self, x): """Forward function. Args: - x (Tensor): Input tensor with shape (n, c, h, w). + x (Tensor): Input tensor with shape (n, c_in, h, w). Returns: - Tensor: Forward results, tensor with shape (n, c, h, w). + Tensor: Forward results, tensor with shape (n, c_in+c_out, h, w). """ return torch.cat([x, self.relu(self.conv(x))], 1) @@ -38,19 +38,22 @@ class RDB(nn.Module): Args: in_channels (int): Channel number of inputs. - out_channels (int): Channel number of outputs. + channel_growth (int): Channels growth in each layer. + num_layers (int): Layer number in the Residual Dense Block. """ - def __init__(self, in_channels, growth_rate, num_layers): - super(RDB, self).__init__() + def __init__(self, in_channels, channel_growth, num_layers): + super().__init__() self.layers = nn.Sequential(*[ - DenseLayer(in_channels + growth_rate * i, growth_rate) + DenseLayer(in_channels + channel_growth * i, channel_growth) for i in range(num_layers) ]) # local feature fusion self.lff = nn.Conv2d( - in_channels + growth_rate * num_layers, growth_rate, kernel_size=1) + in_channels + channel_growth * num_layers, + channel_growth, + kernel_size=1) def forward(self, x): """Forward function. @@ -59,7 +62,7 @@ def forward(self, x): x (Tensor): Input tensor with shape (n, c, h, w). Returns: - Tensor: Forward results, tensor with shape (n, c, h, w). + Tensor: Forward results. """ return x + self.lff(self.layers(x)) # local residual learning @@ -83,7 +86,7 @@ class RDN(nn.Module): Default: 4. num_layer (int): Layer number in the Residual Dense Block. Default: 8. - growth_rate(int): Channels growth in each layer of RDB. + channel_growth(int): Channels growth in each layer of RDB. Default: 64. """ @@ -94,13 +97,13 @@ def __init__(self, num_blocks=16, upscale_factor=4, num_layers=8, - growth_rate=64): + channel_growth=64): - super(RDN, self).__init__() - self.G0 = mid_channels - self.G = growth_rate - self.D = num_blocks - self.C = num_layers + super().__init__() + self.mid_channels = mid_channels + self.channel_growth = channel_growth + self.num_blocks = num_blocks + self.num_layers = num_layers # shallow feature extraction self.sfe1 = nn.Conv2d( @@ -109,14 +112,23 @@ def __init__(self, mid_channels, mid_channels, kernel_size=3, padding=3 // 2) # residual dense blocks - self.rdbs = nn.ModuleList([RDB(self.G0, self.G, self.C)]) - for _ in range(self.D - 1): - self.rdbs.append(RDB(self.G, self.G, self.C)) + self.rdbs = nn.ModuleList( + [RDB(self.mid_channels, self.channel_growth, self.num_layers)]) + for _ in range(self.num_blocks - 1): + self.rdbs.append( + RDB(self.channel_growth, self.channel_growth, self.num_layers)) # global feature fusion self.gff = nn.Sequential( - nn.Conv2d(self.G * self.D, self.G0, kernel_size=1), - nn.Conv2d(self.G0, self.G0, kernel_size=3, padding=3 // 2)) + nn.Conv2d( + self.channel_growth * self.num_blocks, + self.mid_channels, + kernel_size=1), + nn.Conv2d( + self.mid_channels, + self.mid_channels, + kernel_size=3, + padding=3 // 2)) # up-sampling assert 2 <= upscale_factor <= 4 @@ -125,8 +137,8 @@ def __init__(self, for _ in range(upscale_factor // 2): self.upscale.extend([ nn.Conv2d( - self.G0, - self.G0 * (2**2), + self.mid_channels, + self.mid_channels * (2**2), kernel_size=3, padding=3 // 2), nn.PixelShuffle(2) @@ -135,13 +147,13 @@ def __init__(self, else: self.upscale = nn.Sequential( nn.Conv2d( - self.G0, - self.G0 * (upscale_factor**2), + self.mid_channels, + self.mid_channels * (upscale_factor**2), kernel_size=3, padding=3 // 2), nn.PixelShuffle(upscale_factor)) self.output = nn.Conv2d( - self.G0, out_channels, kernel_size=3, padding=3 // 2) + self.mid_channels, out_channels, kernel_size=3, padding=3 // 2) def forward(self, x): """Forward function. @@ -150,7 +162,7 @@ def forward(self, x): x (Tensor): Input tensor with shape (n, c, h, w). Returns: - Tensor: Forward results, tensor with shape (n, c, h, w). + Tensor: Forward results. """ sfe1 = self.sfe1(x) @@ -158,7 +170,7 @@ def forward(self, x): x = sfe2 local_features = [] - for i in range(self.D): + for i in range(self.num_blocks): x = self.rdbs[i](x) local_features.append(x) From e2a289275cf122ed89e846e162fb42b7d3567cae Mon Sep 17 00:00:00 2001 From: liyinshuo Date: Tue, 13 Apr 2021 19:18:30 +0800 Subject: [PATCH 4/7] Tiny fix. --- mmedit/models/backbones/sr_backbones/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mmedit/models/backbones/sr_backbones/__init__.py b/mmedit/models/backbones/sr_backbones/__init__.py index 315cdfb05f..1ed88be173 100644 --- a/mmedit/models/backbones/sr_backbones/__init__.py +++ b/mmedit/models/backbones/sr_backbones/__init__.py @@ -8,5 +8,6 @@ from .tof import TOFlow __all__ = [ - 'MSRResNet', 'RRDBNet', 'EDSR', 'EDVRNet', 'TOFlow', 'SRCNN', 'BasicVSRNet', 'RDN' + 'MSRResNet', 'RRDBNet', 'EDSR', 'EDVRNet', 'TOFlow', 'SRCNN', + 'BasicVSRNet', 'RDN' ] From 9a1a7e1602e7feb05b2c106fa0b2e46cc13cb21e Mon Sep 17 00:00:00 2001 From: liyinshuo Date: Thu, 22 Apr 2021 11:30:37 +0800 Subject: [PATCH 5/7] Add license. --- mmedit/models/backbones/sr_backbones/rdn.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mmedit/models/backbones/sr_backbones/rdn.py b/mmedit/models/backbones/sr_backbones/rdn.py index c3cd487815..1557d6c286 100644 --- a/mmedit/models/backbones/sr_backbones/rdn.py +++ b/mmedit/models/backbones/sr_backbones/rdn.py @@ -72,9 +72,10 @@ class RDN(nn.Module): """RDN model for single image super-resolution. Paper: Residual Dense Network for Image Super-Resolution - Adapted from: - https://github.com/yulunzhang/RDN.git - https://github.com/yjn870/RDN-pytorch + + Adapted from 'git@github.com:yjn870/RDN-pytorch.git' + 'RDN-pytorch/blob/master/models.py' + Copyright (c) 2021, JaeYun Yeo, under MIT License. Args: in_channels (int): Channel number of inputs. From 2c9bdd17d7f894196c4ff29fb4ca616d2342c2f2 Mon Sep 17 00:00:00 2001 From: liyinshuo Date: Thu, 22 Apr 2021 14:22:51 +0800 Subject: [PATCH 6/7] Tiny Fix --- mmedit/models/backbones/sr_backbones/rdn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmedit/models/backbones/sr_backbones/rdn.py b/mmedit/models/backbones/sr_backbones/rdn.py index 1557d6c286..b0d643fe2b 100644 --- a/mmedit/models/backbones/sr_backbones/rdn.py +++ b/mmedit/models/backbones/sr_backbones/rdn.py @@ -73,7 +73,7 @@ class RDN(nn.Module): Paper: Residual Dense Network for Image Super-Resolution - Adapted from 'git@github.com:yjn870/RDN-pytorch.git' + Adapted from 'https://github.com/yjn870/RDN-pytorch.git' 'RDN-pytorch/blob/master/models.py' Copyright (c) 2021, JaeYun Yeo, under MIT License. From a784deb93cf06ff05e5f1b9d0a79d6f44903c4bf Mon Sep 17 00:00:00 2001 From: liyinshuo Date: Thu, 22 Apr 2021 15:55:14 +0800 Subject: [PATCH 7/7] Tiny Fix --- mmedit/models/backbones/sr_backbones/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mmedit/models/backbones/sr_backbones/__init__.py b/mmedit/models/backbones/sr_backbones/__init__.py index 8a15c5e72f..d96a81a9ca 100644 --- a/mmedit/models/backbones/sr_backbones/__init__.py +++ b/mmedit/models/backbones/sr_backbones/__init__.py @@ -1,8 +1,8 @@ from .basicvsr_net import BasicVSRNet from .edsr import EDSR from .edvr_net import EDVRNet -from .rdn import RDN from .iconvsr import IconVSR +from .rdn import RDN from .rrdb_net import RRDBNet from .sr_resnet import MSRResNet from .srcnn import SRCNN @@ -10,5 +10,5 @@ __all__ = [ 'MSRResNet', 'RRDBNet', 'EDSR', 'EDVRNet', 'TOFlow', 'SRCNN', - 'BasicVSRNet', 'RDN', 'IconVSR' + 'BasicVSRNet', 'IconVSR', 'RDN' ]