-
Notifications
You must be signed in to change notification settings - Fork 30
/
Copy pathcore.py
40 lines (34 loc) · 1.83 KB
/
core.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
from torch import nn
from torch.nn import functional as F
from conv_lstm import Conv2dLSTMCell
class InferenceCore(nn.Module):
def __init__(self):
super(InferenceCore, self).__init__()
self.downsample_x = nn.Conv2d(3, 3, kernel_size=4, stride=4, padding=0, bias=False)
self.upsample_v = nn.ConvTranspose2d(7, 7, kernel_size=16, stride=16, padding=0, bias=False)
self.upsample_r = nn.ConvTranspose2d(256, 256, kernel_size=16, stride=16, padding=0, bias=False)
self.downsample_u = nn.Conv2d(128, 128, kernel_size=4, stride=4, padding=0, bias=False)
self.core = Conv2dLSTMCell(3+7+256+2*128, 128, kernel_size=5, stride=1, padding=2)
def forward(self, x, v, r, c_e, h_e, h_g, u):
x = self.downsample_x(x)
v = self.upsample_v(v.view(-1, 7, 1, 1))
if r.size(2)!=h_e.size(2):
r = self.upsample_r(r)
u = self.downsample_u(u)
c_e, h_e = self.core(torch.cat((x, v, r, h_g, u), dim=1), (c_e, h_e))
return c_e, h_e
class GenerationCore(nn.Module):
def __init__(self):
super(GenerationCore, self).__init__()
self.upsample_v = nn.ConvTranspose2d(7, 7, kernel_size=16, stride=16, padding=0, bias=False)
self.upsample_r = nn.ConvTranspose2d(256, 256, kernel_size=16, stride=16, padding=0, bias=False)
self.core = Conv2dLSTMCell(7+256+3, 128, kernel_size=5, stride=1, padding=2)
self.upsample_h = nn.ConvTranspose2d(128, 128, kernel_size=4, stride=4, padding=0, bias=False)
def forward(self, v, r, c_g, h_g, u, z):
v = self.upsample_v(v.view(-1, 7, 1, 1))
if r.size(2)!=h_g.size(2):
r = self.upsample_r(r)
c_g, h_g = self.core(torch.cat((v, r, z), dim=1), (c_g, h_g))
u = self.upsample_h(h_g) + u
return c_g, h_g, u