Skip to content

Commit

Permalink
[Feature] Add Cross-Scale Feature Integration (#312)
Browse files Browse the repository at this point in the history
* [Feature] Add Cross-Scale Feature Integration

* rename

* rename

Co-authored-by: liyinshuo <[email protected]>
  • Loading branch information
Yshuo-Li and liyinshuo authored May 17, 2021
1 parent 2c65d5f commit 0f25981
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 4 deletions.
121 changes: 118 additions & 3 deletions mmedit/models/backbones/sr_backbones/ttsr_net.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_conv_layer

from mmedit.models.common import ResidualBlockNoBN, make_layer

# Use partial to specify some default arguments
_norm_conv_layer = partial(
_conv3x3_layer = partial(
build_conv_layer, dict(type='Conv2d'), kernel_size=3, padding=1)
_conv1x1_layer = partial(
build_conv_layer, dict(type='Conv2d'), kernel_size=1, padding=0)


class SFE(nn.Module):
Expand All @@ -28,15 +31,15 @@ def __init__(self, in_channels, mid_channels, num_blocks, res_scale):
super().__init__()

self.num_blocks = num_blocks
self.conv_first = _norm_conv_layer(in_channels, mid_channels)
self.conv_first = _conv3x3_layer(in_channels, mid_channels)

self.body = make_layer(
ResidualBlockNoBN,
num_blocks,
mid_channels=mid_channels,
res_scale=res_scale)

self.conv_last = _norm_conv_layer(mid_channels, mid_channels)
self.conv_last = _conv3x3_layer(mid_channels, mid_channels)

def forward(self, x):
"""Forward function.
Expand All @@ -53,3 +56,115 @@ def forward(self, x):
x = self.conv_last(x)
x = x + x1
return x


class CSFI2(nn.Module):
"""Cross-Scale Feature Integration between 1x and 2x features.
Cross-Scale Feature Integration in Texture Transformer Network for
Image Super-Resolution.
It is cross-scale feature integration between 1x and 2x features.
For example, `conv2to1` means conv layer from 2x feature to 1x
feature. Down-sampling is achieved by conv layer with stride=2,
and up-sampling is achieved by bicubic interpolate and conv layer.
Args:
mid_channels (int): Channel number of intermediate features
"""

def __init__(self, mid_channels):
super().__init__()
self.conv1to2 = _conv1x1_layer(mid_channels, mid_channels)
self.conv2to1 = _conv3x3_layer(mid_channels, mid_channels, stride=2)

self.conv_merge1 = _conv3x3_layer(mid_channels * 2, mid_channels)
self.conv_merge2 = _conv3x3_layer(mid_channels * 2, mid_channels)

def forward(self, x1, x2):
"""Forward function.
Args:
x1 (Tensor): Input tensor with shape (n, c, h, w).
x2 (Tensor): Input tensor with shape (n, c, 2h, 2w).
Returns:
x1 (Tensor): Output tensor with shape (n, c, h, w).
x2 (Tensor): Output tensor with shape (n, c, 2h, 2w).
"""

x12 = F.interpolate(
x1, scale_factor=2, mode='bicubic', align_corners=False)
x12 = F.relu(self.conv1to2(x12))
x21 = F.relu(self.conv2to1(x2))

x1 = F.relu(self.conv_merge1(torch.cat((x1, x21), dim=1)))
x2 = F.relu(self.conv_merge2(torch.cat((x2, x12), dim=1)))

return x1, x2


class CSFI3(nn.Module):
"""Cross-Scale Feature Integration between 1x, 2x, and 4x features.
Cross-Scale Feature Integration in Texture Transformer Network for
Image Super-Resolution.
It is cross-scale feature integration between 1x and 2x features.
For example, `conv2to1` means conv layer from 2x feature to 1x
feature. Down-sampling is achieved by conv layer with stride=2,
and up-sampling is achieved by bicubic interpolate and conv layer.
Args:
mid_channels (int): Channel number of intermediate features
"""

def __init__(self, mid_channels):
super().__init__()
self.conv1to2 = _conv1x1_layer(mid_channels, mid_channels)
self.conv1to4 = _conv1x1_layer(mid_channels, mid_channels)

self.conv2to1 = _conv3x3_layer(mid_channels, mid_channels, stride=2)
self.conv2to4 = _conv1x1_layer(mid_channels, mid_channels)

self.conv4to1_1 = _conv3x3_layer(mid_channels, mid_channels, stride=2)
self.conv4to1_2 = _conv3x3_layer(mid_channels, mid_channels, stride=2)
self.conv4to2 = _conv3x3_layer(mid_channels, mid_channels, stride=2)

self.conv_merge1 = _conv3x3_layer(mid_channels * 3, mid_channels)
self.conv_merge2 = _conv3x3_layer(mid_channels * 3, mid_channels)
self.conv_merge4 = _conv3x3_layer(mid_channels * 3, mid_channels)

def forward(self, x1, x2, x4):
"""Forward function.
Args:
x1 (Tensor): Input tensor with shape (n, c, h, w).
x2 (Tensor): Input tensor with shape (n, c, 2h, 2w).
x4 (Tensor): Input tensor with shape (n, c, 4h, 4w).
Returns:
x1 (Tensor): Output tensor with shape (n, c, h, w).
x2 (Tensor): Output tensor with shape (n, c, 2h, 2w).
x4 (Tensor): Output tensor with shape (n, c, 4h, 4w).
"""

x12 = F.interpolate(
x1, scale_factor=2, mode='bicubic', align_corners=False)
x12 = F.relu(self.conv1to2(x12))
x14 = F.interpolate(
x1, scale_factor=4, mode='bicubic', align_corners=False)
x14 = F.relu(self.conv1to4(x14))

x21 = F.relu(self.conv2to1(x2))
x24 = F.interpolate(
x2, scale_factor=2, mode='bicubic', align_corners=False)
x24 = F.relu(self.conv2to4(x24))

x41 = F.relu(self.conv4to1_1(x4))
x41 = F.relu(self.conv4to1_2(x41))
x42 = F.relu(self.conv4to2(x4))

x1 = F.relu(self.conv_merge1(torch.cat((x1, x21, x41), dim=1)))
x2 = F.relu(self.conv_merge2(torch.cat((x2, x12, x42), dim=1)))
x4 = F.relu(self.conv_merge4(torch.cat((x4, x14, x24), dim=1)))

return x1, x2, x4
20 changes: 19 additions & 1 deletion tests/test_ttsr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch

from mmedit.models.backbones.sr_backbones.ttsr_net import SFE
from mmedit.models.backbones.sr_backbones.ttsr_net import CSFI2, CSFI3, SFE


def test_sfe():
Expand All @@ -10,5 +10,23 @@ def test_sfe():
assert outputs.shape == (2, 64, 48, 48)


def test_csfi():
inputs1 = torch.rand(2, 16, 24, 24)
inputs2 = torch.rand(2, 16, 48, 48)
inputs3 = torch.rand(2, 16, 96, 96)

csfi2 = CSFI2(mid_channels=16)
out1, out2 = csfi2(inputs1, inputs2)
assert out1.shape == (2, 16, 24, 24)
assert out2.shape == (2, 16, 48, 48)

csfi3 = CSFI3(mid_channels=16)
out1, out2, out3 = csfi3(inputs1, inputs2, inputs3)
assert out1.shape == (2, 16, 24, 24)
assert out2.shape == (2, 16, 48, 48)
assert out3.shape == (2, 16, 96, 96)


if __name__ == '__main__':
test_sfe()
test_csfi()

0 comments on commit 0f25981

Please sign in to comment.