-
Notifications
You must be signed in to change notification settings - Fork 1
/
model.py
222 lines (178 loc) · 8.14 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
from odegcn import ODEG
class Chomp1d(nn.Module):
"""
extra dimension will be added by padding, remove it
"""
def __init__(self, chomp_size):
super(Chomp1d, self).__init__()
self.chomp_size = chomp_size
def forward(self, x):
return x[:, :, :, :-self.chomp_size].contiguous()
class TemporalConvNet(nn.Module):
"""
time dilation convolution
"""
def __init__(self, graph, num_inputs, num_channels, kernel_size=2, dropout=0.2):
"""
Args:
num_inputs : channel's number of input data's feature
num_channels : numbers of data feature tranform channels, the last is the output channel
kernel_size : using 1d convolution, so the real kernel is (1, kernel_size)
"""
super(TemporalConvNet, self).__init__()
layers = []
num_levels = len(num_channels)
for i in range(num_levels):
dilation_size = 2 ** i
in_channels = num_inputs if i == 0 else num_channels[i - 1]
out_channels = num_channels[i]
padding = (kernel_size - 1) * dilation_size
self.conv = nn.Conv2d(in_channels, out_channels, (1, kernel_size), dilation=(1, dilation_size),
padding=(0, padding))
self.conv.weight.data.normal_(0, 0.01)
self.chomp = Chomp1d(padding)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(dropout)
layers += [nn.Sequential(self.conv, self.chomp, self.relu, self.dropout)]
self.network = nn.Sequential(*layers)
self.downsample = nn.Conv2d(num_inputs, num_channels[-1], (1, 1)) if num_inputs != num_channels[-1] else None
if self.downsample:
self.downsample.weight.data.normal_(0, 0.01)
def forward(self, x):
"""
like ResNet
Args:
X : input data of shape (B, N, T, F)
"""
# permute shape to (B, F, N, T)
y = x.permute(0, 3, 1, 2)
y = F.relu(self.network(y) + self.downsample(y) if self.downsample else y)
y = y.permute(0, 2, 3, 1)
return y
class GCN(nn.Module):
def __init__(self, A_hat, in_channels, out_channels, ):
super(GCN, self).__init__()
self.A_hat = A_hat
self.theta = nn.Parameter(torch.FloatTensor(in_channels, out_channels))
self.reset()
def reset(self):
stdv = 1. / math.sqrt(self.theta.shape[1])
self.theta.data.uniform_(-stdv, stdv)
def forward(self, X):
y = torch.einsum('ij, kjlm-> kilm', self.A_hat, X)
return F.relu(torch.einsum('kjlm, mn->kjln', y, self.theta))
class STGCNBlock(nn.Module):
def __init__(self, in_channels, out_channels, num_nodes, A_hat, type):
"""
Args:
in_channels: Number of input features at each node in each time step.
out_channels: a list of feature channels in timeblock, the last is output feature channel
num_nodes: Number of nodes in the graph
A_hat: the normalized adjacency matrix
"""
super(STGCNBlock, self).__init__()
self.A_hat = A_hat
self.temporal1 = TemporalConvNet(A_hat, num_inputs=in_channels,
num_channels=out_channels)
self.temporal2 = TemporalConvNet(A_hat, num_inputs=out_channels[-1],
num_channels=out_channels)
self.odeg1 = ODEG(num_nodes, type, out_channels[-1], 12, A_hat, time=6)
self.batch_norm = nn.BatchNorm2d(num_nodes)
def forward(self, X):
"""
Args:
X: Input data of shape (batch_size, num_nodes, num_timesteps, num_features)
Return:
Output data of shape(batch_size, num_nodes, num_timesteps, out_channels[-1])
"""
b, n, t, c = X.shape
X = self.temporal1(X)
X = self.odeg1(X)
X = self.temporal2(X)
return self.batch_norm(X)
class ODEGCN(nn.Module):
""" the overall network framework """
def __init__(self, num_nodes, num_features, num_timesteps_input,
num_timesteps_output, A_sp_hat, A_se_hat):
"""
Args:
num_nodes : number of nodes in the graph
num_features : number of features at each node in each time step
num_timesteps_input : number of past time steps fed into the network
num_timesteps_output : desired number of future time steps output by the network
A_sp_hat : nomarlized adjacency spatial matrix
A_se_hat : nomarlized adjacency semantic matrix
"""
# self.graph=
super(ODEGCN, self).__init__()
# adjacency graph branch
self.sp_blocks = nn.ModuleList(
[nn.Sequential(
STGCNBlock(in_channels=num_features, out_channels=[64, 32, 64],
num_nodes=num_nodes, A_hat=A_sp_hat, type='sp'),
STGCNBlock(in_channels=64, out_channels=[64, 32, 64],
num_nodes=num_nodes, A_hat=A_sp_hat, type='sp')) for _ in range(3)
])
# dtw graph branch
self.se_blocks = nn.ModuleList([nn.Sequential(
STGCNBlock(in_channels=num_features, out_channels=[64, 32, 64],
num_nodes=num_nodes, A_hat=A_se_hat, type='se'),
STGCNBlock(in_channels=64, out_channels=[64, 32, 64],
num_nodes=num_nodes, A_hat=A_se_hat, type='se')) for _ in range(3)
])
#Attention Module
self.pred = MultiHeadSelfAttention(12 * 64 * 6, 6 * 64, 12)
def forward(self, x):
"""
Args:
x : input data of shape (batch_size, num_nodes, num_timesteps, num_features) == (B, N, T, F)
Returns:
prediction for future of shape (batch_size, num_nodes, num_timesteps_output)
"""
outs = []
# spatial graph
for blk in self.sp_blocks:
outs.append(blk(x))
# semantic graph
for blk in self.se_blocks:
outs.append(blk(x))
outs = torch.stack(outs, dim=-1)
b, n, t, c, s = outs.shape
x = outs.reshape(b, n, t, c * s)
x = x.reshape((x.shape[0], x.shape[1], -1))
return self.pred(x)
class MultiHeadSelfAttention(nn.Module):
# dim_in: int # input dimension
# dim_k: int # key and query dimension
# dim_v: int # value dimension
# num_heads: int # number of heads, for each head, dim_* = dim_* // num_heads
def __init__(self, dim_in, dim_k, dim_v, num_heads=12):
super(MultiHeadSelfAttention, self).__init__()
assert dim_k % num_heads == 0 and dim_v % num_heads == 0, "dim_k and dim_v must be multiple of num_heads"
self.dim_in = dim_in
self.dim_k = dim_k
self.dim_v = dim_v
self.num_heads = num_heads
self.linear_q = nn.Linear(dim_in, dim_k, bias=True)
self.linear_k = nn.Linear(dim_in, dim_k, bias=True)
self.linear_v = nn.Linear(dim_in, dim_v, bias=True)
self._norm_fact = 1 / math.sqrt(dim_k // num_heads)
def forward(self, x):
# x: tensor of shape (batch, n, dim_in)
batch, n, dim_in = x.shape
assert dim_in == self.dim_in
nh = self.num_heads
dk = self.dim_k // nh # dim_k of each head
dv = self.dim_v // nh # dim_v of each head
q = self.linear_q(x).reshape(batch, n, nh, dk).transpose(1, 2) # (batch, nh, n, dk)
k = self.linear_k(x).reshape(batch, n, nh, dk).transpose(1, 2) # (batch, nh, n, dk)
v = self.linear_v(x).reshape(batch, n, nh, dv).transpose(1, 2) # (batch, nh, n, dv)
dist = torch.matmul(q, k.transpose(2, 3)) * self._norm_fact # batch, nh, n, n
dist = torch.softmax(dist, dim=-1) # batch, nh, n, n
att = torch.matmul(dist, v) # batch, nh, n, dv
att = att.transpose(1, 2).reshape(batch, n, self.dim_v) # batch, n, dim_v
return att