-
Notifications
You must be signed in to change notification settings - Fork 0
/
generator.py
124 lines (99 loc) · 4.39 KB
/
generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import torch
import torch.nn as nn
from utils import *
class SARB(nn.Module):
def __init__(self, c_in, c_out, relu_type='relu', norm_type='in', hg_depth=2, att_name='spar'):
super(SARB, self).__init__()
self.c_in = c_in
self.c_out = c_out
self.norm_type = norm_type
self.relu_type = relu_type
self.hg_depth = hg_depth
kwargs = {'norm_type': norm_type, 'relu_type': relu_type}
self.conv1 = ConvLayer(c_in, c_out, 3, **kwargs)
self.conv2 = ConvLayer(c_out, c_out, 3, norm_type=norm_type, relu_type='none')
self.att_func = SEDN(self.hg_depth, c_out, 1)
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.conv2(out)
out = identity + self.att_func(out)
return out
class SEDN(nn.Module):
def __init__(self, depth, c_in, c_out, c_mid=64,norm_type='in', relu_type='relu'):
super(SEDN, self).__init__()
self.depth = depth
self.c_in = c_in
self.c_mid = c_mid
self.c_out = c_out
self.kwargs = {'norm_type': norm_type, 'relu_type': relu_type}
if self.depth:
self._generate_network(self.depth)
self.out_block = nn.Sequential(
ConvLayer(self.c_mid, self.c_out, norm_type='none', relu_type='none'),
nn.Sigmoid()
)
def _generate_network(self, level):
if level == self.depth:
c1, c2 = self.c_in, self.c_mid
else:
c1, c2 = self.c_mid, self.c_mid
self.add_module('b1_' + str(level), ConvLayer(c1, c2, **self.kwargs))
self.add_module('b2_' + str(level), ConvLayer(c1, c2, scale='down', **self.kwargs))
if level > 1:
self._generate_network(level - 1)
else:
self.add_module('b2_plus_' + str(level), ConvLayer(self.c_mid, self.c_mid, **self.kwargs))
self.add_module('b3_' + str(level), ConvLayer(self.c_mid, self.c_mid, scale='up', **self.kwargs))
def _forward(self, level, in_x):
up1 = self._modules['b1_' + str(level)](in_x)
low1 = self._modules['b2_' + str(level)](in_x)
if level > 1:
low2 = self._forward(level - 1, low1)
else:
low2 = self._modules['b2_plus_' + str(level)](low1)
up2 = self._modules['b3_' + str(level)](low2)
if up1.shape[2:] != up2.shape[2:]:
up2 = nn.functional.interpolate(up2, up1.shape[2:])
return up1 + up2
def forward(self, x, pmask=None):
if self.depth == 0: return x
input_x = x
x = self._forward(self.depth, x)
self.att_map = self.out_block(x)
x = input_x * self.att_map
return x
class Generator(nn.Module):
"""Generator network."""
def __init__(self, conv_dim=64, c_dim=5, repeat_num=1):
super(Generator, self).__init__()
layers = []
layers.append(nn.Conv2d(3+c_dim, conv_dim, kernel_size=7, stride=1, padding=3, bias=False))
layers.append(nn.InstanceNorm2d(conv_dim, affine=True, track_running_stats=True))
layers.append(nn.ReLU(inplace=True))
# Down-sampling layers.
curr_dim = conv_dim
for i in range(2):
layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1, bias=False))
layers.append(nn.InstanceNorm2d(curr_dim*2, affine=True, track_running_stats=True))
layers.append(nn.ReLU(inplace=True))
curr_dim = curr_dim * 2
layers.append(SARB(c_in=curr_dim, c_out=curr_dim))
# Up-sampling layers.
for i in range(2):
layers.append(nn.ConvTranspose2d(curr_dim, curr_dim//2, kernel_size=4, stride=2, padding=1, bias=False))
layers.append(nn.InstanceNorm2d(curr_dim//2, affine=True, track_running_stats=True))
layers.append(nn.ReLU(inplace=True))
curr_dim = curr_dim // 2
layers.append(nn.Conv2d(curr_dim, 3, kernel_size=7, stride=1, padding=3, bias=False))
layers.append(nn.Tanh())
self.att_func = SEDN(2, 3, 1)
self.main = nn.Sequential(*layers)
def forward(self, x, c):
xo = x
c = c.view(c.size(0), c.size(1), 1, 1)
c = c.repeat(1, 1, x.size(2), x.size(3))
x = torch.cat([x, c], dim=1)
y = self.main(x)
x = y + xo
return x