Skip to content

Commit

Permalink
[Feature] Add Merge-Features.
Browse files Browse the repository at this point in the history
  • Loading branch information
liyinshuo committed May 17, 2021
1 parent 0f25981 commit bbdb234
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 7 deletions.
42 changes: 42 additions & 0 deletions mmedit/models/backbones/sr_backbones/ttsr_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,45 @@ def forward(self, x1, x2, x4):
x4 = F.relu(self.conv_merge4(torch.cat((x4, x14, x24), dim=1)))

return x1, x2, x4


class MergeFeatures(nn.Module):
"""Merge Features. Merge 1x, 2x, and 4x features.
Final module of Texture Transformer Network for Image Super-Resolution.
"""

def __init__(self, mid_channels, out_channels):
super().__init__()
self.conv1to4 = _conv1x1_layer(mid_channels, mid_channels)
self.conv2to4 = _conv1x1_layer(mid_channels, mid_channels)
self.conv_merge = _conv3x3_layer(mid_channels * 3, mid_channels)
self.conv_last1 = _conv3x3_layer(mid_channels, mid_channels // 2)
self.conv_last2 = _conv1x1_layer(mid_channels // 2, out_channels)

def forward(self, x1, x2, x4):
"""Forward function.
Args:
x1 (Tensor): Input tensor with shape (n, c, h, w).
x2 (Tensor): Input tensor with shape (n, c, 2h, 2w).
x4 (Tensor): Input tensor with shape (n, c, 4h, 4w).
Returns:
x (Tensor): Output tensor with shape (n, c_out, 4h, 4w).
"""

x14 = F.interpolate(
x1, scale_factor=4, mode='bicubic', align_corners=False)
x14 = F.relu(self.conv1to4(x14))
x24 = F.interpolate(
x2, scale_factor=2, mode='bicubic', align_corners=False)
x24 = F.relu(self.conv2to4(x24))

x = F.relu(self.conv_merge(torch.cat((x4, x14, x24), dim=1)))
x = self.conv_last1(x)
x = self.conv_last2(x)
x = torch.clamp(x, -1, 1)

return x
20 changes: 13 additions & 7 deletions tests/test_ttsr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch

from mmedit.models.backbones.sr_backbones.ttsr_net import CSFI2, CSFI3, SFE
from mmedit.models.backbones.sr_backbones.ttsr_net import (CSFI2, CSFI3, SFE,
MergeFeatures)


def test_sfe():
Expand All @@ -13,20 +14,25 @@ def test_sfe():
def test_csfi():
inputs1 = torch.rand(2, 16, 24, 24)
inputs2 = torch.rand(2, 16, 48, 48)
inputs3 = torch.rand(2, 16, 96, 96)
inputs4 = torch.rand(2, 16, 96, 96)

csfi2 = CSFI2(mid_channels=16)
out1, out2 = csfi2(inputs1, inputs2)
assert out1.shape == (2, 16, 24, 24)
assert out2.shape == (2, 16, 48, 48)

csfi3 = CSFI3(mid_channels=16)
out1, out2, out3 = csfi3(inputs1, inputs2, inputs3)
out1, out2, out4 = csfi3(inputs1, inputs2, inputs4)
assert out1.shape == (2, 16, 24, 24)
assert out2.shape == (2, 16, 48, 48)
assert out3.shape == (2, 16, 96, 96)
assert out4.shape == (2, 16, 96, 96)


if __name__ == '__main__':
test_sfe()
test_csfi()
def test_merge_features():
inputs1 = torch.rand(2, 16, 24, 24)
inputs2 = torch.rand(2, 16, 48, 48)
inputs4 = torch.rand(2, 16, 96, 96)

merge_features = MergeFeatures(mid_channels=16, out_channels=3)
out = merge_features(inputs1, inputs2, inputs4)
assert out.shape == (2, 3, 96, 96)

0 comments on commit bbdb234

Please sign in to comment.