-
Notifications
You must be signed in to change notification settings - Fork 1
/
attention_module.py
67 lines (61 loc) · 3.09 KB
/
attention_module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch.nn as nn
from basic_layers import ResidualBlock
class AttentionModule(nn.Module):
def __init__(self, in_channels, out_channels, size1, size2, size3):
super(AttentionModule, self).__init__()
self.first_residual_blocks = ResidualBlock(in_channels, out_channels)
self.trunk_branches = nn.Sequential(
ResidualBlock(in_channels, out_channels),
ResidualBlock(in_channels, out_channels)
)
self.mpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.softmax1_blocks = ResidualBlock(in_channels, out_channels)
self.skip1_connection_residual_block = ResidualBlock(in_channels, out_channels)
self.mpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.softmax2_blocks = ResidualBlock(in_channels, out_channels)
self.skip2_connection_residual_block = ResidualBlock(in_channels, out_channels)
self.mpool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.softmax3_blocks = nn.Sequential(
ResidualBlock(in_channels, out_channels),
ResidualBlock(in_channels, out_channels)
)
self.interpolation3 = nn.UpsamplingBilinear2d(size=size3)
self.softmax4_blocks = ResidualBlock(in_channels, out_channels)
self.interpolation2 = nn.UpsamplingBilinear2d(size=size2)
self.softmax5_blocks = ResidualBlock(in_channels, out_channels)
self.interpolation1 = nn.UpsamplingBilinear2d(size=size1)
self.softmax6_blocks = nn.Sequential(
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, bias=False),
nn.Sigmoid()
)
self.last_blocks = ResidualBlock(in_channels, out_channels)
def forward(self, x):
x = self.first_residual_blocks(x)
out_trunk = self.trunk_branches(x)
out_mpool1 = self.mpool1(x)
out_softmax1 = self.softmax1_blocks(out_mpool1)
out_skip1_connection = self.skip1_connection_residual_block(out_softmax1)
out_mpool2 = self.mpool2(out_softmax1)
out_softmax2 = self.softmax2_blocks(out_mpool2)
out_skip2_connection = self.skip2_connection_residual_block(out_softmax2)
out_mpool3 = self.mpool3(out_softmax2)
out_softmax3 = self.softmax3_blocks(out_mpool3)
#
out_interp3 = self.interpolation3(out_softmax3)
# print(out_skip2_connection.data)
# print(out_interp3.data)
out = out_interp3 + out_skip2_connection
out_softmax4 = self.softmax4_blocks(out)
out_interp2 = self.interpolation2(out_softmax4)
out = out_interp2 + out_skip1_connection
out_softmax5 = self.softmax5_blocks(out)
out_interp1 = self.interpolation1(out_softmax5)
out_softmax6 = self.softmax6_blocks(out_interp1)
out = (1 + out_softmax6) * out_trunk
out_last = self.last_blocks(out)
return out_last